/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.openai;

import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptions;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata;
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.RetryCallback;
import org.springframework.retry.RetryContext;
import org.springframework.retry.RetryListener;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import reactor.core.publisher.Flux;

public class OpenAiChatClient
implements ChatClient,
StreamingChatClient {
    private final Logger logger = LoggerFactory.getLogger(this.getClass());
    private static final boolean IS_RUNTIME_CALL = true;
    private OpenAiChatOptions defaultOptions;
    private Map<String, FunctionCallback> functionCallbackRegister = new ConcurrentHashMap<String, FunctionCallback>();
    private FunctionCallbackContext functionCallbackContext;
    public final RetryTemplate retryTemplate = RetryTemplate.builder().maxAttempts(10).retryOn(OpenAiApi.OpenAiApiException.class).exponentialBackoff(Duration.ofMillis(2000L), 5.0, Duration.ofMillis(180000L)).withListener(new RetryListener(){

        public <T, E extends Throwable> void onError(RetryContext context, RetryCallback<T, E> callback, Throwable throwable) {
            OpenAiChatClient.this.logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable);
        }
    }).build();
    private final OpenAiApi openAiApi;

    public OpenAiChatClient(OpenAiApi openAiApi) {
        this(openAiApi, OpenAiChatOptions.builder().withModel("gpt-3.5-turbo").withTemperature(Float.valueOf(0.7f)).build());
    }

    public OpenAiChatClient(OpenAiApi openAiApi, OpenAiChatOptions options) {
        this(openAiApi, options, null);
    }

    public OpenAiChatClient(OpenAiApi openAiApi, OpenAiChatOptions options, FunctionCallbackContext functionCallbackContext) {
        Assert.notNull((Object)openAiApi, (String)"OpenAiApi must not be null");
        Assert.notNull((Object)options, (String)"Options must not be null");
        this.openAiApi = openAiApi;
        this.defaultOptions = options;
        this.functionCallbackContext = functionCallbackContext;
    }

    @Deprecated(since="0.8.0", forRemoval=true)
    public OpenAiChatClient withDefaultOptions(OpenAiChatOptions options) {
        this.defaultOptions = options;
        return this;
    }

    public ChatResponse call(Prompt prompt) {
        return (ChatResponse)this.retryTemplate.execute(ctx -> {
            OpenAiApi.ChatCompletionRequest request = this.createRequest(prompt, false);
            ResponseEntity<OpenAiApi.ChatCompletion> completionEntity = this.chatCompletionWithTools(request);
            OpenAiApi.ChatCompletion chatCompletion = (OpenAiApi.ChatCompletion)completionEntity.getBody();
            if (chatCompletion == null) {
                this.logger.warn("No chat completion returned for prompt: {}", (Object)prompt);
                return new ChatResponse(List.of());
            }
            RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity);
            List<Generation> generations = chatCompletion.choices().stream().map(choice -> new Generation(choice.message().content(), this.toMap(choice.message())).withGenerationMetadata(ChatGenerationMetadata.from((String)choice.finishReason().name(), null))).toList();
            return new ChatResponse(generations, (ChatResponseMetadata)OpenAiChatResponseMetadata.from((OpenAiApi.ChatCompletion)completionEntity.getBody()).withRateLimit(rateLimits));
        });
    }

    public Flux<ChatResponse> stream(Prompt prompt) {
        return (Flux)this.retryTemplate.execute(ctx -> {
            OpenAiApi.ChatCompletionRequest request = this.createRequest(prompt, true);
            Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.openAiApi.chatCompletionStream(request);
            ConcurrentHashMap roleMap = new ConcurrentHashMap();
            return completionChunks.map(chunk -> {
                String chunkId = chunk.id();
                List<Generation> generations = chunk.choices().stream().map(choice -> {
                    if (choice.delta().role() != null) {
                        roleMap.putIfAbsent(chunkId, choice.delta().role().name());
                    }
                    Generation generation = new Generation(choice.delta().content(), Map.of("role", roleMap.get(chunkId)));
                    if (choice.finishReason() != null) {
                        generation = generation.withGenerationMetadata(ChatGenerationMetadata.from((String)choice.finishReason().name(), null));
                    }
                    return generation;
                }).toList();
                return new ChatResponse(generations);
            });
        });
    }

    OpenAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
        HashSet<String> functionsForThisRequest = new HashSet<String>();
        List<OpenAiApi.ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(m -> new OpenAiApi.ChatCompletionMessage(m.getContent(), OpenAiApi.ChatCompletionMessage.Role.valueOf(m.getMessageType().name()))).toList();
        OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, stream);
        if (prompt.getOptions() != null) {
            ModelOptions modelOptions = prompt.getOptions();
            if (modelOptions instanceof ChatOptions) {
                ChatOptions runtimeOptions = (ChatOptions)modelOptions;
                OpenAiChatOptions updatedRuntimeOptions = (OpenAiChatOptions)ModelOptionsUtils.copyToTarget((Object)runtimeOptions, ChatOptions.class, OpenAiChatOptions.class);
                Set<String> promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions, true);
                functionsForThisRequest.addAll(promptEnabledFunctions);
                request = (OpenAiApi.ChatCompletionRequest)ModelOptionsUtils.merge((Object)updatedRuntimeOptions, (Object)request, OpenAiApi.ChatCompletionRequest.class);
            } else {
                throw new IllegalArgumentException("Prompt options are not of type ChatOptions: " + prompt.getOptions().getClass().getSimpleName());
            }
        }
        if (this.defaultOptions != null) {
            Set<String> defaultEnabledFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions, false);
            functionsForThisRequest.addAll(defaultEnabledFunctions);
            request = (OpenAiApi.ChatCompletionRequest)ModelOptionsUtils.merge((Object)request, (Object)this.defaultOptions, OpenAiApi.ChatCompletionRequest.class);
        }
        if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
            if (stream) {
                throw new IllegalArgumentException("Currently tool functions are not supported in streaming mode");
            }
            request = (OpenAiApi.ChatCompletionRequest)ModelOptionsUtils.merge((Object)OpenAiChatOptions.builder().withTools(this.getFunctionTools(functionsForThisRequest)).build(), (Object)request, OpenAiApi.ChatCompletionRequest.class);
        }
        return request;
    }

    private Set<String> handleFunctionCallbackConfigurations(OpenAiChatOptions options, boolean isRuntimeCall) {
        HashSet<String> functionToCall = new HashSet<String>();
        if (options != null) {
            if (!CollectionUtils.isEmpty(options.getFunctionCallbacks())) {
                options.getFunctionCallbacks().stream().forEach(functionCallback -> {
                    if (isRuntimeCall) {
                        this.functionCallbackRegister.put(functionCallback.getName(), (FunctionCallback)functionCallback);
                    } else {
                        this.functionCallbackRegister.putIfAbsent(functionCallback.getName(), (FunctionCallback)functionCallback);
                    }
                    if (isRuntimeCall) {
                        functionToCall.add(functionCallback.getName());
                    }
                });
            }
            if (!CollectionUtils.isEmpty(options.getFunctions())) {
                functionToCall.addAll(options.getFunctions());
            }
        }
        return functionToCall;
    }

    Map<String, FunctionCallback> getFunctionCallbackRegister() {
        return this.functionCallbackRegister;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private List<OpenAiApi.FunctionTool> getFunctionTools(Set<String> functionNames) {
        ArrayList<OpenAiApi.FunctionTool> functionTools = new ArrayList<OpenAiApi.FunctionTool>();
        for (String functionName : functionNames) {
            FunctionCallback functionCallback;
            if (!this.functionCallbackRegister.containsKey(functionName)) {
                if (this.functionCallbackContext == null) throw new IllegalStateException("No function callback found for name: " + functionName);
                functionCallback = this.functionCallbackContext.getFunctionCallback(functionName, null);
                if (functionCallback == null) throw new IllegalStateException("No function callback [" + functionName + "] fund in tht FunctionCallbackContext");
                this.functionCallbackRegister.put(functionName, functionCallback);
            }
            functionCallback = this.functionCallbackRegister.get(functionName);
            OpenAiApi.FunctionTool.Function function = new OpenAiApi.FunctionTool.Function(functionCallback.getDescription(), functionCallback.getName(), functionCallback.getInputTypeSchema());
            functionTools.add(new OpenAiApi.FunctionTool(function));
        }
        return functionTools;
    }

    private ResponseEntity<OpenAiApi.ChatCompletion> chatCompletionWithTools(OpenAiApi.ChatCompletionRequest request) {
        ResponseEntity<OpenAiApi.ChatCompletion> chatCompletion = this.openAiApi.chatCompletionEntity(request);
        if (Boolean.FALSE.equals(this.isToolCall(chatCompletion))) {
            return chatCompletion;
        }
        ArrayList<OpenAiApi.ChatCompletionMessage> conversationMessages = new ArrayList<OpenAiApi.ChatCompletionMessage>(request.messages());
        OpenAiApi.ChatCompletionMessage responseMessage = ((OpenAiApi.ChatCompletion)chatCompletion.getBody()).choices().iterator().next().message();
        if (((OpenAiApi.ChatCompletion)chatCompletion.getBody()).choices().size() > 1) {
            this.logger.warn("More than one choice returned. Only the first choice is processed.");
        }
        conversationMessages.add(responseMessage);
        for (OpenAiApi.ChatCompletionMessage.ToolCall toolCall : responseMessage.toolCalls()) {
            String functionName = toolCall.function().name();
            String functionArguments = toolCall.function().arguments();
            if (!this.functionCallbackRegister.containsKey(functionName)) {
                throw new IllegalStateException("No function callback found for function name: " + functionName);
            }
            String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
            conversationMessages.add(new OpenAiApi.ChatCompletionMessage(functionResponse, OpenAiApi.ChatCompletionMessage.Role.TOOL, null, toolCall.id(), null));
        }
        OpenAiApi.ChatCompletionRequest newRequest = new OpenAiApi.ChatCompletionRequest(conversationMessages, request.stream());
        newRequest = (OpenAiApi.ChatCompletionRequest)ModelOptionsUtils.merge((Object)newRequest, (Object)request, OpenAiApi.ChatCompletionRequest.class);
        return this.chatCompletionWithTools(newRequest);
    }

    private Map<String, Object> toMap(OpenAiApi.ChatCompletionMessage message) {
        HashMap<String, Object> map = new HashMap<String, Object>();
        if (message.toolCalls() != null) {
            map.put("tool_calls", message.toolCalls());
        }
        if (message.toolCallId() != null) {
            map.put("tool_call_id", message.toolCallId());
        }
        if (message.role() != null) {
            map.put("role", message.role().name());
        }
        return map;
    }

    private Boolean isToolCall(ResponseEntity<OpenAiApi.ChatCompletion> chatCompletion) {
        OpenAiApi.ChatCompletion body = (OpenAiApi.ChatCompletion)chatCompletion.getBody();
        if (body == null) {
            return false;
        }
        List<OpenAiApi.ChatCompletion.Choice> choices = body.choices();
        if (CollectionUtils.isEmpty(choices)) {
            return false;
        }
        return choices.get(0).message().toolCalls() != null;
    }
}

