/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.vectorstore;

import com.alibaba.fastjson.JSONObject;
import io.micrometer.observation.ObservationRegistry;
import io.milvus.client.MilvusServiceClient;
import io.milvus.common.clientenum.ConsistencyLevelEnum;
import io.milvus.grpc.DataType;
import io.milvus.grpc.MutationResult;
import io.milvus.grpc.SearchResults;
import io.milvus.param.IndexType;
import io.milvus.param.MetricType;
import io.milvus.param.R;
import io.milvus.param.collection.CreateCollectionParam;
import io.milvus.param.collection.DropCollectionParam;
import io.milvus.param.collection.FieldType;
import io.milvus.param.collection.HasCollectionParam;
import io.milvus.param.collection.LoadCollectionParam;
import io.milvus.param.collection.ReleaseCollectionParam;
import io.milvus.param.dml.DeleteParam;
import io.milvus.param.dml.InsertParam;
import io.milvus.param.dml.SearchParam;
import io.milvus.param.index.CreateIndexParam;
import io.milvus.param.index.DescribeIndexParam;
import io.milvus.param.index.DropIndexParam;
import io.milvus.response.QueryResultsWrapper;
import io.milvus.response.SearchResultsWrapper;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.model.EmbeddingUtils;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
import org.springframework.ai.vectorstore.MilvusFilterExpressionConverter;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

public class MilvusVectorStore
extends AbstractObservationVectorStore
implements InitializingBean {
    private static final Logger logger = LoggerFactory.getLogger(MilvusVectorStore.class);
    public static final int OPENAI_EMBEDDING_DIMENSION_SIZE = 1536;
    public static final int INVALID_EMBEDDING_DIMENSION = -1;
    public static final String DEFAULT_DATABASE_NAME = "default";
    public static final String DEFAULT_COLLECTION_NAME = "vector_store";
    public static final String DOC_ID_FIELD_NAME = "doc_id";
    public static final String CONTENT_FIELD_NAME = "content";
    public static final String METADATA_FIELD_NAME = "metadata";
    public static final String EMBEDDING_FIELD_NAME = "embedding";
    public static final String DISTANCE_FIELD_NAME = "distance";
    public static final List<String> SEARCH_OUTPUT_FIELDS = List.of("doc_id", "content", "metadata");
    public final FilterExpressionConverter filterExpressionConverter = new MilvusFilterExpressionConverter();
    private final MilvusServiceClient milvusClient;
    private final EmbeddingModel embeddingModel;
    private final MilvusVectorStoreConfig config;
    private final boolean initializeSchema;
    private final BatchingStrategy batchingStrategy;
    private static Map<MetricType, VectorStoreSimilarityMetric> SIMILARITY_TYPE_MAPPING = Map.of(MetricType.COSINE, VectorStoreSimilarityMetric.COSINE, MetricType.L2, VectorStoreSimilarityMetric.EUCLIDEAN, MetricType.IP, VectorStoreSimilarityMetric.DOT);

    public MilvusVectorStore(MilvusServiceClient milvusClient, EmbeddingModel embeddingModel, boolean initializeSchema) {
        this(milvusClient, embeddingModel, MilvusVectorStoreConfig.defaultConfig(), initializeSchema, (BatchingStrategy)new TokenCountBatchingStrategy());
    }

    public MilvusVectorStore(MilvusServiceClient milvusClient, EmbeddingModel embeddingModel, boolean initializeSchema, BatchingStrategy batchingStrategy) {
        this(milvusClient, embeddingModel, MilvusVectorStoreConfig.defaultConfig(), initializeSchema, batchingStrategy);
    }

    public MilvusVectorStore(MilvusServiceClient milvusClient, EmbeddingModel embeddingModel, MilvusVectorStoreConfig config, boolean initializeSchema, BatchingStrategy batchingStrategy) {
        this(milvusClient, embeddingModel, config, initializeSchema, batchingStrategy, ObservationRegistry.NOOP, null);
    }

    public MilvusVectorStore(MilvusServiceClient milvusClient, EmbeddingModel embeddingModel, MilvusVectorStoreConfig config, boolean initializeSchema, BatchingStrategy batchingStrategy, ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention) {
        super(observationRegistry, customObservationConvention);
        this.initializeSchema = initializeSchema;
        Assert.notNull((Object)milvusClient, (String)"MilvusServiceClient must not be null");
        Assert.notNull((Object)milvusClient, (String)"EmbeddingModel must not be null");
        this.milvusClient = milvusClient;
        this.embeddingModel = embeddingModel;
        this.config = config;
        this.batchingStrategy = batchingStrategy;
    }

    public void doAdd(List<Document> documents) {
        Assert.notNull(documents, (String)"Documents must not be null");
        ArrayList<String> docIdArray = new ArrayList<String>();
        ArrayList<String> contentArray = new ArrayList<String>();
        ArrayList<JSONObject> metadataArray = new ArrayList<JSONObject>();
        ArrayList<List> embeddingArray = new ArrayList<List>();
        this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
        for (Document document : documents) {
            docIdArray.add(document.getId());
            contentArray.add(document.getContent());
            metadataArray.add(new JSONObject(document.getMetadata()));
            embeddingArray.add(EmbeddingUtils.toList((float[])document.getEmbedding()));
        }
        ArrayList<InsertParam.Field> fields = new ArrayList<InsertParam.Field>();
        fields.add(new InsertParam.Field(DOC_ID_FIELD_NAME, docIdArray));
        fields.add(new InsertParam.Field(CONTENT_FIELD_NAME, contentArray));
        fields.add(new InsertParam.Field(METADATA_FIELD_NAME, metadataArray));
        fields.add(new InsertParam.Field(EMBEDDING_FIELD_NAME, embeddingArray));
        InsertParam insertParam = InsertParam.newBuilder().withDatabaseName(this.config.databaseName).withCollectionName(this.config.collectionName).withFields(fields).build();
        R status = this.milvusClient.insert(insertParam);
        if (status.getException() != null) {
            throw new RuntimeException("Failed to insert:", status.getException());
        }
    }

    public Optional<Boolean> doDelete(List<String> idList) {
        Assert.notNull(idList, (String)"Document id list must not be null");
        String deleteExpression = String.format("%s in [%s]", DOC_ID_FIELD_NAME, idList.stream().map(id -> "'" + id + "'").collect(Collectors.joining(",")));
        R status = this.milvusClient.delete(DeleteParam.newBuilder().withCollectionName(this.config.collectionName).withExpr(deleteExpression).build());
        long deleteCount = ((MutationResult)status.getData()).getDeleteCnt();
        if (deleteCount != (long)idList.size()) {
            logger.warn(String.format("Deleted only %s entries from requested %s ", deleteCount, idList.size()));
        }
        return Optional.of(status.getStatus().intValue() == R.Status.Success.getCode());
    }

    public List<Document> doSimilaritySearch(SearchRequest request) {
        R respSearch;
        String nativeFilterExpressions = request.getFilterExpression() != null ? this.filterExpressionConverter.convertExpression(request.getFilterExpression()) : "";
        Assert.notNull((Object)request.getQuery(), (String)"Query string must not be null");
        float[] embedding = this.embeddingModel.embed(request.getQuery());
        SearchParam.Builder searchParamBuilder = SearchParam.newBuilder().withCollectionName(this.config.collectionName).withConsistencyLevel(ConsistencyLevelEnum.STRONG).withMetricType(this.config.metricType).withOutFields(SEARCH_OUTPUT_FIELDS).withTopK(Integer.valueOf(request.getTopK())).withVectors(List.of(EmbeddingUtils.toList((float[])embedding))).withVectorFieldName(EMBEDDING_FIELD_NAME);
        if (StringUtils.hasText((String)nativeFilterExpressions)) {
            searchParamBuilder.withExpr(nativeFilterExpressions);
        }
        if ((respSearch = this.milvusClient.search(searchParamBuilder.build())).getException() != null) {
            throw new RuntimeException("Search failed!", respSearch.getException());
        }
        SearchResultsWrapper wrapperSearch = new SearchResultsWrapper(((SearchResults)respSearch.getData()).getResults());
        return wrapperSearch.getRowRecords(0).stream().filter(rowRecord -> (double)this.getResultSimilarity((QueryResultsWrapper.RowRecord)rowRecord) >= request.getSimilarityThreshold()).map(rowRecord -> {
            String docId = (String)rowRecord.get(DOC_ID_FIELD_NAME);
            String content = (String)rowRecord.get(CONTENT_FIELD_NAME);
            JSONObject metadata = (JSONObject)rowRecord.get(METADATA_FIELD_NAME);
            metadata.put(DISTANCE_FIELD_NAME, (Object)Float.valueOf(1.0f - this.getResultSimilarity((QueryResultsWrapper.RowRecord)rowRecord)));
            return new Document(docId, content, metadata.getInnerMap());
        }).toList();
    }

    private float getResultSimilarity(QueryResultsWrapper.RowRecord rowRecord) {
        Float distance = (Float)rowRecord.get(DISTANCE_FIELD_NAME);
        return this.config.metricType == MetricType.IP || this.config.metricType == MetricType.COSINE ? distance.floatValue() : 1.0f - distance.floatValue();
    }

    public void afterPropertiesSet() throws Exception {
        if (!this.initializeSchema) {
            return;
        }
        this.createCollection();
    }

    void releaseCollection() {
        if (this.isDatabaseCollectionExists()) {
            this.milvusClient.releaseCollection(ReleaseCollectionParam.newBuilder().withCollectionName(this.config.collectionName).build());
        }
    }

    private boolean isDatabaseCollectionExists() {
        return (Boolean)this.milvusClient.hasCollection(HasCollectionParam.newBuilder().withDatabaseName(this.config.databaseName).withCollectionName(this.config.collectionName).build()).getData();
    }

    void createCollection() {
        R indexStatus;
        R indexDescriptionResponse;
        if (!this.isDatabaseCollectionExists()) {
            FieldType docIdFieldType = FieldType.newBuilder().withName(DOC_ID_FIELD_NAME).withDataType(DataType.VarChar).withMaxLength(Integer.valueOf(36)).withPrimaryKey(true).withAutoID(false).build();
            FieldType contentFieldType = FieldType.newBuilder().withName(CONTENT_FIELD_NAME).withDataType(DataType.VarChar).withMaxLength(Integer.valueOf(65535)).build();
            FieldType metadataFieldType = FieldType.newBuilder().withName(METADATA_FIELD_NAME).withDataType(DataType.JSON).build();
            FieldType embeddingFieldType = FieldType.newBuilder().withName(EMBEDDING_FIELD_NAME).withDataType(DataType.FloatVector).withDimension(Integer.valueOf(this.embeddingDimensions())).build();
            CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder().withDatabaseName(this.config.databaseName).withCollectionName(this.config.collectionName).withDescription("Spring AI Vector Store").withConsistencyLevel(ConsistencyLevelEnum.STRONG).withShardsNum(2).addFieldType(docIdFieldType).addFieldType(contentFieldType).addFieldType(metadataFieldType).addFieldType(embeddingFieldType).build();
            R collectionStatus = this.milvusClient.createCollection(createCollectionReq);
            if (collectionStatus.getException() != null) {
                throw new RuntimeException("Failed to create collection", collectionStatus.getException());
            }
        }
        if ((indexDescriptionResponse = this.milvusClient.describeIndex(DescribeIndexParam.newBuilder().withDatabaseName(this.config.databaseName).withCollectionName(this.config.collectionName).build())).getData() == null && (indexStatus = this.milvusClient.createIndex(CreateIndexParam.newBuilder().withDatabaseName(this.config.databaseName).withCollectionName(this.config.collectionName).withFieldName(EMBEDDING_FIELD_NAME).withIndexType(this.config.indexType).withMetricType(this.config.metricType).withExtraParam(this.config.indexParameters).withSyncMode(Boolean.FALSE).build())).getException() != null) {
            throw new RuntimeException("Failed to create Index", indexStatus.getException());
        }
        R loadCollectionStatus = this.milvusClient.loadCollection(LoadCollectionParam.newBuilder().withDatabaseName(this.config.databaseName).withCollectionName(this.config.collectionName).build());
        if (loadCollectionStatus.getException() != null) {
            throw new RuntimeException("Collection loading failed!", loadCollectionStatus.getException());
        }
    }

    int embeddingDimensions() {
        if (this.config.embeddingDimension != -1) {
            return this.config.embeddingDimension;
        }
        try {
            int embeddingDimensions = this.embeddingModel.dimensions();
            if (embeddingDimensions > 0) {
                return embeddingDimensions;
            }
        }
        catch (Exception e) {
            logger.warn("Failed to obtain the embedding dimensions from the embedding model and fall backs to default:" + this.config.embeddingDimension, (Throwable)e);
        }
        return 1536;
    }

    void dropCollection() {
        R status = this.milvusClient.releaseCollection(ReleaseCollectionParam.newBuilder().withCollectionName(this.config.collectionName).build());
        if (status.getException() != null) {
            throw new RuntimeException("Release collection failed!", status.getException());
        }
        status = this.milvusClient.dropIndex(DropIndexParam.newBuilder().withCollectionName(this.config.collectionName).build());
        if (status.getException() != null) {
            throw new RuntimeException("Drop Index failed!", status.getException());
        }
        status = this.milvusClient.dropCollection(DropCollectionParam.newBuilder().withDatabaseName(this.config.databaseName).withCollectionName(this.config.collectionName).build());
        if (status.getException() != null) {
            throw new RuntimeException("Drop Collection failed!", status.getException());
        }
    }

    public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) {
        return VectorStoreObservationContext.builder((String)VectorStoreProvider.MILVUS.value(), (String)operationName).withCollectionName(this.config.collectionName).withDimensions(Integer.valueOf(this.embeddingModel.dimensions())).withSimilarityMetric(this.getSimilarityMetric()).withNamespace(this.config.databaseName);
    }

    private String getSimilarityMetric() {
        if (!SIMILARITY_TYPE_MAPPING.containsKey(this.config.metricType)) {
            return this.config.metricType.name();
        }
        return SIMILARITY_TYPE_MAPPING.get(this.config.metricType).value();
    }

    public static class MilvusVectorStoreConfig {
        private final String databaseName;
        private final String collectionName;
        private final int embeddingDimension;
        private final IndexType indexType;
        private final MetricType metricType;
        private final String indexParameters;

        public static Builder builder() {
            return new Builder();
        }

        public static MilvusVectorStoreConfig defaultConfig() {
            return MilvusVectorStoreConfig.builder().build();
        }

        private MilvusVectorStoreConfig(Builder builder) {
            this.databaseName = builder.databaseName;
            this.collectionName = builder.collectionName;
            this.embeddingDimension = builder.embeddingDimension;
            this.indexType = builder.indexType;
            this.metricType = builder.metricType;
            this.indexParameters = builder.indexParameters;
        }

        public static class Builder {
            private String databaseName = "default";
            private String collectionName = "vector_store";
            private int embeddingDimension = -1;
            private IndexType indexType = IndexType.IVF_FLAT;
            private MetricType metricType = MetricType.COSINE;
            private String indexParameters = "{\"nlist\":1024}";

            private Builder() {
            }

            public Builder withMetricType(MetricType metricType) {
                Assert.notNull((Object)metricType, (String)"Collection Name must not be empty");
                Assert.isTrue((metricType == MetricType.IP || metricType == MetricType.L2 || metricType == MetricType.COSINE ? 1 : 0) != 0, (String)"Only the text metric types IP and L2 are supported");
                this.metricType = metricType;
                return this;
            }

            public Builder withIndexType(IndexType indexType) {
                this.indexType = indexType;
                return this;
            }

            public Builder withIndexParameters(String indexParameters) {
                this.indexParameters = indexParameters;
                return this;
            }

            public Builder withDatabaseName(String databaseName) {
                this.databaseName = databaseName;
                return this;
            }

            public Builder withCollectionName(String collectionName) {
                this.collectionName = collectionName;
                return this;
            }

            public Builder withEmbeddingDimension(int newEmbeddingDimension) {
                Assert.isTrue((newEmbeddingDimension >= 1 && newEmbeddingDimension <= 32768 ? 1 : 0) != 0, (String)"Dimension has to be withing the boundaries 1 and 32768 (inclusively)");
                this.embeddingDimension = newEmbeddingDimension;
                return this;
            }

            public MilvusVectorStoreConfig build() {
                return new MilvusVectorStoreConfig(this);
            }
        }
    }
}

