package org.springframework.ai.mcp.server.transport;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.time.Duration;
import java.util.List;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.mcp.spec.McpError;
import org.springframework.ai.mcp.spec.McpSchema;
import org.springframework.ai.mcp.spec.McpTransport;
import org.springframework.ai.mcp.util.Assert;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.reactive.function.server.RouterFunction;
import org.springframework.web.reactive.function.server.RouterFunctions;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.reactive.function.server.ServerResponse;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;

/* loaded from: input_file:org/springframework/ai/mcp/server/transport/SseServerTransport.class */
public class SseServerTransport implements McpTransport {
    private static final Logger logger = LoggerFactory.getLogger(SseServerTransport.class);
    public static final String MESSAGE_EVENT_TYPE = "message";
    public static final String ENDPOINT_EVENT_TYPE = "endpoint";
    public static final String SSE_ENDPOINT = "/sse";
    private final ObjectMapper objectMapper;
    private final String messageEndpoint;
    private final RouterFunction<?> routerFunction;
    private final ConcurrentHashMap<String, ClientSession> sessions = new ConcurrentHashMap<>();
    private volatile boolean isClosing = false;
    private Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> connectHandler;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/springframework/ai/mcp/server/transport/SseServerTransport$ClientSession.class */
    public static class ClientSession {
        private final String id;
        private final Sinks.Many<ServerSentEvent<?>> messageSink;

        ClientSession(String str) {
            this.id = str;
            SseServerTransport.logger.debug("Creating new session: {}", str);
            this.messageSink = Sinks.many().replay().latest();
            SseServerTransport.logger.debug("Session {} initialized with replay sink", str);
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public void close() {
            SseServerTransport.logger.debug("Closing session: {}", this.id);
            Sinks.EmitResult tryEmitComplete = this.messageSink.tryEmitComplete();
            if (tryEmitComplete.isFailure()) {
                SseServerTransport.logger.warn("Failed to complete message sink for session {}: {}", this.id, tryEmitComplete);
            } else {
                SseServerTransport.logger.debug("Successfully completed message sink for session {}", this.id);
            }
        }
    }

    public SseServerTransport(ObjectMapper objectMapper, String str) {
        Assert.notNull(objectMapper, "ObjectMapper must not be null");
        Assert.notNull(str, "Message endpoint must not be null");
        this.objectMapper = objectMapper;
        this.messageEndpoint = str;
        this.routerFunction = RouterFunctions.route().GET(SSE_ENDPOINT, this::handleSseConnection).POST(str, this::handleMessage).build();
    }

    @Override // org.springframework.ai.mcp.spec.McpTransport
    public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> function) {
        this.connectHandler = function;
        return Mono.empty().then();
    }

    @Override // org.springframework.ai.mcp.spec.McpTransport
    public Mono<Void> sendMessage(McpSchema.JSONRPCMessage jSONRPCMessage) {
        if (!this.sessions.isEmpty()) {
            return Mono.create(monoSink -> {
                try {
                    ServerSentEvent build = ServerSentEvent.builder().event(MESSAGE_EVENT_TYPE).data(this.objectMapper.writeValueAsString(jSONRPCMessage)).build();
                    logger.debug("Attempting to broadcast message to {} active sessions", Integer.valueOf(this.sessions.size()));
                    List list = this.sessions.values().stream().filter(clientSession -> {
                        return clientSession.messageSink.tryEmitNext(build).isFailure();
                    }).map(clientSession2 -> {
                        return clientSession2.id;
                    }).toList();
                    if (list.isEmpty()) {
                        logger.debug("Successfully broadcast message to all sessions");
                        monoSink.success();
                    } else {
                        String str = "Failed to broadcast message to sessions: " + String.join(", ", list);
                        logger.error(str);
                        monoSink.error(new RuntimeException(str));
                    }
                } catch (IOException e) {
                    logger.error("Failed to serialize message: {}", e.getMessage());
                    monoSink.error(e);
                }
            });
        }
        logger.debug("No active sessions to broadcast message to");
        return Mono.empty();
    }

    @Override // org.springframework.ai.mcp.spec.McpTransport
    public <T> T unmarshalFrom(Object obj, TypeReference<T> typeReference) {
        return (T) this.objectMapper.convertValue(obj, typeReference);
    }

    @Override // org.springframework.ai.mcp.spec.McpTransport
    public Mono<Void> closeGracefully() {
        return Mono.fromRunnable(() -> {
            this.isClosing = true;
            logger.debug("Initiating graceful shutdown with {} active sessions", Integer.valueOf(this.sessions.size()));
        }).then(Mono.when(this.sessions.values().stream().map(clientSession -> {
            String str = clientSession.id;
            return Mono.fromRunnable(() -> {
                clientSession.close();
            }).then(Mono.delay(Duration.ofMillis(100L))).then(Mono.fromRunnable(() -> {
                this.sessions.remove(str);
            }));
        }).toList())).timeout(Duration.ofSeconds(5L)).doOnSuccess(r3 -> {
            logger.info("Graceful shutdown completed");
        }).doOnError(th -> {
            logger.error("Error during graceful shutdown: {}", th.getMessage());
        });
    }

    public RouterFunction<?> getRouterFunction() {
        return this.routerFunction;
    }

    private Mono<ServerResponse> handleSseConnection(ServerRequest serverRequest) {
        if (this.isClosing) {
            return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
        }
        String uuid = UUID.randomUUID().toString();
        logger.debug("Creating new SSE connection for session: {}", uuid);
        ClientSession clientSession = new ClientSession(uuid);
        this.sessions.put(uuid, clientSession);
        return ServerResponse.ok().contentType(MediaType.TEXT_EVENT_STREAM).body(Flux.create(fluxSink -> {
            logger.debug("Sending initial endpoint event to session: {}", uuid);
            fluxSink.next(ServerSentEvent.builder().event(ENDPOINT_EVENT_TYPE).data(this.messageEndpoint).build());
            Flux doOnCancel = clientSession.messageSink.asFlux().doOnSubscribe(subscription -> {
                logger.debug("Session {} subscribed to message sink", uuid);
            }).doOnComplete(() -> {
                logger.debug("Session {} completed", uuid);
                this.sessions.remove(uuid);
            }).doOnError(th -> {
                logger.error("Error in session {}: {}", uuid, th.getMessage());
                this.sessions.remove(uuid);
            }).doOnCancel(() -> {
                logger.debug("Session {} cancelled", uuid);
                this.sessions.remove(uuid);
            });
            Consumer consumer = serverSentEvent -> {
                logger.debug("Forwarding event to session {}: {}", uuid, serverSentEvent);
                fluxSink.next(serverSentEvent);
            };
            Objects.requireNonNull(fluxSink);
            Consumer consumer2 = fluxSink::error;
            Objects.requireNonNull(fluxSink);
            doOnCancel.subscribe(consumer, consumer2, fluxSink::complete);
            fluxSink.onCancel(() -> {
                logger.debug("Session {} cancelled", uuid);
                this.sessions.remove(uuid);
            });
        }), ServerSentEvent.class);
    }

    private Mono<ServerResponse> handleMessage(ServerRequest serverRequest) {
        return this.isClosing ? ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down") : serverRequest.bodyToMono(String.class).flatMap(str -> {
            try {
                return Mono.just(McpSchema.deserializeJsonRpcMessage(this.objectMapper, str)).transform(this.connectHandler).flatMap(jSONRPCMessage -> {
                    return ServerResponse.ok().build();
                }).onErrorResume(th -> {
                    logger.error("Error processing message: {}", th.getMessage());
                    return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).bodyValue(new McpError(th.getMessage()));
                });
            } catch (IOException e) {
                logger.error("Failed to deserialize message: {}", e.getMessage());
                return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format"));
            } catch (IllegalArgumentException e2) {
                logger.error("Failed to deserialize message: {}", e2.getMessage());
                return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format"));
            }
        });
    }
}
