/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.safetensors;

import com.github.tjake.jlama.safetensors.SafeTensorSupport;
import com.github.tjake.jlama.safetensors.TensorInfo;
import com.github.tjake.jlama.safetensors.WeightLoader;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.util.JsonSupport;
import com.github.tjake.jlama.util.Pair;
import java.io.File;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;

public class SafeTensorSplitter {
    static long MAX_CHUNK_SIZE = 0x500000000L;

    static String getChunkFile(TensorInfo info, long fileSize) {
        long fileChunk = Math.floorDiv(info.dataOffsets[1], MAX_CHUNK_SIZE);
        long totalChunks = Math.floorDiv(fileSize, MAX_CHUNK_SIZE);
        return String.format("model-%05d-of-%05d.safetensor", fileChunk, totalChunks);
    }

    public static void main(String[] args) {
        if (args.length == 0) {
            throw new IllegalArgumentException("Missing model name");
        }
        String modelDir = args[0];
        if (!new File(modelDir).isDirectory()) {
            throw new IllegalArgumentException("Not a directory");
        }
        if (Paths.get(modelDir, "model.safetensors.index.json").toFile().exists()) {
            throw new IllegalArgumentException("Already split");
        }
        if (!Paths.get(modelDir, "model.safetensors").toFile().exists()) {
            throw new IllegalArgumentException("Missing model file");
        }
        WeightLoader wl = SafeTensorSupport.loadWeights(new File(modelDir));
        try {
            Map<String, TensorInfo> info = wl.tensorInfoMap();
            LinkedHashMap<String, String> tensorIndex = new LinkedHashMap<String, String>();
            HashMap<String, Pair> chunkFiles = new HashMap<String, Pair>();
            LinkedHashMap<String, Map> tensorsInChunk = new LinkedHashMap<String, Map>();
            for (Map.Entry<String, TensorInfo> entry : info.entrySet()) {
                TensorInfo tensorInfo = entry.getValue();
                String name = entry.getKey();
                String chunkName = SafeTensorSplitter.getChunkFile(tensorInfo, new File(modelDir, "model.safetensors").length());
                tensorIndex.put(name, chunkName);
                Pair chunkFile = chunkFiles.computeIfAbsent(chunkName, n -> {
                    try {
                        File tmp = File.createTempFile("jlama", "chunk");
                        tmp.deleteOnExit();
                        RandomAccessFile r = new RandomAccessFile(tmp, "rw");
                        FileChannel ch = r.getChannel();
                        return Pair.of(r, ch);
                    }
                    catch (IOException e) {
                        throw new RuntimeException(e);
                    }
                });
                AbstractTensor t = wl.load(name);
                FileChannel ch = (FileChannel)chunkFile.right;
                TensorInfo newInfo = t.save(ch);
                System.out.println("Wrote " + name + " to " + chunkName + " at " + newInfo.dataOffsets[0] + " to " + newInfo.dataOffsets[1]);
                Map tensors = tensorsInChunk.computeIfAbsent(chunkName, n -> new LinkedHashMap());
                tensors.put(name, newInfo);
            }
            for (Map.Entry<String, TensorInfo> entry : chunkFiles.entrySet()) {
                String chunkName = entry.getKey();
                Pair chunkFile = (Pair)((Object)entry.getValue());
                FileChannel ch = ((RandomAccessFile)chunkFile.left).getChannel();
                Map chunkTensors = (Map)tensorsInChunk.get(chunkName);
                byte[] header = JsonSupport.om.writeValueAsBytes((Object)chunkTensors);
                System.out.println("Writing " + chunkName + " with " + chunkTensors.size() + " tensors");
                byte[] hsize = new byte[8];
                ByteBuffer.wrap(hsize).order(ByteOrder.LITTLE_ENDIAN).putLong(header.length);
                try (RandomAccessFile raf = new RandomAccessFile(Paths.get(modelDir, chunkName).toFile(), "rw");){
                    raf.write(hsize);
                    raf.write(header);
                    raf.seek(raf.length());
                    System.out.println("Writing " + ch.size() + " bytes of data from " + raf.getChannel().position());
                    ch.transferTo(0L, ch.size(), raf.getChannel());
                }
            }
            try (RandomAccessFile raf = new RandomAccessFile(Paths.get(modelDir, "model.safetensors.index.json").toFile(), "rw");){
                raf.write(JsonSupport.om.writeValueAsBytes(Map.of("metadata", new HashMap(), "weight_map", tensorIndex)));
            }
            for (Pair pair : chunkFiles.values()) {
                ((RandomAccessFile)pair.left).close();
            }
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}

