/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.safetensors;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.tjake.jlama.model.DistributedContext;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.SafeTensorSupport;
import com.github.tjake.jlama.safetensors.TensorInfo;
import com.github.tjake.jlama.safetensors.WeightLoader;
import com.github.tjake.jlama.safetensors.Weights;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.SegmentedTensor;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Ints;
import java.io.File;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SafeTensorIndex
implements WeightLoader,
AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger(SafeTensorIndex.class);
    private static final ObjectMapper om = new ObjectMapper();
    public static final String SINGLE_MODEL_NAME = "model.safetensors";
    public static final String MODEL_INDEX_JSON = "model.safetensors.index.json";
    private final Map<String, String> metadata;
    final Map<String, TensorInfo> allTensorInfoMap = new HashMap<String, TensorInfo>();
    final Map<String, String> weightFileMap;
    private final Map<String, Weights> weightMap = new HashMap<String, Weights>();
    private final Map<String, RandomAccessFile> fileMap = new HashMap<String, RandomAccessFile>();

    public static SafeTensorIndex loadWithWeights(Path modelRoot) {
        try {
            File indexFile = Paths.get(modelRoot.toString(), MODEL_INDEX_JSON).toFile();
            SafeTensorIndex index = (SafeTensorIndex)om.readValue(indexFile, SafeTensorIndex.class);
            SafeTensorIndex.loadWeights(index, modelRoot);
            return index;
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static SafeTensorIndex loadSingleFile(Path modelRoot, String modelFile) {
        try {
            SafeTensorIndex index = new SafeTensorIndex(Collections.emptyMap(), Map.of("model-file", modelFile));
            SafeTensorIndex.loadWeights(index, modelRoot);
            return index;
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    static void loadWeights(SafeTensorIndex index, Path modelRoot) throws IOException {
        for (Map.Entry<String, String> e : index.weightFileMap.entrySet()) {
            if (index.fileMap.containsKey(e.getValue())) continue;
            RandomAccessFile raf = new RandomAccessFile(Paths.get(modelRoot.toString(), e.getValue()).toFile(), "r");
            index.fileMap.put(e.getValue(), raf);
            MappedByteBuffer header = raf.getChannel().map(FileChannel.MapMode.READ_ONLY, 0L, Math.min(0x100000L, raf.length()));
            HashMap<String, String> metadata = new HashMap<String, String>();
            Map<String, TensorInfo> tensorInfoMap = SafeTensorSupport.readTensorInfoMap(header, Optional.of(metadata));
            index.allTensorInfoMap.putAll(tensorInfoMap);
            int endOfHeaderPosition = header.position();
            Map<List<Long>, List<String>> splits = index.computeMmapSplits(tensorInfoMap, raf.length());
            for (Map.Entry<List<Long>, List<String>> split : splits.entrySet()) {
                long offset = split.getKey().get(0);
                long length = split.getKey().get(1);
                List<String> tensors = split.getValue();
                int lengthInt = Ints.checkedCast((long)(length - offset));
                MappedByteBuffer buf = raf.getChannel().map(FileChannel.MapMode.READ_ONLY, (long)endOfHeaderPosition + offset, lengthInt);
                Map mmapTensorInfoMap = (Map)tensorInfoMap.entrySet().stream().filter(x -> tensors.contains(x.getKey())).collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
                Weights mmapWeights = new Weights(metadata, mmapTensorInfoMap, buf, Optional.of(index));
                for (String tensor : tensors) {
                    index.weightMap.put(tensor, mmapWeights);
                }
            }
        }
    }

    private Map<List<Long>, List<String>> computeMmapSplits(Map<String, TensorInfo> tensorInfoMap, long fileLength) {
        HashMap<List<Long>, List<String>> splits = new HashMap<List<Long>, List<String>>();
        long lastSplitOffset = 0L;
        int tensorsInFile = tensorInfoMap.size();
        int tensorsSplit = 0;
        ArrayList<String> tensors = new ArrayList<String>();
        Iterator<Map.Entry<String, TensorInfo>> it = new ArrayList<Map.Entry<String, TensorInfo>>(tensorInfoMap.entrySet()).iterator();
        Map.Entry<String, TensorInfo> next = null;
        while (tensorsSplit < tensorsInFile && (it.hasNext() || next != null)) {
            tensors.clear();
            long limit = lastSplitOffset + Integer.MAX_VALUE;
            long startOffset = fileLength;
            long endOffset = 0L;
            while (it.hasNext() || next != null) {
                next = next == null ? it.next() : next;
                TensorInfo info = next.getValue();
                logger.debug("Tensor {} {} {} limit {}", new Object[]{next.getKey(), info.dataOffsets[0], info.dataOffsets[1], limit});
                if (info.dataOffsets[1] < limit) {
                    tensors.add(next.getKey());
                    ++tensorsSplit;
                    if (info.dataOffsets[1] > endOffset) {
                        endOffset = info.dataOffsets[1];
                    }
                    if (info.dataOffsets[0] < startOffset) {
                        startOffset = info.dataOffsets[0];
                    }
                    info.dataOffsets[0] = info.dataOffsets[0] - lastSplitOffset;
                    info.dataOffsets[1] = info.dataOffsets[1] - lastSplitOffset;
                    logger.debug("Adding tensor {} to split {}-{}", new Object[]{next.getKey(), info.dataOffsets[0], info.dataOffsets[1]});
                    next = null;
                    continue;
                }
                if (tensors.size() != 0) break;
                int bytesPerColumn = info.dType.size() * info.shape[1];
                if (info.dataOffsets[1] > endOffset) {
                    endOffset = info.dataOffsets[1];
                }
                if (info.dataOffsets[0] < startOffset) {
                    startOffset = info.dataOffsets[0];
                }
                info.dataOffsets[0] = info.dataOffsets[0] - lastSplitOffset;
                info.dataOffsets[1] = info.dataOffsets[1] - lastSplitOffset;
                long offset = info.dataOffsets[0];
                long chunkSize = Integer.MAX_VALUE - Integer.MAX_VALUE % bytesPerColumn;
                long offsetAdded = 0L;
                int chunk = 0;
                boolean added = false;
                for (long length = info.dataOffsets[1] - offset; length > 0L; length -= chunkSize) {
                    long chunkEnd = Math.min(offset + chunkSize, endOffset);
                    String chunkName = next.getKey() + "-part-" + chunk++;
                    logger.debug("Adding chunk {} to split {}-{} {}", new Object[]{chunkName, offset, chunkEnd, Ints.checkedCast((long)(chunkEnd - offset))});
                    splits.put(List.of(Long.valueOf(offset), Long.valueOf(chunkEnd)), List.of(chunkName));
                    assert (info.shape.length == 2) : "Only 2D tensors supported";
                    int numRowsInChunk = Ints.checkedCast((long)((chunkEnd - offset) / (long)bytesPerColumn));
                    TensorInfo chunkInfo = new TensorInfo(info.dType, new long[]{numRowsInChunk, info.shape[1]}, new long[]{offset - offsetAdded, chunkEnd - offsetAdded});
                    tensorInfoMap.put(chunkName, chunkInfo);
                    added = true;
                    offsetAdded += chunkEnd - offset;
                    offset = chunkEnd;
                }
                if (!added) break;
                ++tensorsSplit;
                next = null;
                break;
            }
            assert (tensorsSplit > 0) : "No tensors in split";
            logger.debug("Adding split {}-{} with {} tensors of {}", new Object[]{startOffset, endOffset, tensors.size(), tensorsSplit});
            if (!tensors.isEmpty()) {
                splits.put(List.of(Long.valueOf(startOffset), Long.valueOf(endOffset)), new ArrayList(tensors));
            }
            if (endOffset <= lastSplitOffset) continue;
            lastSplitOffset = endOffset;
        }
        assert (tensorsInFile == tensorsSplit) : "Not all tensors were split: " + tensorsSplit + " != " + tensorsInFile;
        return splits;
    }

    @JsonCreator
    SafeTensorIndex(@JsonProperty(value="metadata") Map<String, String> metadata, @JsonProperty(value="weight_map") Map<String, String> weightFileMap) {
        this.metadata = ImmutableMap.copyOf(metadata);
        this.weightFileMap = ImmutableMap.copyOf(weightFileMap);
    }

    @Override
    public Map<String, String> metadata() {
        return this.metadata;
    }

    @Override
    public Map<String, TensorInfo> tensorInfoMap() {
        return this.allTensorInfoMap;
    }

    @Override
    public AbstractTensor load(String name, DistributedContext dctx, boolean sparseRows, boolean sparseColumns) {
        Weights w = this.weightMap.get(name);
        if (w == null) {
            String segmentName;
            ArrayList<AbstractTensor> segments = new ArrayList<AbstractTensor>();
            int idx = 0;
            while (this.weightMap.containsKey(segmentName = name + "-part-" + idx++)) {
                segments.add(this.weightMap.get(segmentName).load(segmentName, dctx, sparseRows, sparseColumns));
            }
            if (segments.size() > 0) {
                return SegmentedTensor.wrap(segments);
            }
            throw new NoSuchElementException(name);
        }
        return w.load(name, dctx, sparseRows, sparseColumns);
    }

    @Override
    public DType getModelDType() {
        return this.weightMap.values().iterator().next().getModelDType();
    }

    @Override
    public void close() throws Exception {
        this.weightMap.clear();
        this.fileMap.forEach((k, v) -> {
            try {
                v.close();
            }
            catch (IOException iOException) {
                // empty catch block
            }
        });
        this.fileMap.clear();
        this.allTensorInfoMap.clear();
    }
}

