/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.tensor;

import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.TensorShape;
import com.github.tjake.jlama.util.UnsafeDirectByteBuffer;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import java.lang.foreign.MemorySegment;
import java.nio.ByteOrder;
import java.nio.ShortBuffer;
import java.util.Arrays;
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.VectorSpecies;

public class Float16BufferTensor
extends AbstractTensor<ShortVector, Short> {
    private final ShortBuffer b;
    private final String name;
    private final MemorySegment segment;

    public Float16BufferTensor(AbstractTensor ft) {
        this(ft.shape);
        Preconditions.checkArgument((ft.dType != DType.F16 ? 1 : 0) != 0, (Object)"This should never happen, likely a bug");
        int[] cursor = new int[ft.shape.dims()];
        do {
            this.set(ft.get(cursor), cursor);
        } while (ft.iterate(cursor));
    }

    public Float16BufferTensor(int ... shape) {
        this(TensorShape.of(shape));
    }

    public Float16BufferTensor(TensorShape shape) {
        super(DType.F16, shape, true);
        this.name = "tmp";
        this.b = UnsafeDirectByteBuffer.allocateAlignedByteBuffer(Ints.checkedCast((long)(this.size() * (long)this.dType().size())), 64L).asShortBuffer();
        this.segment = MemorySegment.ofBuffer(this.b);
    }

    public Float16BufferTensor(ShortBuffer b, TensorShape shape, boolean cacheSlices) {
        this("none", b, shape, cacheSlices);
    }

    public Float16BufferTensor(String name, ShortBuffer b, TensorShape shape, boolean cacheSlices) {
        super(DType.F16, shape, cacheSlices);
        Preconditions.checkArgument((boolean)b.isDirect(), (Object)"Must use direct buffers");
        this.name = name;
        this.b = b;
        this.segment = MemorySegment.ofBuffer(b);
    }

    @Override
    protected AbstractTensor make(TensorShape shape) {
        return new Float16BufferTensor(shape);
    }

    @Override
    protected AbstractTensor make(int offset, int length, TensorShape shape, boolean cacheSlices) {
        return new Float16BufferTensor(this.name, this.b.slice(offset, length), shape, cacheSlices);
    }

    @Override
    public float get(int ... dims) {
        Preconditions.checkArgument((dims.length <= this.shape.dims() ? 1 : 0) != 0, (Object)"Too many dimensions specified");
        Preconditions.checkArgument((dims.length == this.shape.dims() ? 1 : 0) != 0, (Object)"Must specify all dimensions");
        return Float.float16ToFloat(this.b.get(this.getOffset(dims)));
    }

    @Override
    public void set(float v, int ... dims) {
        Preconditions.checkArgument((dims.length <= this.shape.dims() ? 1 : 0) != 0, (Object)"Too many dimensions specified for tensor");
        Preconditions.checkArgument((dims.length == this.shape.dims() ? 1 : 0) != 0, (Object)"Must specify all dimensions");
        Preconditions.checkArgument((!this.b.isReadOnly() ? 1 : 0) != 0, (Object)"Can't modify a read only buffer");
        this.b.put(this.getOffset(dims), Float.floatToFloat16(v));
    }

    @Override
    public ShortVector getVector(VectorSpecies<Short> species, int ... voffset) {
        int offset = this.getOffset(voffset);
        return ShortVector.fromMemorySegment(species, (MemorySegment)this.segment, (long)this.getMemorySegmentOffset(offset), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
    }

    @Override
    public void intoTensor(ShortVector vector, int ... aoffset) {
        Preconditions.checkArgument((!this.b.isReadOnly() ? 1 : 0) != 0);
        int offset = this.getOffset(aoffset);
        vector.intoMemorySegment(this.segment, (long)this.getMemorySegmentOffset(offset), ByteOrder.LITTLE_ENDIAN);
    }

    @Override
    public MemorySegment getMemorySegment() {
        return this.segment;
    }

    @Override
    public int getMemorySegmentOffset(int offset) {
        return offset * this.dType.size();
    }

    @Override
    public void copyFrom(AbstractTensor src, int srcOffset, int destOffset, int length) {
        Preconditions.checkArgument((this.dType == src.dType ? 1 : 0) != 0, (Object)"different types");
        Preconditions.checkArgument((!this.b.isReadOnly() ? 1 : 0) != 0, (Object)"Read-only");
        this.segment.asSlice((long)this.getMemorySegmentOffset(destOffset), length).copyFrom(src.getMemorySegment().asSlice((long)src.getMemorySegmentOffset(srcOffset), length));
    }

    @Override
    public void clear() {
        Preconditions.checkArgument((!this.b.isReadOnly() ? 1 : 0) != 0, (Object)"Can't clear a read-only buffer");
        this.segment.fill((byte)0);
    }

    public String toString() {
        short[] sample = new short[Math.min(10, this.b.remaining())];
        this.b.duplicate().get(sample);
        return "Float16BufferTensor{name='" + this.name + "'shape=" + String.valueOf(this.shape) + ", b=" + Arrays.toString(sample) + "...}";
    }
}

