/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.model;

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.LayerNorm;
import com.github.tjake.jlama.tensor.AbstractTensor;
import net.jafama.FastMath;

public class RMSNorm
extends LayerNorm {
    private final float weightAdjustment;

    public RMSNorm(AbstractModel m, AbstractTensor weights) {
        this(m, weights, 0.0f);
    }

    public RMSNorm(AbstractModel m, AbstractTensor weights, float weightAdjustment) {
        super(m, null, weights);
        this.weightAdjustment = weightAdjustment;
    }

    @Override
    public AbstractTensor forward(AbstractTensor input, int offset, int length) {
        int batchSize = input.shape().first();
        AbstractTensor output = this.m.makeDenseTensor(input.shape());
        int limit = offset + length;
        for (int b = 0; b < batchSize; ++b) {
            double ss = 0.0;
            int j = offset;
            while (j < limit) {
                float v = input.get(b, j++);
                ss += (double)(v * v);
            }
            ss /= (double)this.m.c.embeddingLength;
            ss += (double)this.m.c.layerNormEps;
            ss = 1.0 / FastMath.sqrt((double)ss);
            j = offset;
            while (j < limit) {
                output.set((this.weightAdjustment + this.weights.get(0, j)) * ((float)ss * input.get(b, j)), b, j++);
            }
        }
        return output;
    }
}

