/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.mcp.server.transport;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.time.Duration;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.mcp.spec.McpError;
import org.springframework.ai.mcp.spec.McpSchema;
import org.springframework.ai.mcp.spec.McpTransport;
import org.springframework.ai.mcp.util.Assert;
import org.springframework.http.HttpStatus;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.reactive.function.server.RouterFunction;
import org.springframework.web.reactive.function.server.RouterFunctions;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.reactive.function.server.ServerResponse;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;

public class SseServerTransport
implements McpTransport {
    private static final Logger logger = LoggerFactory.getLogger(SseServerTransport.class);
    public static final String MESSAGE_EVENT_TYPE = "message";
    public static final String ENDPOINT_EVENT_TYPE = "endpoint";
    public static final String SSE_ENDPOINT = "/sse";
    private final ObjectMapper objectMapper;
    private final String messageEndpoint;
    private final RouterFunction<?> routerFunction;
    private final ConcurrentHashMap<String, ClientSession> sessions = new ConcurrentHashMap();
    private volatile boolean isClosing = false;
    private Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> connectHandler;

    public SseServerTransport(ObjectMapper objectMapper, String messageEndpoint) {
        Assert.notNull(objectMapper, "ObjectMapper must not be null");
        Assert.notNull(messageEndpoint, "Message endpoint must not be null");
        this.objectMapper = objectMapper;
        this.messageEndpoint = messageEndpoint;
        this.routerFunction = RouterFunctions.route().GET(SSE_ENDPOINT, this::handleSseConnection).POST(messageEndpoint, this::handleMessage).build();
    }

    @Override
    public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
        this.connectHandler = handler;
        return Mono.empty().then();
    }

    @Override
    public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
        if (this.sessions.isEmpty()) {
            logger.debug("No active sessions to broadcast message to");
            return Mono.empty();
        }
        return Mono.create(sink -> {
            try {
                String jsonText = this.objectMapper.writeValueAsString((Object)message);
                ServerSentEvent event = ServerSentEvent.builder().event(MESSAGE_EVENT_TYPE).data((Object)jsonText).build();
                logger.debug("Attempting to broadcast message to {} active sessions", (Object)this.sessions.size());
                List<String> failedSessions = this.sessions.values().stream().filter(session -> session.messageSink.tryEmitNext((Object)event).isFailure()).map(session -> session.id).toList();
                if (failedSessions.isEmpty()) {
                    logger.debug("Successfully broadcast message to all sessions");
                    sink.success();
                } else {
                    String error = "Failed to broadcast message to sessions: " + String.join((CharSequence)", ", failedSessions);
                    logger.error(error);
                    sink.error((Throwable)new RuntimeException(error));
                }
            }
            catch (IOException e) {
                logger.error("Failed to serialize message: {}", (Object)e.getMessage());
                sink.error((Throwable)e);
            }
        });
    }

    @Override
    public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
        return (T)this.objectMapper.convertValue(data, typeRef);
    }

    @Override
    public Mono<Void> closeGracefully() {
        return Mono.fromRunnable(() -> {
            this.isClosing = true;
            logger.debug("Initiating graceful shutdown with {} active sessions", (Object)this.sessions.size());
        }).then(Mono.when(this.sessions.values().stream().map(session -> {
            String sessionId = session.id;
            return Mono.fromRunnable(() -> session.close()).then(Mono.delay((Duration)Duration.ofMillis(100L))).then(Mono.fromRunnable(() -> this.sessions.remove(sessionId)));
        }).toList())).timeout(Duration.ofSeconds(5L)).doOnSuccess(v -> logger.info("Graceful shutdown completed")).doOnError(e -> logger.error("Error during graceful shutdown: {}", (Object)e.getMessage()));
    }

    public RouterFunction<?> getRouterFunction() {
        return this.routerFunction;
    }

    private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
        if (this.isClosing) {
            return ServerResponse.status((HttpStatusCode)HttpStatus.SERVICE_UNAVAILABLE).bodyValue((Object)"Server is shutting down");
        }
        String sessionId = UUID.randomUUID().toString();
        logger.debug("Creating new SSE connection for session: {}", (Object)sessionId);
        ClientSession session = new ClientSession(sessionId);
        this.sessions.put(sessionId, session);
        return ServerResponse.ok().contentType(MediaType.TEXT_EVENT_STREAM).body((Object)Flux.create(sink -> {
            logger.debug("Sending initial endpoint event to session: {}", (Object)sessionId);
            sink.next((Object)ServerSentEvent.builder().event(ENDPOINT_EVENT_TYPE).data((Object)this.messageEndpoint).build());
            session.messageSink.asFlux().doOnSubscribe(s -> logger.debug("Session {} subscribed to message sink", (Object)sessionId)).doOnComplete(() -> {
                logger.debug("Session {} completed", (Object)sessionId);
                this.sessions.remove(sessionId);
            }).doOnError(error -> {
                logger.error("Error in session {}: {}", (Object)sessionId, (Object)error.getMessage());
                this.sessions.remove(sessionId);
            }).doOnCancel(() -> {
                logger.debug("Session {} cancelled", (Object)sessionId);
                this.sessions.remove(sessionId);
            }).subscribe(event -> {
                logger.debug("Forwarding event to session {}: {}", (Object)sessionId, event);
                sink.next(event);
            }, arg_0 -> ((FluxSink)sink).error(arg_0), () -> ((FluxSink)sink).complete());
            sink.onCancel(() -> {
                logger.debug("Session {} cancelled", (Object)sessionId);
                this.sessions.remove(sessionId);
            });
        }), ServerSentEvent.class);
    }

    private Mono<ServerResponse> handleMessage(ServerRequest request) {
        if (this.isClosing) {
            return ServerResponse.status((HttpStatusCode)HttpStatus.SERVICE_UNAVAILABLE).bodyValue((Object)"Server is shutting down");
        }
        return request.bodyToMono(String.class).flatMap(body -> {
            try {
                McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, body);
                return Mono.just((Object)message).transform(this.connectHandler).flatMap(response -> ServerResponse.ok().build()).onErrorResume(error -> {
                    logger.error("Error processing message: {}", (Object)error.getMessage());
                    return ServerResponse.status((HttpStatusCode)HttpStatus.INTERNAL_SERVER_ERROR).bodyValue((Object)new McpError((Object)error.getMessage()));
                });
            }
            catch (IllegalArgumentException e) {
                logger.error("Failed to deserialize message: {}", (Object)e.getMessage());
                return ServerResponse.badRequest().bodyValue((Object)new McpError((Object)"Invalid message format"));
            }
            catch (IOException e) {
                logger.error("Failed to deserialize message: {}", (Object)e.getMessage());
                return ServerResponse.badRequest().bodyValue((Object)new McpError((Object)"Invalid message format"));
            }
        });
    }

    private static class ClientSession {
        private final String id;
        private final Sinks.Many<ServerSentEvent<?>> messageSink;

        ClientSession(String id) {
            this.id = id;
            logger.debug("Creating new session: {}", (Object)id);
            this.messageSink = Sinks.many().replay().latest();
            logger.debug("Session {} initialized with replay sink", (Object)id);
        }

        void close() {
            logger.debug("Closing session: {}", (Object)this.id);
            Sinks.EmitResult result = this.messageSink.tryEmitComplete();
            if (result.isFailure()) {
                logger.warn("Failed to complete message sink for session {}: {}", (Object)this.id, (Object)result);
            } else {
                logger.debug("Successfully completed message sink for session {}", (Object)this.id);
            }
        }
    }
}

