/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.model.gpt2;

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.CausalSelfAttention;
import com.github.tjake.jlama.model.LayerNorm;
import com.github.tjake.jlama.model.MLPBlock;
import com.github.tjake.jlama.model.ModelSupport;
import com.github.tjake.jlama.model.TransformerBlock;
import com.github.tjake.jlama.model.functions.EmbedInput;
import com.github.tjake.jlama.model.functions.SampleOutput;
import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.WeightLoader;
import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer;
import com.github.tjake.jlama.tensor.AbstractTensor;
import java.util.Optional;

public class GPT2Model
extends AbstractModel {
    public GPT2Model(Config c, WeightLoader w, Tokenizer tokenizer, DType workingDType, DType workingQType, Optional<DType> modelQType) {
        super(AbstractModel.InferenceType.FULL_GENERATION, c, w, tokenizer, workingDType, workingQType, modelQType);
    }

    public GPT2Model(AbstractModel.InferenceType inferenceType, Config c, WeightLoader w, Tokenizer tokenizer, DType workingDType, DType workingQType, Optional<DType> modelQType) {
        super(inferenceType, c, w, tokenizer, workingDType, workingQType, modelQType);
    }

    @Override
    public ModelSupport.ModelType getModelType() {
        return ModelSupport.ModelType.GPT2;
    }

    @Override
    protected EmbedInput loadInputWeights() {
        AbstractTensor wte = this.weights.load("wte.weight");
        AbstractTensor wpe = this.weights.load("wpe.weight");
        return (inputToken, position) -> {
            AbstractTensor embedding = this.makeDenseTensor(1, this.c.embeddingLength);
            int i = 0;
            while (i < this.c.embeddingLength) {
                float v = wte.get(inputToken, i) + wpe.get(position, i);
                embedding.set(v, 0, i++);
            }
            return embedding;
        };
    }

    @Override
    protected TransformerBlock[] loadTransformerBlockWeights() {
        TransformerBlock[] transformerBlocks = new TransformerBlock[this.c.dctx().numberOfLayers];
        for (int i = this.c.dctx().layerStart; i < this.c.dctx().layerEnd; ++i) {
            String b = "h." + i + ".";
            String prefix = b + "attn.";
            AbstractTensor[] attnBias = this.weights.load(prefix + "c_attn.bias").split(3, 1);
            AbstractTensor[] attnWeights = this.weights.load(prefix + "c_attn.weight").transpose().split(3, 0);
            CausalSelfAttention attention = new CausalSelfAttention((AbstractModel)this, i, attnBias[0], attnBias[1], attnBias[2], attnWeights[0], attnWeights[1], attnWeights[2], this.weights.load(prefix + "c_proj.bias"), this.weights.load(prefix + "c_proj.weight").transpose());
            prefix = b + "mlp.";
            MLPBlock mlpBlock = new MLPBlock(this, this.c.activationFunction, this.weights.load(prefix + "c_fc.bias"), this.weights.load(prefix + "c_fc.weight").transpose(), this.weights.load(prefix + "c_proj.bias"), this.weights.load(prefix + "c_proj.weight").transpose());
            LayerNorm layerNorm1 = new LayerNorm(this, this.weights.load(b + "ln_1.bias"), this.weights.load(b + "ln_1.weight"));
            LayerNorm layerNorm2 = new LayerNorm(this, this.weights.load(b + "ln_2.bias"), this.weights.load(b + "ln_2.weight"));
            transformerBlocks[i] = new TransformerBlock((AbstractModel)this, i, layerNorm1, attention, layerNorm2, mlpBlock);
        }
        return transformerBlocks;
    }

    @Override
    protected SampleOutput loadOutputWeights() {
        final AbstractTensor wte = this.weights.load("wte.weight");
        final LayerNorm layerNorm = new LayerNorm(this, this.weights.load("ln_f.bias"), this.weights.load("ln_f.weight"));
        return new SampleOutput(){

            @Override
            public LayerNorm getOutputLayerNorm() {
                return layerNorm;
            }

            @Override
            public AbstractTensor getOutputLogitsWeights() {
                return wte;
            }
        };
    }
}

