/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.tensor.operations;

import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.BFloat16BufferTensor;
import com.github.tjake.jlama.tensor.FloatBufferTensor;
import com.github.tjake.jlama.tensor.Q4ByteBufferTensor;
import com.github.tjake.jlama.tensor.Q8ByteBufferTensor;
import com.github.tjake.jlama.tensor.TensorCache;
import com.github.tjake.jlama.tensor.operations.TensorOperations;
import com.github.tjake.jlama.util.BiIntConsumer;
import com.github.tjake.jlama.util.MachineSpec;
import com.github.tjake.jlama.util.PhysicalCoreExecutor;
import com.google.common.base.Preconditions;
import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.IntVector;
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.Vector;
import jdk.incubator.vector.VectorMask;
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.VectorSpecies;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class PanamaTensorOperations
implements TensorOperations {
    private static final Logger logger = LoggerFactory.getLogger(PanamaTensorOperations.class);
    static final ByteVector Q4_BYTE_SUB_128 = ByteVector.broadcast((VectorSpecies)ByteVector.SPECIES_128, (long)8L);
    static final ByteVector Q4_BYTE_MASK_128 = ByteVector.broadcast((VectorSpecies)ByteVector.SPECIES_128, (long)15L);
    static final ByteVector Q4_BYTE_SHIFT_128 = ByteVector.broadcast((VectorSpecies)ByteVector.SPECIES_128, (long)4L);
    static final ByteVector Q4_BYTE_SUB_64 = ByteVector.broadcast((VectorSpecies)ByteVector.SPECIES_64, (long)8L);
    static final ByteVector Q4_BYTE_MASK_64 = ByteVector.broadcast((VectorSpecies)ByteVector.SPECIES_64, (long)15L);
    static final ByteVector Q4_BYTE_SHIFT_64 = ByteVector.broadcast((VectorSpecies)ByteVector.SPECIES_64, (long)4L);
    static final IntVector BF16_BYTE_SHIFT = IntVector.broadcast((VectorSpecies)IntVector.SPECIES_PREFERRED, (int)16);
    static final IntVector BF16_BYTE_SHIFT_512 = IntVector.broadcast((VectorSpecies)IntVector.SPECIES_512, (int)16);
    static final FloatVector F32_ROUND_UP_512 = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_512, (float)0.5f);
    static final IntVector BF16_BYTE_SHIFT_256 = IntVector.broadcast((VectorSpecies)IntVector.SPECIES_256, (int)16);
    static final FloatVector F32_ROUND_UP_256 = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_256, (float)0.5f);
    static final IntVector BF16_BYTE_SHIFT_128 = IntVector.broadcast((VectorSpecies)IntVector.SPECIES_128, (int)16);
    static final FloatVector F32_ROUND_UP_128 = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_128, (float)0.5f);
    static final VectorMask<Byte> BYTE_MASK_32 = VectorMask.fromValues((VectorSpecies)ByteVector.SPECIES_64, (boolean[])new boolean[]{true, true, true, true, false, false, false, false});
    private final MachineSpec.Type vectorType;

    public PanamaTensorOperations(MachineSpec.Type vectorType) {
        this.vectorType = vectorType;
    }

    @Override
    public String name() {
        return "Panama Vector Operations";
    }

    @Override
    public int parallelSplitSize() {
        return PhysicalCoreExecutor.instance.get().getCoreCount();
    }

    @Override
    public void batchDotProduct(AbstractTensor result, AbstractTensor a, AbstractTensor b, int aColumnOffset, int bColumnOffset, int columnLength, int rOffset, int bRowOffset, int rowChunkSize) {
        Preconditions.checkArgument((a.dims() == 2 && b.dims() == 2 && result.dims() == 2 ? 1 : 0) != 0);
        Preconditions.checkArgument((a.shape().dim(0) == result.shape().dim(0) ? 1 : 0) != 0, (Object)"BAD M");
        Preconditions.checkArgument((rOffset == 0 || rOffset >= bRowOffset ? 1 : 0) != 0, (Object)"Result offset must be >= b row offset");
        int M = a.shape().dim(0);
        int N = rowChunkSize;
        int K = columnLength;
        Gemmer gemm = switch (a.dType()) {
            case DType.F32 -> {
                switch (b.dType()) {
                    case F32: {
                        yield new GemmerF32(this, K, a, b, result, aColumnOffset, bColumnOffset, rOffset);
                    }
                    case BF16: {
                        yield new GemmerF32BF16(this, K, a, b, result, aColumnOffset, bColumnOffset, rOffset);
                    }
                    case Q4: {
                        switch (this.vectorType) {
                            case AVX_256: {
                                yield new GemmerF32Q4_256(this, K, a, b, result, aColumnOffset, bColumnOffset, rOffset);
                            }
                            case AVX_512: {
                                yield new GemmerF32Q4_512(this, K, a, b, result, aColumnOffset, bColumnOffset, rOffset);
                            }
                        }
                        throw new UnsupportedOperationException(this.vectorType.name());
                    }
                }
                throw new UnsupportedOperationException(b.dType().name());
            }
            case DType.I8 -> {
                switch (b.dType()) {
                    case Q4: {
                        switch (this.vectorType) {
                            case AVX_256: {
                                yield new GemmerI8Q4_256(this, K, a, b, result, aColumnOffset, bColumnOffset, rOffset);
                            }
                            case AVX_512: {
                                yield new GemmerI8Q4_512(this, K, a, b, result, aColumnOffset, bColumnOffset, rOffset);
                            }
                            case ARM_128: {
                                yield new GemmerI8Q4_arm(this, K, a, b, result, aColumnOffset, bColumnOffset, rOffset);
                            }
                        }
                        throw new UnsupportedOperationException(this.vectorType.name());
                    }
                }
                throw new UnsupportedOperationException(b.dType().name());
            }
            case DType.BF16 -> {
                switch (b.dType()) {
                    case BF16: {
                        yield new GemmerBF16(this, K, a, b, result, aColumnOffset, bColumnOffset, rOffset);
                    }
                }
                throw new UnsupportedOperationException(b.dType().name());
            }
            default -> throw new UnsupportedOperationException(a.dType().name() + " " + b.dType().name());
        };
        gemm.matmul(0, M, bRowOffset, bRowOffset + N);
    }

    @Override
    public AbstractTensor quantize(AbstractTensor t, DType qtype, int offset, int length) {
        Preconditions.checkArgument((t.dims() == 2 && length % 32 == 0 ? 1 : 0) != 0);
        return switch (t.dType()) {
            case DType.F32 -> {
                switch (qtype) {
                    case I8: {
                        switch (this.vectorType) {
                            case AVX_512: {
                                yield this.quantizeQ8_512((FloatBufferTensor)t, offset, length);
                            }
                            case AVX_256: {
                                yield this.quantizeQ8_256((FloatBufferTensor)t, offset, length);
                            }
                            case ARM_128: {
                                yield this.quantizeQ8_arm((FloatBufferTensor)t, offset, length);
                            }
                        }
                        throw new UnsupportedOperationException();
                    }
                    case BF16: {
                        yield this.quantizeBF16((FloatBufferTensor)t, offset, length);
                    }
                }
                throw new UnsupportedOperationException("F32 => " + String.valueOf((Object)qtype));
            }
            case DType.BF16 -> {
                switch (qtype) {
                    case I8: {
                        switch (this.vectorType) {
                            case AVX_512: {
                                yield this.quantizeBF16_Q8_512((BFloat16BufferTensor)t, offset, length);
                            }
                            case AVX_256: {
                                yield this.quantizeBF16_Q8_256((BFloat16BufferTensor)t, offset, length);
                            }
                            case ARM_128: {
                                yield this.quantizeBF16_Q8_arm((BFloat16BufferTensor)t, offset, length);
                            }
                        }
                        throw new UnsupportedOperationException();
                    }
                    case F32: {
                        yield this.quantizeBF16_F32((BFloat16BufferTensor)t, offset, length);
                    }
                }
                throw new UnsupportedOperationException("BF16 => " + String.valueOf((Object)qtype));
            }
            default -> throw new UnsupportedOperationException(String.valueOf((Object)t.dType()));
        };
    }

    public BFloat16BufferTensor quantizeBF16(FloatBufferTensor ft, int offset, int length) {
        return new BFloat16BufferTensor(ft);
    }

    public FloatBufferTensor quantizeBF16_F32(BFloat16BufferTensor ft, int offset, int length) {
        FloatBufferTensor qft = (FloatBufferTensor)TensorCache.instance.get(DType.F32, ft.shape());
        int batchSize = ft.shape().first();
        for (int b = 0; b < batchSize; ++b) {
            for (int i = offset; i < offset + length; i += ShortVector.SPECIES_PREFERRED.length()) {
                ShortVector sa = ft.getVector((VectorSpecies<Short>)ShortVector.SPECIES_PREFERRED, new int[]{b, i});
                FloatVector af0 = sa.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT).reinterpretAsFloats();
                FloatVector af1 = sa.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 1).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT).reinterpretAsFloats();
                qft.intoTensor(af0, b, i);
                qft.intoTensor(af1, b, i + FloatVector.SPECIES_PREFERRED.length());
            }
        }
        return qft;
    }

    public Q8ByteBufferTensor quantizeQ8_512(FloatBufferTensor ft, int offset, int length) {
        Q8ByteBufferTensor qft = (Q8ByteBufferTensor)TensorCache.instance.get(DType.I8, ft.shape());
        int batchSize = ft.shape().first();
        for (int b = 0; b < batchSize; ++b) {
            for (int i = offset; i < offset + length; i += 32) {
                FloatVector fv0 = ft.getVector((VectorSpecies<Float>)FloatVector.SPECIES_512, new int[]{b, i});
                FloatVector fv1 = ft.getVector((VectorSpecies<Float>)FloatVector.SPECIES_512, new int[]{b, i + 16});
                FloatVector maxAbs0 = fv0.abs();
                FloatVector maxAbs1 = fv1.abs();
                float maxScalar = maxAbs0.max((Vector)maxAbs1).reduceLanes(VectorOperators.MAX);
                float d = maxScalar / 127.0f;
                float id = maxScalar != 0.0f ? 127.0f / maxScalar : 0.0f;
                FloatVector vid = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_512, (float)id);
                FloatVector fvq0 = fv0.mul((Vector)vid).add((Vector)F32_ROUND_UP_512);
                FloatVector fvq1 = fv1.mul((Vector)vid).add((Vector)F32_ROUND_UP_512);
                ByteVector bvq0 = fvq0.convertShape(VectorOperators.F2B, ByteVector.SPECIES_128, 0).reinterpretAsBytes();
                ByteVector bvq1 = fvq1.convertShape(VectorOperators.F2B, ByteVector.SPECIES_128, 0).reinterpretAsBytes();
                qft.intoTensor(bvq0, b, i);
                qft.intoTensor(bvq1, b, i + 16);
                try {
                    qft.getBlockF().set(d, b, (int)((float)i * 0.03125f));
                    continue;
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }
        return qft;
    }

    public Q8ByteBufferTensor quantizeQ8_256(FloatBufferTensor ft, int offset, int length) {
        Q8ByteBufferTensor qft = (Q8ByteBufferTensor)TensorCache.instance.get(DType.I8, ft.shape());
        int batchSize = ft.shape().first();
        for (int b = 0; b < batchSize; ++b) {
            for (int i = offset; i < offset + length; i += 32) {
                FloatVector fv0 = ft.getVector((VectorSpecies<Float>)FloatVector.SPECIES_256, new int[]{b, i});
                FloatVector fv1 = ft.getVector((VectorSpecies<Float>)FloatVector.SPECIES_256, new int[]{b, i + 8});
                FloatVector fv2 = ft.getVector((VectorSpecies<Float>)FloatVector.SPECIES_256, new int[]{b, i + 16});
                FloatVector fv3 = ft.getVector((VectorSpecies<Float>)FloatVector.SPECIES_256, new int[]{b, i + 24});
                FloatVector maxAbs0 = fv0.abs();
                FloatVector maxAbs1 = fv1.abs();
                FloatVector maxAbs2 = fv2.abs();
                FloatVector maxAbs3 = fv3.abs();
                FloatVector m0 = maxAbs0.max((Vector)maxAbs1);
                FloatVector m1 = maxAbs2.max((Vector)maxAbs3);
                float maxScalar = m0.max((Vector)m1).reduceLanes(VectorOperators.MAX);
                float d = maxScalar / 127.0f;
                float id = maxScalar != 0.0f ? 127.0f / maxScalar : 0.0f;
                FloatVector vid = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_256, (float)id);
                FloatVector fvq0 = fv0.mul((Vector)vid).add((Vector)F32_ROUND_UP_256);
                FloatVector fvq1 = fv1.mul((Vector)vid).add((Vector)F32_ROUND_UP_256);
                FloatVector fvq2 = fv2.mul((Vector)vid).add((Vector)F32_ROUND_UP_256);
                FloatVector fvq3 = fv3.mul((Vector)vid).add((Vector)F32_ROUND_UP_256);
                ByteVector bvq0 = fvq0.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector bvq1 = fvq1.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector bvq2 = fvq2.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector bvq3 = fvq3.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                qft.intoTensor(bvq0, b, i);
                qft.intoTensor(bvq1, b, i + 8);
                qft.intoTensor(bvq2, b, i + 16);
                qft.intoTensor(bvq3, b, i + 24);
                qft.getBlockF().set(d, b, (int)((float)i * 0.03125f));
            }
        }
        return qft;
    }

    public Q8ByteBufferTensor quantizeQ8_arm(FloatBufferTensor ft, int offset, int length) {
        Q8ByteBufferTensor qft = (Q8ByteBufferTensor)TensorCache.instance.get(DType.I8, ft.shape());
        int batchSize = ft.shape().first();
        for (int b = 0; b < batchSize; ++b) {
            for (int i = offset; i < offset + length; i += 32) {
                FloatVector fv0 = ft.getVector((VectorSpecies<Float>)FloatVector.SPECIES_128, new int[]{b, i + 0});
                FloatVector fv1 = ft.getVector((VectorSpecies<Float>)FloatVector.SPECIES_128, new int[]{b, i + 4});
                FloatVector fv2 = ft.getVector((VectorSpecies<Float>)FloatVector.SPECIES_128, new int[]{b, i + 8});
                FloatVector fv3 = ft.getVector((VectorSpecies<Float>)FloatVector.SPECIES_128, new int[]{b, i + 12});
                FloatVector fv4 = ft.getVector((VectorSpecies<Float>)FloatVector.SPECIES_128, new int[]{b, i + 16});
                FloatVector fv5 = ft.getVector((VectorSpecies<Float>)FloatVector.SPECIES_128, new int[]{b, i + 20});
                FloatVector fv6 = ft.getVector((VectorSpecies<Float>)FloatVector.SPECIES_128, new int[]{b, i + 24});
                FloatVector fv7 = ft.getVector((VectorSpecies<Float>)FloatVector.SPECIES_128, new int[]{b, i + 28});
                FloatVector maxAbs0 = fv0.abs();
                FloatVector maxAbs1 = fv1.abs();
                FloatVector maxAbs2 = fv2.abs();
                FloatVector maxAbs3 = fv3.abs();
                FloatVector maxAbs4 = fv4.abs();
                FloatVector maxAbs5 = fv5.abs();
                FloatVector maxAbs6 = fv6.abs();
                FloatVector maxAbs7 = fv7.abs();
                FloatVector m0 = maxAbs0.max((Vector)maxAbs1);
                FloatVector m1 = maxAbs2.max((Vector)maxAbs3);
                FloatVector m2 = maxAbs4.max((Vector)maxAbs5);
                FloatVector m3 = maxAbs6.max((Vector)maxAbs7);
                FloatVector m4 = m0.max((Vector)m1);
                FloatVector m5 = m2.max((Vector)m3);
                float maxScalar = m4.max((Vector)m5).reduceLanes(VectorOperators.MAX);
                float d = maxScalar / 127.0f;
                float id = maxScalar != 0.0f ? 127.0f / maxScalar : 0.0f;
                FloatVector vid = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_128, (float)id);
                FloatVector fvq0 = fv0.mul((Vector)vid).add((Vector)F32_ROUND_UP_128);
                FloatVector fvq1 = fv1.mul((Vector)vid).add((Vector)F32_ROUND_UP_128);
                FloatVector fvq2 = fv2.mul((Vector)vid).add((Vector)F32_ROUND_UP_128);
                FloatVector fvq3 = fv3.mul((Vector)vid).add((Vector)F32_ROUND_UP_128);
                FloatVector fvq4 = fv4.mul((Vector)vid).add((Vector)F32_ROUND_UP_128);
                FloatVector fvq5 = fv5.mul((Vector)vid).add((Vector)F32_ROUND_UP_128);
                FloatVector fvq6 = fv6.mul((Vector)vid).add((Vector)F32_ROUND_UP_128);
                FloatVector fvq7 = fv7.mul((Vector)vid).add((Vector)F32_ROUND_UP_128);
                ByteVector bvq0 = fvq0.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector bvq1 = fvq1.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector bvq2 = fvq2.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector bvq3 = fvq3.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector bvq4 = fvq4.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector bvq5 = fvq5.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector bvq6 = fvq6.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector bvq7 = fvq7.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                qft.intoTensor(bvq0, BYTE_MASK_32, b, i + 0);
                qft.intoTensor(bvq1, BYTE_MASK_32, b, i + 4);
                qft.intoTensor(bvq2, BYTE_MASK_32, b, i + 8);
                qft.intoTensor(bvq3, BYTE_MASK_32, b, i + 12);
                qft.intoTensor(bvq4, BYTE_MASK_32, b, i + 16);
                qft.intoTensor(bvq5, BYTE_MASK_32, b, i + 20);
                qft.intoTensor(bvq6, BYTE_MASK_32, b, i + 24);
                qft.intoTensor(bvq7, BYTE_MASK_32, b, i + 28);
                qft.getBlockF().set(d, b, (int)((float)i * 0.03125f));
            }
        }
        return qft;
    }

    public Q8ByteBufferTensor quantizeBF16_Q8_512(BFloat16BufferTensor ft, int offset, int length) {
        Q8ByteBufferTensor qft = (Q8ByteBufferTensor)TensorCache.instance.get(DType.I8, ft.shape());
        int batchSize = ft.shape().first();
        for (int b = 0; b < batchSize; ++b) {
            for (int i = offset; i < offset + length; i += 32) {
                ShortVector sv = ft.getVector((VectorSpecies<Short>)ShortVector.SPECIES_512, new int[]{b, i});
                FloatVector fv0 = sv.convertShape(VectorOperators.S2I, IntVector.SPECIES_512, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT_512).reinterpretAsFloats();
                FloatVector fv1 = sv.convertShape(VectorOperators.S2I, IntVector.SPECIES_512, 1).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT_512).reinterpretAsFloats();
                FloatVector maxAbs0 = fv0.abs();
                FloatVector maxAbs1 = fv1.abs();
                float maxScalar = maxAbs0.max((Vector)maxAbs1).reduceLanes(VectorOperators.MAX);
                float d = maxScalar / 127.0f;
                float id = maxScalar != 0.0f ? 127.0f / maxScalar : 0.0f;
                FloatVector vid = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_512, (float)id);
                FloatVector fvq0 = fv0.mul((Vector)vid).add((Vector)F32_ROUND_UP_512);
                FloatVector fvq1 = fv1.mul((Vector)vid).add((Vector)F32_ROUND_UP_512);
                ByteVector bvq0 = fvq0.convertShape(VectorOperators.F2B, ByteVector.SPECIES_128, 0).reinterpretAsBytes();
                ByteVector bvq1 = fvq1.convertShape(VectorOperators.F2B, ByteVector.SPECIES_128, 0).reinterpretAsBytes();
                qft.intoTensor(bvq0, b, i);
                qft.intoTensor(bvq1, b, i + 16);
                try {
                    qft.getBlockF().set(d, b, (int)((float)i * 0.03125f));
                    continue;
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }
        return qft;
    }

    public Q8ByteBufferTensor quantizeBF16_Q8_256(BFloat16BufferTensor ft, int offset, int length) {
        Q8ByteBufferTensor qft = (Q8ByteBufferTensor)TensorCache.instance.get(DType.I8, ft.shape());
        int batchSize = ft.shape().first();
        for (int b = 0; b < batchSize; ++b) {
            for (int i = offset; i < offset + length; i += 32) {
                ShortVector sv = ft.getVector((VectorSpecies<Short>)ShortVector.SPECIES_256, new int[]{b, i});
                FloatVector fv0 = sv.convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT_256).reinterpretAsFloats();
                FloatVector fv1 = sv.convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 1).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT_256).reinterpretAsFloats();
                sv = ft.getVector((VectorSpecies<Short>)ShortVector.SPECIES_256, new int[]{b, i + 16});
                FloatVector fv2 = sv.convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT_256).reinterpretAsFloats();
                FloatVector fv3 = sv.convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 1).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT_256).reinterpretAsFloats();
                FloatVector maxAbs0 = fv0.abs();
                FloatVector maxAbs1 = fv1.abs();
                FloatVector maxAbs2 = fv2.abs();
                FloatVector maxAbs3 = fv3.abs();
                FloatVector m0 = maxAbs0.max((Vector)maxAbs1);
                FloatVector m1 = maxAbs2.max((Vector)maxAbs3);
                float maxScalar = m0.max((Vector)m1).reduceLanes(VectorOperators.MAX);
                float d = maxScalar / 127.0f;
                float id = maxScalar != 0.0f ? 127.0f / maxScalar : 0.0f;
                FloatVector vid = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_256, (float)id);
                FloatVector fvq0 = fv0.mul((Vector)vid).add((Vector)F32_ROUND_UP_256);
                FloatVector fvq1 = fv1.mul((Vector)vid).add((Vector)F32_ROUND_UP_256);
                FloatVector fvq2 = fv2.mul((Vector)vid).add((Vector)F32_ROUND_UP_256);
                FloatVector fvq3 = fv3.mul((Vector)vid).add((Vector)F32_ROUND_UP_256);
                ByteVector bvq0 = fvq0.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector bvq1 = fvq1.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector bvq2 = fvq2.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector bvq3 = fvq3.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                qft.intoTensor(bvq0, b, i);
                qft.intoTensor(bvq1, b, i + 8);
                qft.intoTensor(bvq2, b, i + 16);
                qft.intoTensor(bvq3, b, i + 24);
                qft.getBlockF().set(d, b, (int)((float)i * 0.03125f));
            }
        }
        return qft;
    }

    public Q8ByteBufferTensor quantizeBF16_Q8_arm(BFloat16BufferTensor ft, int offset, int length) {
        Q8ByteBufferTensor qft = (Q8ByteBufferTensor)TensorCache.instance.get(DType.I8, ft.shape());
        int batchSize = ft.shape().first();
        for (int b = 0; b < batchSize; ++b) {
            for (int i = offset; i < offset + length; i += 32) {
                ShortVector sv = ft.getVector((VectorSpecies<Short>)ShortVector.SPECIES_128, new int[]{b, i});
                FloatVector fv0 = sv.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT_128).reinterpretAsFloats();
                FloatVector fv1 = sv.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT_128).reinterpretAsFloats();
                sv = ft.getVector((VectorSpecies<Short>)ShortVector.SPECIES_128, new int[]{b, i + 8});
                FloatVector fv2 = sv.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT_128).reinterpretAsFloats();
                FloatVector fv3 = sv.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT_128).reinterpretAsFloats();
                sv = ft.getVector((VectorSpecies<Short>)ShortVector.SPECIES_128, new int[]{b, i + 16});
                FloatVector fv4 = sv.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT_128).reinterpretAsFloats();
                FloatVector fv5 = sv.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT_128).reinterpretAsFloats();
                sv = ft.getVector((VectorSpecies<Short>)ShortVector.SPECIES_128, new int[]{b, i + 24});
                FloatVector fv6 = sv.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT_128).reinterpretAsFloats();
                FloatVector fv7 = sv.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT_128).reinterpretAsFloats();
                FloatVector maxAbs0 = fv0.abs();
                FloatVector maxAbs1 = fv1.abs();
                FloatVector maxAbs2 = fv2.abs();
                FloatVector maxAbs3 = fv3.abs();
                FloatVector maxAbs4 = fv4.abs();
                FloatVector maxAbs5 = fv5.abs();
                FloatVector maxAbs6 = fv6.abs();
                FloatVector maxAbs7 = fv7.abs();
                FloatVector m0 = maxAbs0.max((Vector)maxAbs1);
                FloatVector m1 = maxAbs2.max((Vector)maxAbs3);
                FloatVector m2 = maxAbs4.max((Vector)maxAbs5);
                FloatVector m3 = maxAbs6.max((Vector)maxAbs7);
                FloatVector m4 = m0.max((Vector)m1);
                FloatVector m5 = m2.max((Vector)m3);
                float maxScalar = m4.max((Vector)m5).reduceLanes(VectorOperators.MAX);
                float d = maxScalar / 127.0f;
                float id = maxScalar != 0.0f ? 127.0f / maxScalar : 0.0f;
                FloatVector vid = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_128, (float)id);
                FloatVector fvq0 = fv0.mul((Vector)vid).add((Vector)F32_ROUND_UP_128);
                FloatVector fvq1 = fv1.mul((Vector)vid).add((Vector)F32_ROUND_UP_128);
                FloatVector fvq2 = fv2.mul((Vector)vid).add((Vector)F32_ROUND_UP_128);
                FloatVector fvq3 = fv3.mul((Vector)vid).add((Vector)F32_ROUND_UP_128);
                FloatVector fvq4 = fv4.mul((Vector)vid).add((Vector)F32_ROUND_UP_128);
                FloatVector fvq5 = fv5.mul((Vector)vid).add((Vector)F32_ROUND_UP_128);
                FloatVector fvq6 = fv6.mul((Vector)vid).add((Vector)F32_ROUND_UP_128);
                FloatVector fvq7 = fv7.mul((Vector)vid).add((Vector)F32_ROUND_UP_128);
                ByteVector bvq0 = fvq0.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector bvq1 = fvq1.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector bvq2 = fvq2.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector bvq3 = fvq3.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector bvq4 = fvq4.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector bvq5 = fvq5.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector bvq6 = fvq6.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                ByteVector bvq7 = fvq7.convertShape(VectorOperators.F2B, ByteVector.SPECIES_64, 0).reinterpretAsBytes();
                qft.intoTensor(bvq0, BYTE_MASK_32, b, i + 0);
                qft.intoTensor(bvq1, BYTE_MASK_32, b, i + 4);
                qft.intoTensor(bvq2, BYTE_MASK_32, b, i + 8);
                qft.intoTensor(bvq3, BYTE_MASK_32, b, i + 12);
                qft.intoTensor(bvq4, BYTE_MASK_32, b, i + 16);
                qft.intoTensor(bvq5, BYTE_MASK_32, b, i + 20);
                qft.intoTensor(bvq6, BYTE_MASK_32, b, i + 24);
                qft.intoTensor(bvq7, BYTE_MASK_32, b, i + 28);
                qft.getBlockF().set(d, b, (int)((float)i * 0.03125f));
            }
        }
        return qft;
    }

    @Override
    public void maccumulate(AbstractTensor aBatch, AbstractTensor bBatch, int offset, int limit) {
        Preconditions.checkArgument((aBatch.dType() == bBatch.dType() ? 1 : 0) != 0);
        Preconditions.checkArgument((limit % 8 == 0 ? 1 : 0) != 0);
        boolean isBatch = bBatch.shape().first() > 1;
        block4: for (int ai = 0; ai < aBatch.shape().first(); ++ai) {
            AbstractTensor a = aBatch.slice(ai);
            AbstractTensor b = isBatch ? bBatch.slice(ai) : bBatch;
            switch (a.dType()) {
                case F32: {
                    this.maccumulateF32((FloatBufferTensor)a, (FloatBufferTensor)b, offset, limit);
                    continue block4;
                }
                case BF16: {
                    this.maccumulateBF16((BFloat16BufferTensor)a, (BFloat16BufferTensor)b, offset, limit);
                    continue block4;
                }
                default: {
                    throw new UnsupportedOperationException(a.dType().name());
                }
            }
        }
    }

    void maccumulateF32(FloatBufferTensor a, FloatBufferTensor b, int offset, int limit) {
        int i;
        int upperBound = offset + FloatVector.SPECIES_PREFERRED.loopBound(limit);
        for (i = offset; i < upperBound; i += FloatVector.SPECIES_PREFERRED.length()) {
            FloatVector va = a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_PREFERRED, new int[]{0, i});
            FloatVector vb = b.getVector((VectorSpecies<Float>)FloatVector.SPECIES_PREFERRED, new int[]{0, i});
            a.intoTensor(va.mul((Vector)vb), 0, i);
        }
        while (i < offset + limit) {
            a.set(a.get(0, i) * b.get(0, i), 0, i++);
        }
    }

    void maccumulateBF16(BFloat16BufferTensor a, BFloat16BufferTensor b, int offset, int limit) {
        int i;
        int upperBound = offset + ShortVector.SPECIES_PREFERRED.loopBound(limit);
        int half = ShortVector.SPECIES_PREFERRED.length() / 2;
        for (i = offset; i < upperBound; i += ShortVector.SPECIES_PREFERRED.length()) {
            ShortVector sa = a.getVector((VectorSpecies<Short>)ShortVector.SPECIES_PREFERRED, new int[]{0, i});
            FloatVector af0 = sa.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT).reinterpretAsFloats();
            FloatVector af1 = sa.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 1).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT).reinterpretAsFloats();
            ShortVector sb = b.getVector((VectorSpecies<Short>)ShortVector.SPECIES_PREFERRED, new int[]{0, i});
            FloatVector bf0 = sb.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT).reinterpretAsFloats();
            FloatVector bf1 = sb.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 1).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT).reinterpretAsFloats();
            Vector r0 = af0.mul((Vector)bf0).reinterpretAsInts().lanewise(VectorOperators.ASHR, (Vector)BF16_BYTE_SHIFT).convertShape(VectorOperators.I2S, ShortVector.SPECIES_PREFERRED, 0);
            Vector r1 = af1.mul((Vector)bf1).reinterpretAsInts().lanewise(VectorOperators.ASHR, (Vector)BF16_BYTE_SHIFT).convertShape(VectorOperators.I2S, ShortVector.SPECIES_PREFERRED, -1);
            VectorMask mask = VectorMask.fromLong((VectorSpecies)ShortVector.SPECIES_PREFERRED, (long)((1L << half) - 1L));
            mask = mask.not();
            Vector r = r0.blend(r1, mask);
            a.intoTensor((ShortVector)r, 0, i);
        }
        while (i < offset + limit) {
            a.set(a.get(0, i) * b.get(0, i), 0, i++);
        }
    }

    @Override
    public void accumulate(AbstractTensor aBatch, AbstractTensor bBatch, int offset, int limit) {
        boolean isBatch = bBatch.shape().first() > 1;
        block25: for (int ai = 0; ai < aBatch.shape().first(); ++ai) {
            AbstractTensor a = aBatch.slice(ai);
            AbstractTensor b = isBatch ? bBatch.slice(ai) : bBatch;
            switch (a.dType()) {
                case F32: {
                    switch (b.dType()) {
                        case F32: {
                            this.accumulateF32((FloatBufferTensor)a, (FloatBufferTensor)b, offset, limit);
                            continue block25;
                        }
                        case Q4: {
                            switch (this.vectorType) {
                                case AVX_256: 
                                case AVX_512: {
                                    this.accumulateF32Q4_256((FloatBufferTensor)a, (Q4ByteBufferTensor)b, offset, limit);
                                    continue block25;
                                }
                                case ARM_128: {
                                    this.accumulateF32Q4_arm((FloatBufferTensor)a, (Q4ByteBufferTensor)b, offset, limit);
                                    continue block25;
                                }
                            }
                            throw new UnsupportedOperationException();
                        }
                        case BF16: {
                            switch (this.vectorType) {
                                case AVX_256: 
                                case AVX_512: {
                                    this.accumulateF32BF16_256((FloatBufferTensor)a, (BFloat16BufferTensor)b, offset, limit);
                                    continue block25;
                                }
                                case ARM_128: {
                                    this.accumulateF32BF16_arm((FloatBufferTensor)a, (BFloat16BufferTensor)b, offset, limit);
                                    continue block25;
                                }
                            }
                            throw new UnsupportedOperationException();
                        }
                    }
                    throw new UnsupportedOperationException("F32 => " + String.valueOf((Object)b.dType()));
                }
                case BF16: {
                    switch (b.dType()) {
                        case BF16: {
                            switch (this.vectorType) {
                                case AVX_512: {
                                    this.accumulateBF16_512((BFloat16BufferTensor)a, (BFloat16BufferTensor)b, offset, limit);
                                    continue block25;
                                }
                                case AVX_256: {
                                    this.accumulateBF16_256((BFloat16BufferTensor)a, (BFloat16BufferTensor)b, offset, limit);
                                    continue block25;
                                }
                                case ARM_128: {
                                    this.accumulateBF16_arm((BFloat16BufferTensor)a, (BFloat16BufferTensor)b, offset, limit);
                                    continue block25;
                                }
                            }
                            throw new UnsupportedOperationException();
                        }
                    }
                    throw new UnsupportedOperationException();
                }
                default: {
                    throw new UnsupportedOperationException(String.valueOf((Object)a.dType()));
                }
            }
        }
    }

    private void accumulateF32Q4_arm(FloatBufferTensor a, Q4ByteBufferTensor b, int offset, int limit) {
        int aoffset = offset;
        int boffset = offset;
        int alim = offset + FloatVector.SPECIES_128.loopBound(limit);
        int slen = 32;
        while (aoffset < alim) {
            FloatVector scale = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_128, (float)b.getFactorForIndex(0, boffset));
            FloatVector af0 = a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_128, new int[]{0, aoffset});
            FloatVector af1 = a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_128, new int[]{0, aoffset + 4});
            FloatVector af2 = a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_128, new int[]{0, aoffset + 8});
            FloatVector af3 = a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_128, new int[]{0, aoffset + 12});
            FloatVector af4 = a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_128, new int[]{0, aoffset + 16});
            FloatVector af5 = a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_128, new int[]{0, aoffset + 20});
            FloatVector af6 = a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_128, new int[]{0, aoffset + 24});
            FloatVector af7 = a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_128, new int[]{0, aoffset + 28});
            ByteVector bf0 = b.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_64, new int[]{0, boffset});
            ByteVector bf1 = b.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_64, new int[]{0, boffset + 16});
            ByteVector low = bf0.lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_64).sub((Vector)Q4_BYTE_SUB_64);
            ByteVector high = bf0.lanewise(VectorOperators.ASHR, (Vector)Q4_BYTE_SHIFT_64).lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_64).sub((Vector)Q4_BYTE_SUB_64);
            Vector low0 = low.castShape(ShortVector.SPECIES_128, 0);
            Vector lowf0 = low0.convertShape(VectorOperators.S2F, FloatVector.SPECIES_128, 0);
            Vector lowf1 = low0.convertShape(VectorOperators.S2F, FloatVector.SPECIES_128, 1);
            Vector high0 = high.castShape(ShortVector.SPECIES_128, 0);
            Vector highf0 = high0.convertShape(VectorOperators.S2F, FloatVector.SPECIES_128, 0);
            Vector highf1 = high0.convertShape(VectorOperators.S2F, FloatVector.SPECIES_128, 1);
            ByteVector nlow = bf1.lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_64).sub((Vector)Q4_BYTE_SUB_64);
            ByteVector nhigh = bf1.lanewise(VectorOperators.ASHR, (Vector)Q4_BYTE_SHIFT_64).lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_64).sub((Vector)Q4_BYTE_SUB_64);
            Vector low2 = nlow.castShape(ShortVector.SPECIES_128, 0);
            Vector low2f0 = low2.convertShape(VectorOperators.S2F, FloatVector.SPECIES_128, 0);
            Vector low2f1 = low2.convertShape(VectorOperators.S2F, FloatVector.SPECIES_128, 1);
            Vector high2 = nhigh.castShape(ShortVector.SPECIES_128, 0);
            Vector high2f0 = high2.convertShape(VectorOperators.S2F, FloatVector.SPECIES_128, 0);
            Vector high2f1 = high2.convertShape(VectorOperators.S2F, FloatVector.SPECIES_128, 1);
            a.intoTensor(af0.add(lowf0.mul((Vector)scale)), 0, aoffset);
            a.intoTensor(af1.add(lowf1.mul((Vector)scale)), 0, aoffset + 4);
            a.intoTensor(af2.add(low2f0.mul((Vector)scale)), 0, aoffset + 8);
            a.intoTensor(af3.add(low2f1.mul((Vector)scale)), 0, aoffset + 12);
            a.intoTensor(af4.add(highf0.mul((Vector)scale)), 0, aoffset + 16);
            a.intoTensor(af5.add(highf1.mul((Vector)scale)), 0, aoffset + 20);
            a.intoTensor(af6.add(high2f0.mul((Vector)scale)), 0, aoffset + 24);
            a.intoTensor(af7.add(high2f1.mul((Vector)scale)), 0, aoffset + 28);
            aoffset += slen;
            boffset += slen;
        }
    }

    void accumulateF32(FloatBufferTensor a, FloatBufferTensor b, int offset, int limit) {
        int i;
        int upperBound = offset + FloatVector.SPECIES_PREFERRED.loopBound(limit);
        for (i = offset; i < upperBound; i += FloatVector.SPECIES_PREFERRED.length()) {
            FloatVector va = a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_PREFERRED, new int[]{0, i});
            FloatVector vb = b.getVector((VectorSpecies<Float>)FloatVector.SPECIES_PREFERRED, new int[]{0, i});
            a.intoTensor(va.add((Vector)vb), 0, i);
        }
        while (i < offset + limit) {
            a.set(a.get(0, i) + b.get(0, i), 0, i++);
        }
    }

    void accumulateF32Q4_256(FloatBufferTensor a, Q4ByteBufferTensor b, int offset, int limit) {
        int aoffset = offset;
        int boffset = offset;
        int alim = offset + FloatVector.SPECIES_256.loopBound(limit);
        int slen = 32;
        while (aoffset < alim) {
            FloatVector scale = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_256, (float)b.getFactorForIndex(0, boffset));
            ByteVector wBytes = b.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_128, new int[]{0, boffset});
            ByteVector loBytes = wBytes.and((byte)15).sub((byte)8);
            ByteVector hiBytes = wBytes.lanewise(VectorOperators.LSHR, 4L).sub((byte)8);
            FloatVector af0 = a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_256, new int[]{0, aoffset}).add(loBytes.castShape(FloatVector.SPECIES_256, 0).mul((Vector)scale));
            FloatVector af1 = a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_256, new int[]{0, aoffset + 8}).add(loBytes.castShape(FloatVector.SPECIES_256, 1).mul((Vector)scale));
            FloatVector af2 = a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_256, new int[]{0, aoffset + 16}).add(hiBytes.castShape(FloatVector.SPECIES_256, 0).mul((Vector)scale));
            FloatVector af3 = a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_256, new int[]{0, aoffset + 16 + 8}).add(hiBytes.castShape(FloatVector.SPECIES_256, 1).mul((Vector)scale));
            a.intoTensor(af0, 0, aoffset);
            a.intoTensor(af1, 0, aoffset + 8);
            a.intoTensor(af2, 0, aoffset + 16);
            a.intoTensor(af3, 0, aoffset + 16 + 8);
            aoffset += slen;
            boffset += slen;
        }
    }

    void accumulateF32BF16_256(FloatBufferTensor a, BFloat16BufferTensor b, int offset, int limit) {
        int i;
        int upperBound = offset + FloatVector.SPECIES_256.loopBound(limit);
        for (i = offset; i < upperBound; i += FloatVector.SPECIES_256.length()) {
            FloatVector af = a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_256, new int[]{0, i});
            FloatVector bf = b.getVector((VectorSpecies<Short>)ShortVector.SPECIES_128, new int[]{0, i}).convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT_256).reinterpretAsFloats();
            FloatVector res = af.add((Vector)bf);
            a.intoTensor(res, 0, i);
        }
        while (i < offset + limit) {
            a.set(a.get(0, i) + b.get(0, i), 0, i++);
        }
    }

    void accumulateF32BF16_arm(FloatBufferTensor a, BFloat16BufferTensor b, int offset, int limit) {
        int i;
        int upperBound = offset + FloatVector.SPECIES_128.loopBound(limit);
        for (i = offset; i < upperBound; i += FloatVector.SPECIES_128.length()) {
            FloatVector af = a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_128, new int[]{0, i});
            FloatVector bf = b.getVector((VectorSpecies<Short>)ShortVector.SPECIES_64, new int[]{0, i}).convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT_128).reinterpretAsFloats();
            FloatVector res = af.add((Vector)bf);
            a.intoTensor(res, 0, i);
        }
        while (i < offset + limit) {
            a.set(a.get(0, i) + b.get(0, i), 0, i++);
        }
    }

    void accumulateBF16_arm(BFloat16BufferTensor a, BFloat16BufferTensor b, int offset, int limit) {
        int i;
        int upperBound = offset + FloatVector.SPECIES_128.loopBound(limit);
        for (i = offset; i < upperBound; i += FloatVector.SPECIES_128.length()) {
            FloatVector af = a.getVector((VectorSpecies<Short>)ShortVector.SPECIES_64, new int[]{0, i}).convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT_128).reinterpretAsFloats();
            FloatVector bf = b.getVector((VectorSpecies<Short>)ShortVector.SPECIES_64, new int[]{0, i}).convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT_128).reinterpretAsFloats();
            Vector res = af.add((Vector)bf).reinterpretAsInts().lanewise(VectorOperators.ASHR, (Vector)BF16_BYTE_SHIFT_128).convertShape(VectorOperators.I2S, ShortVector.SPECIES_64, 0);
            a.intoTensor((ShortVector)res, 0, i);
        }
        while (i < offset + limit) {
            a.set(a.get(0, i) + b.get(0, i), 0, i++);
        }
    }

    void accumulateBF16_256(BFloat16BufferTensor a, BFloat16BufferTensor b, int offset, int limit) {
        int i;
        int upperBound = offset + FloatVector.SPECIES_256.loopBound(limit);
        for (i = offset; i < upperBound; i += FloatVector.SPECIES_256.length()) {
            FloatVector af = a.getVector((VectorSpecies<Short>)ShortVector.SPECIES_128, new int[]{0, i}).convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT_256).reinterpretAsFloats();
            FloatVector bf = b.getVector((VectorSpecies<Short>)ShortVector.SPECIES_128, new int[]{0, i}).convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT_256).reinterpretAsFloats();
            Vector res = af.add((Vector)bf).reinterpretAsInts().lanewise(VectorOperators.ASHR, (Vector)BF16_BYTE_SHIFT_256).convertShape(VectorOperators.I2S, ShortVector.SPECIES_128, 0);
            a.intoTensor((ShortVector)res, 0, i);
        }
        while (i < offset + limit) {
            a.set(a.get(0, i) + b.get(0, i), 0, i++);
        }
    }

    void accumulateBF16_512(BFloat16BufferTensor a, BFloat16BufferTensor b, int offset, int limit) {
        int i;
        int upperBound = offset + FloatVector.SPECIES_512.loopBound(limit);
        for (i = offset; i < upperBound; i += FloatVector.SPECIES_512.length()) {
            FloatVector af = a.getVector((VectorSpecies<Short>)ShortVector.SPECIES_256, new int[]{0, i}).convertShape(VectorOperators.S2I, IntVector.SPECIES_512, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT_512).reinterpretAsFloats();
            FloatVector bf = b.getVector((VectorSpecies<Short>)ShortVector.SPECIES_256, new int[]{0, i}).convertShape(VectorOperators.S2I, IntVector.SPECIES_512, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT_512).reinterpretAsFloats();
            Vector res = af.add((Vector)bf).reinterpretAsInts().lanewise(VectorOperators.ASHR, (Vector)BF16_BYTE_SHIFT_512).convertShape(VectorOperators.I2S, ShortVector.SPECIES_256, 0);
            a.intoTensor((ShortVector)res, 0, i);
        }
        while (i < offset + limit) {
            a.set(a.get(0, i) + b.get(0, i), 0, i++);
        }
    }

    @Override
    public void scale(float factor, AbstractTensor aBatch, int offset, int length) {
        block8: for (int ai = 0; ai < aBatch.shape().first(); ++ai) {
            AbstractTensor a = aBatch.slice(ai);
            switch (a.dType()) {
                case F32: {
                    this.scaleF32(factor, (FloatBufferTensor)a, offset, length);
                    continue block8;
                }
                case BF16: {
                    switch (this.vectorType) {
                        case AVX_512: {
                            this.scaleBF16_512(factor, (BFloat16BufferTensor)a, offset, length);
                            continue block8;
                        }
                        case AVX_256: {
                            this.scaleBF16_256(factor, (BFloat16BufferTensor)a, offset, length);
                            continue block8;
                        }
                    }
                    throw new UnsupportedOperationException();
                }
                default: {
                    throw new UnsupportedOperationException();
                }
            }
        }
    }

    public void scaleF32(float factor, FloatBufferTensor a, int offset, int length) {
        int i;
        int upperBound = FloatVector.SPECIES_PREFERRED.loopBound(length) + offset;
        FloatVector sf = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_PREFERRED, (float)factor);
        for (i = offset; i < upperBound; i += FloatVector.SPECIES_PREFERRED.length()) {
            FloatVector va = a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_PREFERRED, new int[]{0, i});
            a.intoTensor(va.mul((Vector)sf), 0, i);
        }
        while (i < offset + length) {
            a.set(a.get(0, i) * factor, 0, i++);
        }
    }

    public void scaleBF16_512(float factor, BFloat16BufferTensor a, int offset, int length) {
        int i;
        int upperBound = FloatVector.SPECIES_512.loopBound(length) + offset;
        FloatVector sf = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_512, (float)factor);
        for (i = offset; i < upperBound; i += FloatVector.SPECIES_512.length()) {
            FloatVector va = a.getVector((VectorSpecies<Short>)ShortVector.SPECIES_256, new int[]{0, i}).convertShape(VectorOperators.S2I, IntVector.SPECIES_512, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT_512).reinterpretAsFloats();
            Vector res = va.mul((Vector)sf).reinterpretAsInts().lanewise(VectorOperators.ASHR, (Vector)BF16_BYTE_SHIFT_512).convertShape(VectorOperators.I2S, ShortVector.SPECIES_256, 0);
            a.intoTensor((ShortVector)res, 0, i);
        }
        while (i < offset + length) {
            a.set(a.get(0, i) * factor, 0, i++);
        }
    }

    public void scaleBF16_256(float factor, BFloat16BufferTensor a, int offset, int length) {
        int i;
        int upperBound = FloatVector.SPECIES_256.loopBound(length) + offset;
        FloatVector sf = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_256, (float)factor);
        for (i = offset; i < upperBound; i += FloatVector.SPECIES_256.length()) {
            FloatVector va = a.getVector((VectorSpecies<Short>)ShortVector.SPECIES_128, new int[]{0, i}).convertShape(VectorOperators.S2I, IntVector.SPECIES_256, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT_256).reinterpretAsFloats();
            Vector res = va.mul((Vector)sf).reinterpretAsInts().lanewise(VectorOperators.ASHR, (Vector)BF16_BYTE_SHIFT_256).convertShape(VectorOperators.I2S, ShortVector.SPECIES_128, 0);
            a.intoTensor((ShortVector)res, 0, i);
        }
        while (i < offset + length) {
            a.set(a.get(0, i) * factor, 0, i++);
        }
    }

    @Override
    public void saxpy(float alpha, AbstractTensor x, AbstractTensor y, int xoffset, int yoffset, int limit) {
        Preconditions.checkArgument((y.shape().first() == 1 ? 1 : 0) != 0);
        Preconditions.checkArgument((x.dType() == y.dType() || x.dType() == DType.BF16 && y.dType() == DType.F32 ? 1 : 0) != 0);
        Preconditions.checkArgument((limit % 2 == 0 ? 1 : 0) != 0);
        block0 : switch (x.dType()) {
            case F32: {
                this.saxpyF32(alpha, (FloatBufferTensor)x, (FloatBufferTensor)y, xoffset, yoffset, limit);
                break;
            }
            case BF16: {
                switch (y.dType()) {
                    case F32: {
                        this.saxpyBF16F32(alpha, (BFloat16BufferTensor)x, (FloatBufferTensor)y, xoffset, yoffset, limit);
                        break block0;
                    }
                    case BF16: {
                        this.saxpyBF16(alpha, (BFloat16BufferTensor)x, (BFloat16BufferTensor)y, xoffset, yoffset, limit);
                        break block0;
                    }
                }
                throw new UnsupportedOperationException();
            }
            default: {
                throw new UnsupportedOperationException();
            }
        }
    }

    void saxpyF32(float alpha, FloatBufferTensor x, FloatBufferTensor y, int xoffset, int yoffset, int limit) {
        int yo;
        int upperBound = FloatVector.SPECIES_PREFERRED.loopBound(limit);
        int xo = xoffset;
        FloatVector av = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_PREFERRED, (float)alpha);
        for (yo = yoffset; xo < xoffset + upperBound && yo < yoffset + upperBound; xo += FloatVector.SPECIES_PREFERRED.length(), yo += FloatVector.SPECIES_PREFERRED.length()) {
            FloatVector vx = x.getVector((VectorSpecies<Float>)FloatVector.SPECIES_PREFERRED, new int[]{0, xo});
            FloatVector vy = y.getVector((VectorSpecies<Float>)FloatVector.SPECIES_PREFERRED, new int[]{0, yo});
            FloatVector res = vx.fma((Vector)av, (Vector)vy);
            y.intoTensor(res, 0, yo);
        }
        while (xo < xoffset + limit && yo < yoffset + limit) {
            float v = y.get(0, yo) + alpha * x.get(0, xo++);
            y.set(v, 0, yo++);
        }
    }

    @Override
    public void saxpy(AbstractTensor alpha, AbstractTensor x, AbstractTensor y, int xoffset, int yoffset, int limit, int aOffset, int xOffset, int batchSize) {
        Preconditions.checkArgument((limit % 2 == 0 ? 1 : 0) != 0);
        block0 : switch (x.dType()) {
            case F32: {
                this.saxpyF32(alpha, (FloatBufferTensor)x, (FloatBufferTensor)y, xoffset, yoffset, limit, aOffset, xOffset, batchSize);
                break;
            }
            case BF16: {
                switch (y.dType()) {
                    case F32: {
                        this.saxpyBF16F32(alpha, x, y, xoffset, yoffset, limit, aOffset, xOffset, batchSize);
                        break block0;
                    }
                    case BF16: {
                        this.saxpyBF16(alpha, x, y, xoffset, yoffset, limit, aOffset, xOffset, batchSize);
                        break block0;
                    }
                }
                throw new UnsupportedOperationException();
            }
            default: {
                throw new UnsupportedOperationException();
            }
        }
    }

    public void saxpyF32(AbstractTensor alpha, FloatBufferTensor x, FloatBufferTensor y, int xoffset, int yoffset, int limit, int aOffset, int xOffset, int batchSize) {
        int upperBound = FloatVector.SPECIES_PREFERRED.loopBound(limit);
        int aLimit = batchSize - batchSize % 4;
        int a = aOffset;
        int xi = xOffset;
        aLimit += aOffset;
        while (a < aLimit) {
            int xo = xoffset;
            FloatVector a0 = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_PREFERRED, (float)alpha.get(0, a + 0));
            FloatVector a1 = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_PREFERRED, (float)alpha.get(0, a + 1));
            FloatVector a2 = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_PREFERRED, (float)alpha.get(0, a + 2));
            FloatVector a3 = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_PREFERRED, (float)alpha.get(0, a + 3));
            for (int yo = yoffset; xo < xoffset + upperBound && yo < yoffset + upperBound; xo += FloatVector.SPECIES_PREFERRED.length(), yo += FloatVector.SPECIES_PREFERRED.length()) {
                FloatVector x0 = x.getVector((VectorSpecies<Float>)FloatVector.SPECIES_PREFERRED, new int[]{xi + 0, xo});
                FloatVector x1 = x.getVector((VectorSpecies<Float>)FloatVector.SPECIES_PREFERRED, new int[]{xi + 1, xo});
                FloatVector x2 = x.getVector((VectorSpecies<Float>)FloatVector.SPECIES_PREFERRED, new int[]{xi + 2, xo});
                FloatVector x3 = x.getVector((VectorSpecies<Float>)FloatVector.SPECIES_PREFERRED, new int[]{xi + 3, xo});
                FloatVector vy = y.getVector((VectorSpecies<Float>)FloatVector.SPECIES_PREFERRED, new int[]{0, yo});
                FloatVector r0 = x0.fma((Vector)a0, (Vector)vy);
                r0 = x1.fma((Vector)a1, (Vector)r0);
                r0 = x2.fma((Vector)a2, (Vector)r0);
                r0 = x3.fma((Vector)a3, (Vector)r0);
                y.intoTensor(r0, 0, yo);
            }
            a += 4;
            xi += 4;
        }
        while (a < aOffset + batchSize) {
            this.saxpyF32(alpha.get(0, a++), (FloatBufferTensor)x.slice(xi++), y, xoffset, yoffset, limit);
        }
    }

    public void saxpyBF16(AbstractTensor alpha, AbstractTensor xt, AbstractTensor yt, int xoffset, int yoffset, int limit, int aOffset, int xOffset, int batchSize) {
        BFloat16BufferTensor x = (BFloat16BufferTensor)xt;
        BFloat16BufferTensor y = (BFloat16BufferTensor)yt;
        int batchLimit = aOffset + batchSize;
        int a = aOffset;
        int xi = xOffset;
        while (a < batchLimit) {
            this.saxpyBF16(alpha.get(0, a++), (BFloat16BufferTensor)x.slice(xi++), y, xoffset, yoffset, limit);
        }
    }

    public void saxpyBF16F32(AbstractTensor alpha, AbstractTensor xt, AbstractTensor yt, int xoffset, int yoffset, int limit, int aOffset, int xOffset, int batchSize) {
        BFloat16BufferTensor x = (BFloat16BufferTensor)xt;
        FloatBufferTensor y = (FloatBufferTensor)yt;
        int batchLimit = aOffset + batchSize;
        int a = aOffset;
        int xi = xOffset;
        while (a < batchLimit) {
            this.saxpyBF16F32(alpha.get(0, a++), (BFloat16BufferTensor)x.slice(xi++), y, xoffset, yoffset, limit);
        }
    }

    void saxpyBF16(float alpha, BFloat16BufferTensor a, BFloat16BufferTensor b, int aoffset, int boffset, int limit) {
        int upperBound = ShortVector.SPECIES_PREFERRED.loopBound(limit);
        Preconditions.checkArgument((upperBound == limit ? 1 : 0) != 0);
        int ao = aoffset;
        int len = ShortVector.SPECIES_PREFERRED.length();
        int half = ShortVector.SPECIES_PREFERRED.length() / 2;
        for (int bo = boffset; ao < aoffset + upperBound && bo < boffset + upperBound; ao += len, bo += len) {
            ShortVector sa = a.getVector((VectorSpecies<Short>)ShortVector.SPECIES_PREFERRED, new int[]{0, ao});
            FloatVector af0 = sa.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT).reinterpretAsFloats();
            FloatVector af1 = sa.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 1).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT).reinterpretAsFloats();
            ShortVector sb = b.getVector((VectorSpecies<Short>)ShortVector.SPECIES_PREFERRED, new int[]{0, bo});
            FloatVector bf0 = sb.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT).reinterpretAsFloats();
            FloatVector bf1 = sb.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 1).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT).reinterpretAsFloats();
            Vector r0 = bf0.add((Vector)af0.mul(alpha)).reinterpretAsInts().lanewise(VectorOperators.ASHR, (Vector)BF16_BYTE_SHIFT).convertShape(VectorOperators.I2S, ShortVector.SPECIES_PREFERRED, 0);
            Vector r1 = bf1.add((Vector)af1.mul(alpha)).reinterpretAsInts().lanewise(VectorOperators.ASHR, (Vector)BF16_BYTE_SHIFT).convertShape(VectorOperators.I2S, ShortVector.SPECIES_PREFERRED, -1);
            VectorMask mask = VectorMask.fromLong((VectorSpecies)ShortVector.SPECIES_PREFERRED, (long)((1L << half) - 1L));
            mask = mask.not();
            Vector r = r0.blend(r1, mask);
            b.intoTensor((ShortVector)r, 0, bo);
        }
    }

    void saxpyBF16F32(float alpha, BFloat16BufferTensor a, FloatBufferTensor b, int aoffset, int boffset, int limit) {
        int upperBound = ShortVector.SPECIES_PREFERRED.loopBound(limit);
        Preconditions.checkArgument((upperBound == limit ? 1 : 0) != 0);
        int ao = aoffset;
        int len = ShortVector.SPECIES_PREFERRED.length();
        for (int bo = boffset; ao < aoffset + upperBound && bo < boffset + upperBound; ao += len, bo += len) {
            ShortVector sa = a.getVector((VectorSpecies<Short>)ShortVector.SPECIES_PREFERRED, new int[]{0, ao});
            FloatVector af0 = sa.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT).reinterpretAsFloats();
            FloatVector af1 = sa.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 1).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT).reinterpretAsFloats();
            FloatVector bf0 = b.getVector((VectorSpecies<Float>)FloatVector.SPECIES_PREFERRED, new int[]{0, bo});
            FloatVector bf1 = b.getVector((VectorSpecies<Float>)FloatVector.SPECIES_PREFERRED, new int[]{0, bo + FloatVector.SPECIES_PREFERRED.length()});
            FloatVector r0 = bf0.add((Vector)af0.mul(alpha));
            FloatVector r1 = bf1.add((Vector)af1.mul(alpha));
            b.intoTensor(r0, 0, bo);
            b.intoTensor(r1, 0, bo + FloatVector.SPECIES_PREFERRED.length());
        }
    }

    private class GemmerF32
    extends Gemmer {
        final BiIntConsumer matmul1x1 = this.initMatmul1x1();
        final BiIntConsumer matmul1x4 = this.initMatmul1x4();
        final BiIntConsumer matmul3x4 = this.initMatmul3x4();
        final BiIntConsumer matmul4x1 = this.initMatmul4x1();

        GemmerF32(PanamaTensorOperations panamaTensorOperations, int k, AbstractTensor a, AbstractTensor b, AbstractTensor c, int ith, int nth, int rOffset) {
            super(panamaTensorOperations, k, a, b, c, ith, nth, rOffset);
        }

        @Override
        protected int pickKernel(int m0, int m, int n0, int n) {
            int nc;
            int mc;
            if (m - m0 >= 3 && n - n0 >= 4) {
                mc = 3;
                nc = 4;
                this.kernel(m0, m, 3, n0, n, 4, this.matmul3x4);
            } else if (m - m0 >= 4 && n - n0 >= 1) {
                mc = 4;
                nc = 1;
                this.kernel(m0, m, 4, n0, n, 1, this.matmul4x1);
            } else if (m - m0 >= 1 && n - n0 >= 4) {
                mc = 1;
                nc = 4;
                this.kernel(m0, m, 1, n0, n, 4, this.matmul1x4);
            } else {
                mc = 1;
                nc = 1;
                this.kernel(m0, m, 1, n0, n, 1, this.matmul1x1);
            }
            return mc << 4 | nc;
        }

        protected BiIntConsumer initMatmul1x1() {
            return (i, j) -> {
                FloatVector vc = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_PREFERRED);
                int aoffset = this.aColumnOffset;
                int alim = this.aColumnOffset + this.k;
                int blim = this.bColumnOffset + this.k;
                for (int boffset = this.bColumnOffset; aoffset < alim || boffset < blim; aoffset += FloatVector.SPECIES_PREFERRED.length(), boffset += FloatVector.SPECIES_PREFERRED.length()) {
                    FloatVector va = this.a.getVector(FloatVector.SPECIES_PREFERRED, i, aoffset).reinterpretAsFloats();
                    FloatVector vb = this.b.getVector(FloatVector.SPECIES_PREFERRED, j, boffset).reinterpretAsFloats();
                    vc = va.fma((Vector)vb, (Vector)vc);
                }
                this.c.set(vc.reduceLanes(VectorOperators.ADD), i, j + this.rOffset);
            };
        }

        protected BiIntConsumer initMatmul1x4() {
            return (i, j) -> {
                FloatVector vc0 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_PREFERRED);
                FloatVector vc1 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_PREFERRED);
                FloatVector vc2 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_PREFERRED);
                FloatVector vc3 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_PREFERRED);
                int aoffset = this.aColumnOffset;
                int alim = this.aColumnOffset + this.k;
                int blim = this.bColumnOffset + this.k;
                for (int boffset = this.bColumnOffset; aoffset < alim || boffset < blim; aoffset += FloatVector.SPECIES_PREFERRED.length(), boffset += FloatVector.SPECIES_PREFERRED.length()) {
                    FloatVector va = this.a.getVector(FloatVector.SPECIES_PREFERRED, i, aoffset).reinterpretAsFloats();
                    FloatVector vb0 = this.b.getVector(FloatVector.SPECIES_PREFERRED, j + 0, boffset).reinterpretAsFloats();
                    FloatVector vb1 = this.b.getVector(FloatVector.SPECIES_PREFERRED, j + 1, boffset).reinterpretAsFloats();
                    FloatVector vb2 = this.b.getVector(FloatVector.SPECIES_PREFERRED, j + 2, boffset).reinterpretAsFloats();
                    FloatVector vb3 = this.b.getVector(FloatVector.SPECIES_PREFERRED, j + 3, boffset).reinterpretAsFloats();
                    vc0 = va.fma((Vector)vb0, (Vector)vc0);
                    vc1 = va.fma((Vector)vb1, (Vector)vc1);
                    vc2 = va.fma((Vector)vb2, (Vector)vc2);
                    vc3 = va.fma((Vector)vb3, (Vector)vc3);
                }
                this.c.set(vc0.reduceLanes(VectorOperators.ADD), i, j + 0 + this.rOffset);
                this.c.set(vc1.reduceLanes(VectorOperators.ADD), i, j + 1 + this.rOffset);
                this.c.set(vc2.reduceLanes(VectorOperators.ADD), i, j + 2 + this.rOffset);
                this.c.set(vc3.reduceLanes(VectorOperators.ADD), i, j + 3 + this.rOffset);
            };
        }

        protected BiIntConsumer initMatmul3x4() {
            return (i, j) -> {
                FloatVector vc00 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_PREFERRED);
                FloatVector vc01 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_PREFERRED);
                FloatVector vc02 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_PREFERRED);
                FloatVector vc03 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_PREFERRED);
                FloatVector vc10 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_PREFERRED);
                FloatVector vc11 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_PREFERRED);
                FloatVector vc12 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_PREFERRED);
                FloatVector vc13 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_PREFERRED);
                FloatVector vc20 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_PREFERRED);
                FloatVector vc21 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_PREFERRED);
                FloatVector vc22 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_PREFERRED);
                FloatVector vc23 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_PREFERRED);
                int aoffset = this.aColumnOffset;
                int alim = this.aColumnOffset + this.k;
                int blim = this.bColumnOffset + this.k;
                for (int boffset = this.bColumnOffset; aoffset < alim || boffset < blim; aoffset += FloatVector.SPECIES_PREFERRED.length(), boffset += FloatVector.SPECIES_PREFERRED.length()) {
                    FloatVector vb0 = this.b.getVector(FloatVector.SPECIES_PREFERRED, j + 0, boffset).reinterpretAsFloats();
                    FloatVector vb1 = this.b.getVector(FloatVector.SPECIES_PREFERRED, j + 1, boffset).reinterpretAsFloats();
                    FloatVector vb2 = this.b.getVector(FloatVector.SPECIES_PREFERRED, j + 2, boffset).reinterpretAsFloats();
                    FloatVector vb3 = this.b.getVector(FloatVector.SPECIES_PREFERRED, j + 3, boffset).reinterpretAsFloats();
                    FloatVector va = this.a.getVector(FloatVector.SPECIES_PREFERRED, i + 0, aoffset).reinterpretAsFloats();
                    vc00 = va.fma((Vector)vb0, (Vector)vc00);
                    vc01 = va.fma((Vector)vb1, (Vector)vc01);
                    vc02 = va.fma((Vector)vb2, (Vector)vc02);
                    vc03 = va.fma((Vector)vb3, (Vector)vc03);
                    FloatVector va1 = this.a.getVector(FloatVector.SPECIES_PREFERRED, i + 1, aoffset).reinterpretAsFloats();
                    vc10 = va1.fma((Vector)vb0, (Vector)vc10);
                    vc11 = va1.fma((Vector)vb1, (Vector)vc11);
                    vc12 = va1.fma((Vector)vb2, (Vector)vc12);
                    vc13 = va1.fma((Vector)vb3, (Vector)vc13);
                    FloatVector va2 = this.a.getVector(FloatVector.SPECIES_PREFERRED, i + 2, aoffset).reinterpretAsFloats();
                    vc20 = va2.fma((Vector)vb0, (Vector)vc20);
                    vc21 = va2.fma((Vector)vb1, (Vector)vc21);
                    vc22 = va2.fma((Vector)vb2, (Vector)vc22);
                    vc23 = va2.fma((Vector)vb3, (Vector)vc23);
                }
                this.c.set(vc00.reduceLanes(VectorOperators.ADD), i + 0, j + 0 + this.rOffset);
                this.c.set(vc01.reduceLanes(VectorOperators.ADD), i + 0, j + 1 + this.rOffset);
                this.c.set(vc02.reduceLanes(VectorOperators.ADD), i + 0, j + 2 + this.rOffset);
                this.c.set(vc03.reduceLanes(VectorOperators.ADD), i + 0, j + 3 + this.rOffset);
                this.c.set(vc10.reduceLanes(VectorOperators.ADD), i + 1, j + 0 + this.rOffset);
                this.c.set(vc11.reduceLanes(VectorOperators.ADD), i + 1, j + 1 + this.rOffset);
                this.c.set(vc12.reduceLanes(VectorOperators.ADD), i + 1, j + 2 + this.rOffset);
                this.c.set(vc13.reduceLanes(VectorOperators.ADD), i + 1, j + 3 + this.rOffset);
                this.c.set(vc20.reduceLanes(VectorOperators.ADD), i + 2, j + 0 + this.rOffset);
                this.c.set(vc21.reduceLanes(VectorOperators.ADD), i + 2, j + 1 + this.rOffset);
                this.c.set(vc22.reduceLanes(VectorOperators.ADD), i + 2, j + 2 + this.rOffset);
                this.c.set(vc23.reduceLanes(VectorOperators.ADD), i + 2, j + 3 + this.rOffset);
            };
        }

        protected BiIntConsumer initMatmul4x1() {
            return (i, j) -> {
                FloatVector vc0 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_PREFERRED);
                FloatVector vc1 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_PREFERRED);
                FloatVector vc2 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_PREFERRED);
                FloatVector vc3 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_PREFERRED);
                int aoffset = this.aColumnOffset;
                int alim = this.aColumnOffset + this.k;
                int blim = this.bColumnOffset + this.k;
                for (int boffset = this.bColumnOffset; aoffset < alim || boffset < blim; aoffset += FloatVector.SPECIES_PREFERRED.length(), boffset += FloatVector.SPECIES_PREFERRED.length()) {
                    FloatVector va0 = this.a.getVector(FloatVector.SPECIES_PREFERRED, i + 0, aoffset).reinterpretAsFloats();
                    FloatVector va1 = this.a.getVector(FloatVector.SPECIES_PREFERRED, i + 1, aoffset).reinterpretAsFloats();
                    FloatVector va2 = this.a.getVector(FloatVector.SPECIES_PREFERRED, i + 2, aoffset).reinterpretAsFloats();
                    FloatVector va3 = this.a.getVector(FloatVector.SPECIES_PREFERRED, i + 3, aoffset).reinterpretAsFloats();
                    FloatVector vb0 = this.b.getVector(FloatVector.SPECIES_PREFERRED, j, boffset).reinterpretAsFloats();
                    vc0 = va0.fma((Vector)vb0, (Vector)vc0);
                    vc1 = va1.fma((Vector)vb0, (Vector)vc1);
                    vc2 = va2.fma((Vector)vb0, (Vector)vc2);
                    vc3 = va3.fma((Vector)vb0, (Vector)vc3);
                }
                this.c.set(vc0.reduceLanes(VectorOperators.ADD), i + 0, j + this.rOffset);
                this.c.set(vc1.reduceLanes(VectorOperators.ADD), i + 1, j + this.rOffset);
                this.c.set(vc2.reduceLanes(VectorOperators.ADD), i + 2, j + this.rOffset);
                this.c.set(vc3.reduceLanes(VectorOperators.ADD), i + 3, j + this.rOffset);
            };
        }
    }

    private class GemmerF32BF16
    extends Gemmer {
        final BiIntConsumer matmul1x1 = this.initMatmul1x1();
        final FloatBufferTensor a;
        final BFloat16BufferTensor b;

        GemmerF32BF16(PanamaTensorOperations panamaTensorOperations, int k, AbstractTensor ta, AbstractTensor tb, AbstractTensor c, int ith, int nth, int rOffset) {
            super(panamaTensorOperations, k, ta, tb, c, ith, nth, rOffset);
            this.a = (FloatBufferTensor)ta;
            this.b = (BFloat16BufferTensor)tb;
        }

        @Override
        protected int pickKernel(int m0, int m, int n0, int n) {
            int mc = 1;
            int nc = 1;
            this.kernel(m0, m, 1, n0, n, 1, this.matmul1x1);
            return mc << 4 | nc;
        }

        protected BiIntConsumer initMatmul1x1() {
            return (i, j) -> {
                FloatVector vc = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_PREFERRED);
                int aoffset = this.aColumnOffset;
                int alim = this.aColumnOffset + this.k;
                int blim = this.bColumnOffset + this.k;
                int slen = ShortVector.SPECIES_PREFERRED.length();
                for (int boffset = this.bColumnOffset; aoffset < alim && boffset < blim; aoffset += slen, boffset += slen) {
                    FloatVector va0 = this.a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_PREFERRED, new int[]{i, aoffset});
                    FloatVector va1 = this.a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_PREFERRED, new int[]{i, aoffset + FloatVector.SPECIES_PREFERRED.length()});
                    ShortVector sb = this.b.getVector((VectorSpecies<Short>)ShortVector.SPECIES_PREFERRED, new int[]{j, boffset});
                    FloatVector vb0 = sb.convertShape(VectorOperators.ZERO_EXTEND_S2I, IntVector.SPECIES_PREFERRED, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT).reinterpretAsFloats();
                    FloatVector vb1 = sb.convertShape(VectorOperators.ZERO_EXTEND_S2I, IntVector.SPECIES_PREFERRED, 1).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT).reinterpretAsFloats();
                    vc = va0.fma((Vector)vb0, (Vector)vc);
                    vc = va1.fma((Vector)vb1, (Vector)vc);
                }
                float res = vc.reduceLanes(VectorOperators.ADD);
                this.c.set(res, i, j + this.rOffset);
            };
        }
    }

    private class GemmerF32Q4_256
    extends Gemmer {
        final BiIntConsumer matmul1x1;
        final BiIntConsumer matmul1x4;
        final BiIntConsumer matmul3x4;
        final BiIntConsumer matmul4x1;
        final Q4ByteBufferTensor b;
        final FloatBufferTensor a;

        GemmerF32Q4_256(PanamaTensorOperations panamaTensorOperations, int k, AbstractTensor ta, AbstractTensor tb, AbstractTensor c, int ith, int nth, int rOffset) {
            super(panamaTensorOperations, k, ta, tb, c, ith, nth, rOffset);
            this.a = (FloatBufferTensor)ta;
            this.b = (Q4ByteBufferTensor)tb;
            this.matmul1x1 = this.initMatmul1x1();
            this.matmul1x4 = this.initMatmul1x4();
            this.matmul3x4 = null;
            this.matmul4x1 = null;
        }

        @Override
        protected int pickKernel(int m0, int m, int n0, int n) {
            int mc = 1;
            int nc = 1;
            this.kernel(m0, m, 1, n0, n, 1, this.matmul1x1);
            return mc << 4 | nc;
        }

        protected BiIntConsumer initMatmul1x1() {
            return (i, j) -> {
                int boffset;
                int aoffset = this.aColumnOffset;
                int alim = aoffset + this.k;
                int blim = boffset + this.k;
                int slen = 32;
                FloatVector acc = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_256);
                for (boffset = this.bColumnOffset; aoffset < alim && boffset < blim; aoffset += slen, boffset += slen) {
                    FloatVector scale = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_256, (float)this.b.getFactorForIndex(j, boffset));
                    ByteVector b0 = this.b.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_128, new int[]{j, boffset});
                    ByteVector b0lo = b0.and((Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128);
                    ByteVector b0hi = b0.lanewise(VectorOperators.LSHR, (Vector)Q4_BYTE_SHIFT_128).sub((Vector)Q4_BYTE_SUB_128);
                    FloatVector af0 = this.a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_256, new int[]{i, aoffset}).mul(b0lo.castShape(FloatVector.SPECIES_256, 0));
                    FloatVector af1 = this.a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_256, new int[]{i, aoffset + 8}).mul(b0lo.castShape(FloatVector.SPECIES_256, 1));
                    FloatVector af2 = this.a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_256, new int[]{i, aoffset + 16}).mul(b0hi.castShape(FloatVector.SPECIES_256, 0));
                    FloatVector af3 = this.a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_256, new int[]{i, aoffset + 16 + 8}).mul(b0hi.castShape(FloatVector.SPECIES_256, 1));
                    acc = af0.add((Vector)af1).add((Vector)af2).add((Vector)af3).fma((Vector)scale, (Vector)acc);
                }
                this.c.set(acc.reduceLanes(VectorOperators.ADD), i, j + this.rOffset);
            };
        }

        protected BiIntConsumer initMatmul1x4() {
            return (i, j) -> {
                int boffset;
                int aoffset = this.aColumnOffset;
                int alim = aoffset + this.k;
                int blim = boffset + this.k;
                int slen = 32;
                FloatVector acc0 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_256);
                FloatVector acc1 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_256);
                for (boffset = this.bColumnOffset; aoffset < alim && boffset < blim; aoffset += slen, boffset += slen) {
                    FloatVector scale0 = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_256, (float)this.b.getFactorForIndex(j + 0, boffset));
                    FloatVector scale1 = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_256, (float)this.b.getFactorForIndex(j + 1, boffset));
                    FloatVector af0 = this.a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_256, new int[]{i, aoffset});
                    FloatVector af1 = this.a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_256, new int[]{i, aoffset + 8});
                    FloatVector af2 = this.a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_256, new int[]{i, aoffset + 16});
                    FloatVector af3 = this.a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_256, new int[]{i, aoffset + 16 + 8});
                    ByteVector bf0 = this.b.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_128, new int[]{j + 0, boffset});
                    ByteVector lo0 = bf0.and((Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128);
                    ByteVector hi0 = bf0.lanewise(VectorOperators.LSHR, (Vector)Q4_BYTE_SHIFT_128).sub((Vector)Q4_BYTE_SUB_128);
                    FloatVector af0l = af0.mul(lo0.castShape(FloatVector.SPECIES_256, 0));
                    FloatVector af1l = af1.mul(lo0.castShape(FloatVector.SPECIES_256, 1));
                    FloatVector af2l = af2.mul(hi0.castShape(FloatVector.SPECIES_256, 0));
                    FloatVector af3l = af3.mul(hi0.castShape(FloatVector.SPECIES_256, 1));
                    acc0 = af0l.add((Vector)af1l).add((Vector)af2l).add((Vector)af3l).fma((Vector)scale0, (Vector)acc0);
                    bf0 = this.b.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_128, new int[]{j + 1, boffset});
                    lo0 = bf0.and((Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128);
                    hi0 = bf0.lanewise(VectorOperators.LSHR, (Vector)Q4_BYTE_SHIFT_128).sub((Vector)Q4_BYTE_SUB_128);
                    af0l = af0.mul(lo0.castShape(FloatVector.SPECIES_256, 0));
                    af1l = af1.mul(lo0.castShape(FloatVector.SPECIES_256, 1));
                    af2l = af2.mul(hi0.castShape(FloatVector.SPECIES_256, 0));
                    af3l = af3.mul(hi0.castShape(FloatVector.SPECIES_256, 1));
                    acc1 = af0l.add((Vector)af1l).add((Vector)af2l).add((Vector)af3l).fma((Vector)scale1, (Vector)acc1);
                }
                this.c.set(acc0.reduceLanes(VectorOperators.ADD), i, j + 0 + this.rOffset);
                this.c.set(acc1.reduceLanes(VectorOperators.ADD), i, j + 1 + this.rOffset);
            };
        }
    }

    private class GemmerF32Q4_512
    extends Gemmer {
        final BiIntConsumer matmul1x1;
        final BiIntConsumer matmul1x4;
        final BiIntConsumer matmul3x4;
        final BiIntConsumer matmul4x1;
        final Q4ByteBufferTensor b;
        final FloatBufferTensor a;

        GemmerF32Q4_512(PanamaTensorOperations panamaTensorOperations, int k, AbstractTensor ta, AbstractTensor tb, AbstractTensor c, int ith, int nth, int rOffset) {
            super(panamaTensorOperations, k, ta, tb, c, ith, nth, rOffset);
            this.a = (FloatBufferTensor)ta;
            this.b = (Q4ByteBufferTensor)tb;
            this.matmul1x1 = this.initMatmul1x1();
            this.matmul1x4 = this.initMatmul1x4();
            this.matmul3x4 = null;
            this.matmul4x1 = this.initMatmul4x1();
        }

        @Override
        protected int pickKernel(int m0, int m, int n0, int n) {
            int nc;
            int mc;
            if (m - m0 >= 4 && n - n0 >= 1) {
                mc = 4;
                nc = 1;
                this.kernel(m0, m, 4, n0, n, 1, this.matmul4x1);
            } else if (m - m0 >= 1 && n - n0 >= 4) {
                mc = 1;
                nc = 4;
                this.kernel(m0, m, mc, n0, n, nc, this.matmul1x4);
            } else {
                mc = 1;
                nc = 1;
                this.kernel(m0, m, 1, n0, n, 1, this.matmul1x1);
            }
            return mc << 4 | nc;
        }

        protected BiIntConsumer initMatmul1x1() {
            return (i, j) -> {
                int boffset;
                int aoffset = this.aColumnOffset;
                int alim = aoffset + this.k;
                int blim = boffset + this.k;
                int slen = 32;
                FloatVector acc = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_512);
                for (boffset = this.bColumnOffset; aoffset < alim && boffset < blim; aoffset += slen, boffset += slen) {
                    FloatVector scale = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_512, (float)this.b.getFactorForIndex(j, boffset));
                    FloatVector af0 = this.a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_512, new int[]{i, aoffset});
                    FloatVector af1 = this.a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_512, new int[]{i, aoffset + 16});
                    ByteVector bf0 = this.b.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_128, new int[]{j, boffset});
                    Vector low0 = bf0.lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2F, FloatVector.SPECIES_512, 0).mul((Vector)scale);
                    Vector high0 = bf0.lanewise(VectorOperators.ASHR, (Vector)Q4_BYTE_SHIFT_128).lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2F, FloatVector.SPECIES_512, 0).mul((Vector)scale);
                    acc = af0.fma(low0, (Vector)acc);
                    acc = af1.fma(high0, (Vector)acc);
                }
                this.c.set(acc.reduceLanes(VectorOperators.ADD), i, j + this.rOffset);
            };
        }

        protected final BiIntConsumer initMatmul4x1() {
            return (i, j) -> {
                int boffset;
                int aoffset = this.aColumnOffset;
                int alim = aoffset + this.k;
                int blim = boffset + this.k;
                int slen = 32;
                FloatVector acc0 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_512);
                FloatVector acc1 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_512);
                FloatVector acc2 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_512);
                FloatVector acc3 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_512);
                for (boffset = this.bColumnOffset; aoffset < alim && boffset < blim; aoffset += slen, boffset += slen) {
                    FloatVector scale = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_512, (float)this.b.getFactorForIndex(j, boffset));
                    ByteVector bf0 = this.b.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_128, new int[]{j, boffset});
                    Vector low0 = bf0.lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2F, FloatVector.SPECIES_512, 0).mul((Vector)scale);
                    Vector high0 = bf0.lanewise(VectorOperators.ASHR, (Vector)Q4_BYTE_SHIFT_128).lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2F, FloatVector.SPECIES_512, 0).mul((Vector)scale);
                    FloatVector af00 = this.a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_512, new int[]{i, aoffset});
                    FloatVector af01 = this.a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_512, new int[]{i, aoffset + 16});
                    FloatVector af10 = this.a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_512, new int[]{i + 1, aoffset});
                    FloatVector af11 = this.a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_512, new int[]{i + 1, aoffset + 16});
                    FloatVector af20 = this.a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_512, new int[]{i + 2, aoffset});
                    FloatVector af21 = this.a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_512, new int[]{i + 2, aoffset + 16});
                    FloatVector af30 = this.a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_512, new int[]{i + 3, aoffset});
                    FloatVector af31 = this.a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_512, new int[]{i + 3, aoffset + 16});
                    acc0 = af00.fma(low0, (Vector)acc0);
                    acc0 = af01.fma(high0, (Vector)acc0);
                    acc1 = af10.fma(low0, (Vector)acc1);
                    acc1 = af11.fma(high0, (Vector)acc1);
                    acc2 = af20.fma(low0, (Vector)acc2);
                    acc2 = af21.fma(high0, (Vector)acc2);
                    acc3 = af30.fma(low0, (Vector)acc3);
                    acc3 = af31.fma(high0, (Vector)acc3);
                }
                this.c.set(acc0.reduceLanes(VectorOperators.ADD), i + 0, j + this.rOffset);
                this.c.set(acc1.reduceLanes(VectorOperators.ADD), i + 1, j + this.rOffset);
                this.c.set(acc2.reduceLanes(VectorOperators.ADD), i + 2, j + this.rOffset);
                this.c.set(acc3.reduceLanes(VectorOperators.ADD), i + 3, j + this.rOffset);
            };
        }

        protected BiIntConsumer initMatmul1x4() {
            return (i, j) -> {
                int boffset;
                int aoffset = this.aColumnOffset;
                int alim = aoffset + this.k;
                int blim = boffset + this.k;
                int slen = 32;
                FloatVector acc0 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_512);
                FloatVector acc1 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_512);
                FloatVector acc2 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_512);
                FloatVector acc3 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_512);
                for (boffset = this.bColumnOffset; aoffset < alim && boffset < blim; aoffset += slen, boffset += slen) {
                    FloatVector af0 = this.a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_512, new int[]{i, aoffset});
                    FloatVector af1 = this.a.getVector((VectorSpecies<Float>)FloatVector.SPECIES_512, new int[]{i, aoffset + 16});
                    FloatVector scale0 = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_512, (float)this.b.getFactorForIndex(j + 0, boffset));
                    ByteVector bf0 = this.b.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_128, new int[]{j + 0, boffset});
                    Vector low0 = bf0.lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2F, FloatVector.SPECIES_512, 0).mul((Vector)scale0);
                    Vector high0 = bf0.lanewise(VectorOperators.ASHR, (Vector)Q4_BYTE_SHIFT_128).lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2F, FloatVector.SPECIES_512, 0).mul((Vector)scale0);
                    acc0 = af0.fma(low0, (Vector)acc0);
                    acc0 = af1.fma(high0, (Vector)acc0);
                    scale0 = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_512, (float)this.b.getFactorForIndex(j + 1, boffset));
                    bf0 = this.b.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_128, new int[]{j + 1, boffset});
                    low0 = bf0.lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2F, FloatVector.SPECIES_512, 0).mul((Vector)scale0);
                    high0 = bf0.lanewise(VectorOperators.ASHR, (Vector)Q4_BYTE_SHIFT_128).lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2F, FloatVector.SPECIES_512, 0).mul((Vector)scale0);
                    acc1 = af0.fma(low0, (Vector)acc1);
                    acc1 = af1.fma(high0, (Vector)acc1);
                    scale0 = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_512, (float)this.b.getFactorForIndex(j + 2, boffset));
                    bf0 = this.b.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_128, new int[]{j + 2, boffset});
                    low0 = bf0.lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2F, FloatVector.SPECIES_512, 0).mul((Vector)scale0);
                    high0 = bf0.lanewise(VectorOperators.ASHR, (Vector)Q4_BYTE_SHIFT_128).lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2F, FloatVector.SPECIES_512, 0).mul((Vector)scale0);
                    acc2 = af0.fma(low0, (Vector)acc2);
                    acc2 = af1.fma(high0, (Vector)acc2);
                    scale0 = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_512, (float)this.b.getFactorForIndex(j + 3, boffset));
                    bf0 = this.b.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_128, new int[]{j + 3, boffset});
                    low0 = bf0.lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2F, FloatVector.SPECIES_512, 0).mul((Vector)scale0);
                    high0 = bf0.lanewise(VectorOperators.ASHR, (Vector)Q4_BYTE_SHIFT_128).lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2F, FloatVector.SPECIES_512, 0).mul((Vector)scale0);
                    acc3 = af0.fma(low0, (Vector)acc3);
                    acc3 = af1.fma(high0, (Vector)acc3);
                }
                this.c.set(acc0.reduceLanes(VectorOperators.ADD), i, j + 0 + this.rOffset);
                this.c.set(acc1.reduceLanes(VectorOperators.ADD), i, j + 1 + this.rOffset);
                this.c.set(acc2.reduceLanes(VectorOperators.ADD), i, j + 2 + this.rOffset);
                this.c.set(acc3.reduceLanes(VectorOperators.ADD), i, j + 3 + this.rOffset);
            };
        }
    }

    private class GemmerI8Q4_256
    extends Gemmer {
        final BiIntConsumer matmul1x1;
        final BiIntConsumer matmul1x4;
        final BiIntConsumer matmul3x4;
        final BiIntConsumer matmul4x1;
        final Q8ByteBufferTensor a;
        final Q4ByteBufferTensor b;

        GemmerI8Q4_256(PanamaTensorOperations panamaTensorOperations, int k, AbstractTensor ta, AbstractTensor tb, AbstractTensor c, int aColumnOffset, int bColumnOffset, int rOffset) {
            super(panamaTensorOperations, k, ta, tb, c, aColumnOffset, bColumnOffset, rOffset);
            this.a = (Q8ByteBufferTensor)ta;
            this.b = (Q4ByteBufferTensor)tb;
            this.matmul1x1 = this.initMatmul1x1();
            this.matmul1x4 = null;
            this.matmul3x4 = null;
            this.matmul4x1 = null;
        }

        @Override
        protected int pickKernel(int m0, int m, int n0, int n) {
            int mc = 1;
            int nc = 1;
            this.kernel(m0, m, 1, n0, n, 1, this.matmul1x1);
            return mc << 4 | nc;
        }

        protected BiIntConsumer initMatmul1x1() {
            return (i, j) -> {
                int blockSize = 32;
                int blocksNeeded = this.k / 32;
                int aoffset = this.aColumnOffset;
                int boffset = this.bColumnOffset;
                FloatVector acc = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_256);
                for (int bi = 0; bi < blocksNeeded; bi += FloatVector.SPECIES_256.length()) {
                    FloatVector ablock = this.a.getBlockF().getVector((VectorSpecies<Float>)FloatVector.SPECIES_256, new int[]{i, (int)(0.03125f * (float)aoffset)});
                    FloatVector bblock = this.b.getBlockF().getVector((VectorSpecies<Float>)FloatVector.SPECIES_256, new int[]{j, (int)(0.03125f * (float)boffset)});
                    FloatVector scales = ablock.mul((Vector)bblock);
                    int k = 0;
                    while (k < FloatVector.SPECIES_256.length()) {
                        FloatVector scale = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_256, (float)scales.lane(k));
                        ByteVector ai = this.a.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_256, new int[]{i, aoffset});
                        Vector af0 = ai.convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0);
                        Vector af1 = ai.convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 1);
                        ByteVector b0 = this.b.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_128, new int[]{j, boffset});
                        ByteVector b0low = b0.and((Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128);
                        ByteVector b0hi = b0.lanewise(VectorOperators.LSHR, (Vector)Q4_BYTE_SHIFT_128).sub((Vector)Q4_BYTE_SUB_128);
                        Vector isum = b0low.convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0).mul(af0);
                        isum = isum.add(b0hi.convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0).mul(af1));
                        Vector r0 = isum.convertShape(VectorOperators.S2F, FloatVector.SPECIES_256, 0);
                        Vector r1 = isum.convertShape(VectorOperators.S2F, FloatVector.SPECIES_256, 1);
                        acc = scale.fma(r0.add(r1), (Vector)acc);
                        ++k;
                        aoffset += 32;
                        boffset += 32;
                    }
                }
                this.c.set(acc.reduceLanes(VectorOperators.ADD), i, j + this.rOffset);
            };
        }
    }

    private class GemmerI8Q4_512
    extends Gemmer {
        final BiIntConsumer matmul1x1;
        final BiIntConsumer matmul1x4;
        final BiIntConsumer matmul3x4;
        final Q8ByteBufferTensor a;
        final Q4ByteBufferTensor b;

        GemmerI8Q4_512(PanamaTensorOperations panamaTensorOperations, int k, AbstractTensor ta, AbstractTensor tb, AbstractTensor c, int ith, int nth, int rOffset) {
            super(panamaTensorOperations, k, ta, tb, c, ith, nth, rOffset);
            this.a = (Q8ByteBufferTensor)ta;
            this.b = (Q4ByteBufferTensor)tb;
            this.matmul1x1 = this.initMatmul1x1();
            this.matmul1x4 = this.initMatmul1x4();
            this.matmul3x4 = this.initMatmul3x4();
        }

        @Override
        protected int pickKernel(int m0, int m, int n0, int n) {
            int nc;
            int mc;
            if (m - m0 >= 2 && n - n0 >= 2) {
                mc = 2;
                nc = 2;
                this.kernel(m0, m, 2, n0, n, 2, this.matmul3x4);
            } else if (m - m0 >= 1 && n - n0 >= 4) {
                mc = 1;
                nc = 4;
                this.kernel(m0, m, 1, n0, n, 4, this.matmul1x4);
            } else {
                mc = 1;
                nc = 1;
                this.kernel(m0, m, 1, n0, n, 1, this.matmul1x1);
            }
            return mc << 4 | nc;
        }

        protected BiIntConsumer initMatmul1x1() {
            return (i, j) -> {
                int blockSize = 32;
                int aoffset = this.aColumnOffset;
                int boffset = this.bColumnOffset;
                FloatVector acc = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_512);
                int l = 0;
                while (l < this.k) {
                    FloatVector scale = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_512, (float)(this.a.getFactorForIndex(i, aoffset) * this.b.getFactorForIndex(j, boffset)));
                    ShortVector af = this.a.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_256, new int[]{i, aoffset}).convertShape(VectorOperators.B2S, ShortVector.SPECIES_512, 0).reinterpretAsShorts();
                    ByteVector bf0 = this.b.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_128, new int[]{j, boffset});
                    Vector low0 = bf0.lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0);
                    Vector high0 = bf0.lanewise(VectorOperators.ASHR, (Vector)Q4_BYTE_SHIFT_128).lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0);
                    Vector isum = low0.mul(af.castShape(ShortVector.SPECIES_256, 0)).add(high0.mul(af.castShape(ShortVector.SPECIES_256, 1)));
                    Vector r0 = isum.convertShape(VectorOperators.S2F, FloatVector.SPECIES_512, 0);
                    acc = scale.fma(r0, (Vector)acc);
                    l += 32;
                    aoffset += 32;
                    boffset += 32;
                }
                this.c.set(acc.reduceLanes(VectorOperators.ADD), i, j + this.rOffset);
            };
        }

        protected BiIntConsumer initMatmul1x4() {
            return (i, j) -> {
                int blockSize = 32;
                int blocksNeeded = this.k / 32;
                int aoffset = this.aColumnOffset;
                int boffset = this.bColumnOffset;
                FloatVector acc0 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_512);
                FloatVector acc1 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_512);
                FloatVector acc2 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_512);
                FloatVector acc3 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_512);
                for (int bi = 0; bi < blocksNeeded; bi += FloatVector.SPECIES_512.length()) {
                    int k = 0;
                    while (k < FloatVector.SPECIES_512.length()) {
                        float as = this.a.getFactorForIndex(i + 0, aoffset);
                        FloatVector scale0 = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_512, (float)(as * this.b.getFactorForIndex(j + 0, boffset)));
                        FloatVector scale1 = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_512, (float)(as * this.b.getFactorForIndex(j + 1, boffset)));
                        FloatVector scale2 = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_512, (float)(as * this.b.getFactorForIndex(j + 2, boffset)));
                        FloatVector scale3 = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_512, (float)(as * this.b.getFactorForIndex(j + 3, boffset)));
                        Vector af = this.a.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_256, new int[]{i, aoffset}).convertShape(VectorOperators.B2S, ShortVector.SPECIES_512, 0);
                        Vector af0 = af.castShape(ShortVector.SPECIES_256, 0);
                        Vector af1 = af.castShape(ShortVector.SPECIES_256, 1);
                        ByteVector bf0 = this.b.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_128, new int[]{j + 0, boffset});
                        ByteVector bf1 = this.b.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_128, new int[]{j + 1, boffset});
                        ByteVector bf2 = this.b.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_128, new int[]{j + 2, boffset});
                        ByteVector bf3 = this.b.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_128, new int[]{j + 3, boffset});
                        Vector r0 = bf0.lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0).mul(af0).add(bf0.lanewise(VectorOperators.ASHR, (Vector)Q4_BYTE_SHIFT_128).lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0).mul(af1)).convertShape(VectorOperators.S2F, FloatVector.SPECIES_512, 0);
                        Vector r1 = bf1.lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0).mul(af0).add(bf1.lanewise(VectorOperators.ASHR, (Vector)Q4_BYTE_SHIFT_128).lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0).mul(af1)).convertShape(VectorOperators.S2F, FloatVector.SPECIES_512, 0);
                        Vector r2 = bf2.lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0).mul(af0).add(bf2.lanewise(VectorOperators.ASHR, (Vector)Q4_BYTE_SHIFT_128).lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0).mul(af1)).convertShape(VectorOperators.S2F, FloatVector.SPECIES_512, 0);
                        Vector r3 = bf3.lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0).mul(af0).add(bf3.lanewise(VectorOperators.ASHR, (Vector)Q4_BYTE_SHIFT_128).lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0).mul(af1)).convertShape(VectorOperators.S2F, FloatVector.SPECIES_512, 0);
                        acc0 = scale0.fma(r0, (Vector)acc0);
                        acc1 = scale1.fma(r1, (Vector)acc1);
                        acc2 = scale2.fma(r2, (Vector)acc2);
                        acc3 = scale3.fma(r3, (Vector)acc3);
                        ++k;
                        aoffset += 32;
                        boffset += 32;
                    }
                }
                float r0 = acc0.reduceLanes(VectorOperators.ADD);
                float r1 = acc1.reduceLanes(VectorOperators.ADD);
                float r2 = acc2.reduceLanes(VectorOperators.ADD);
                float r3 = acc3.reduceLanes(VectorOperators.ADD);
                this.c.set(r0, i, j + 0 + this.rOffset);
                this.c.set(r1, i, j + 1 + this.rOffset);
                this.c.set(r2, i, j + 2 + this.rOffset);
                this.c.set(r3, i, j + 3 + this.rOffset);
            };
        }

        protected BiIntConsumer initMatmul3x4() {
            return (i, j) -> {
                int blockSize = 32;
                int aoffset = this.aColumnOffset;
                int boffset = this.bColumnOffset;
                FloatVector acc00 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_512);
                FloatVector acc01 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_512);
                FloatVector acc10 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_512);
                FloatVector acc11 = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_512);
                int l = 0;
                while (l < this.k) {
                    float as0 = this.a.getFactorForIndex(i + 0, aoffset);
                    float as1 = this.a.getFactorForIndex(i + 1, aoffset);
                    float bs0 = this.b.getFactorForIndex(j + 0, boffset);
                    float bs1 = this.b.getFactorForIndex(j + 1, boffset);
                    FloatVector scale00 = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_512, (float)(as0 * bs0));
                    FloatVector scale01 = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_512, (float)(as0 * bs1));
                    FloatVector scale10 = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_512, (float)(as1 * bs0));
                    FloatVector scale11 = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_512, (float)(as1 * bs1));
                    Vector af0 = this.a.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_256, new int[]{i + 0, aoffset}).convertShape(VectorOperators.B2S, ShortVector.SPECIES_512, 0);
                    Vector af1 = this.a.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_256, new int[]{i + 1, aoffset}).convertShape(VectorOperators.B2S, ShortVector.SPECIES_512, 0);
                    Vector af0low = af0.castShape(ShortVector.SPECIES_256, 0);
                    Vector af0high = af0.castShape(ShortVector.SPECIES_256, 1);
                    Vector af1low = af1.castShape(ShortVector.SPECIES_256, 0);
                    Vector af1high = af1.castShape(ShortVector.SPECIES_256, 1);
                    ByteVector bf0 = this.b.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_128, new int[]{j + 0, boffset});
                    ByteVector bf1 = this.b.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_128, new int[]{j + 1, boffset});
                    Vector low0 = bf0.lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0);
                    Vector high0 = bf0.lanewise(VectorOperators.ASHR, (Vector)Q4_BYTE_SHIFT_128).lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0);
                    Vector low1 = bf1.lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0);
                    Vector high1 = bf1.lanewise(VectorOperators.ASHR, (Vector)Q4_BYTE_SHIFT_128).lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_128).sub((Vector)Q4_BYTE_SUB_128).convertShape(VectorOperators.B2S, ShortVector.SPECIES_256, 0);
                    Vector r00 = low0.mul(af0low).add(high0.mul(af0high)).convertShape(VectorOperators.S2F, FloatVector.SPECIES_512, 0);
                    Vector r01 = low1.mul(af0low).add(high1.mul(af0high)).convertShape(VectorOperators.S2F, FloatVector.SPECIES_512, 0);
                    Vector r10 = low0.mul(af1low).add(high0.mul(af1high)).convertShape(VectorOperators.S2F, FloatVector.SPECIES_512, 0);
                    Vector r11 = low1.mul(af1low).add(high1.mul(af1high)).convertShape(VectorOperators.S2F, FloatVector.SPECIES_512, 0);
                    acc00 = scale00.fma(r00, (Vector)acc00);
                    acc01 = scale01.fma(r01, (Vector)acc01);
                    acc10 = scale10.fma(r10, (Vector)acc10);
                    acc11 = scale11.fma(r11, (Vector)acc11);
                    l += 32;
                    aoffset += 32;
                    boffset += 32;
                }
                float r00 = acc00.reduceLanes(VectorOperators.ADD);
                float r01 = acc01.reduceLanes(VectorOperators.ADD);
                float r10 = acc10.reduceLanes(VectorOperators.ADD);
                float r11 = acc11.reduceLanes(VectorOperators.ADD);
                this.c.set(r00, i + 0, j + 0 + this.rOffset);
                this.c.set(r01, i + 0, j + 1 + this.rOffset);
                this.c.set(r10, i + 1, j + 0 + this.rOffset);
                this.c.set(r11, i + 1, j + 1 + this.rOffset);
            };
        }
    }

    private class GemmerI8Q4_arm
    extends Gemmer {
        final BiIntConsumer matmul1x1;
        final BiIntConsumer matmul1x4;
        final BiIntConsumer matmul3x4;
        final BiIntConsumer matmul4x1;
        final Q8ByteBufferTensor a;
        final Q4ByteBufferTensor b;

        GemmerI8Q4_arm(PanamaTensorOperations panamaTensorOperations, int k, AbstractTensor ta, AbstractTensor tb, AbstractTensor c, int aColumnOffset, int bColumnOffset, int rOffset) {
            super(panamaTensorOperations, k, ta, tb, c, aColumnOffset, bColumnOffset, rOffset);
            this.a = (Q8ByteBufferTensor)ta;
            this.b = (Q4ByteBufferTensor)tb;
            this.matmul1x1 = this.initMatmul1x1();
            this.matmul1x4 = null;
            this.matmul3x4 = null;
            this.matmul4x1 = null;
        }

        @Override
        protected int pickKernel(int m0, int m, int n0, int n) {
            int mc = 1;
            int nc = 1;
            this.kernel(m0, m, 1, n0, n, 1, this.matmul1x1);
            return mc << 4 | nc;
        }

        protected BiIntConsumer initMatmul1x1() {
            return (i, j) -> {
                int blockSize = 32;
                int blocksNeeded = this.k / 32;
                int aoffset = this.aColumnOffset;
                int boffset = this.bColumnOffset;
                FloatVector acc = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_128);
                for (int bi = 0; bi < blocksNeeded; bi += FloatVector.SPECIES_128.length()) {
                    FloatVector ablock = this.a.getBlockF().getVector((VectorSpecies<Float>)FloatVector.SPECIES_128, new int[]{i, (int)(0.03125f * (float)aoffset)});
                    FloatVector bblock = this.b.getBlockF().getVector((VectorSpecies<Float>)FloatVector.SPECIES_128, new int[]{j, (int)(0.03125f * (float)boffset)});
                    FloatVector scales = ablock.mul((Vector)bblock);
                    int k = 0;
                    while (k < FloatVector.SPECIES_128.length()) {
                        FloatVector scale = FloatVector.broadcast((VectorSpecies)FloatVector.SPECIES_128, (float)scales.lane(k));
                        ByteVector ab0 = this.a.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_128, new int[]{i, aoffset});
                        ByteVector ab1 = this.a.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_128, new int[]{i, aoffset + 16});
                        Vector af0 = ab0.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
                        Vector af1 = ab0.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 1);
                        Vector af2 = ab1.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
                        Vector af3 = ab1.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 1);
                        ByteVector bf0 = this.b.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_64, new int[]{j, boffset});
                        ByteVector bf1 = this.b.getVector((VectorSpecies<Byte>)ByteVector.SPECIES_64, new int[]{j, boffset + 16});
                        ByteVector low = bf0.lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_64).sub((Vector)Q4_BYTE_SUB_64);
                        ByteVector high = bf0.lanewise(VectorOperators.ASHR, (Vector)Q4_BYTE_SHIFT_64).lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_64).sub((Vector)Q4_BYTE_SUB_64);
                        Vector low0 = low.castShape(ShortVector.SPECIES_128, 0);
                        Vector high0 = high.castShape(ShortVector.SPECIES_128, 0);
                        ByteVector nlow = bf1.lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_64).sub((Vector)Q4_BYTE_SUB_64);
                        ByteVector nhigh = bf1.lanewise(VectorOperators.ASHR, (Vector)Q4_BYTE_SHIFT_64).lanewise((VectorOperators.Binary)VectorOperators.AND, (Vector)Q4_BYTE_MASK_64).sub((Vector)Q4_BYTE_SUB_64);
                        Vector low2 = nlow.castShape(ShortVector.SPECIES_128, 0);
                        Vector high2 = nhigh.castShape(ShortVector.SPECIES_128, 0);
                        ShortVector tacc = ShortVector.zero((VectorSpecies)ShortVector.SPECIES_128);
                        tacc = tacc.add(af0.mul(low0));
                        tacc = tacc.add(af1.mul(low2));
                        tacc = tacc.add(af2.mul(high0));
                        tacc = tacc.add(af3.mul(high2));
                        acc = acc.add(tacc.convertShape(VectorOperators.S2F, FloatVector.SPECIES_128, 0).mul((Vector)scale));
                        acc = acc.add(tacc.convertShape(VectorOperators.S2F, FloatVector.SPECIES_128, 1).mul((Vector)scale));
                        ++k;
                        aoffset += 32;
                        boffset += 32;
                    }
                }
                this.c.set(acc.reduceLanes(VectorOperators.ADD), i, j + this.rOffset);
            };
        }
    }

    private class GemmerBF16
    extends Gemmer {
        final BiIntConsumer matmul1x1 = this.initMatmul1x1();
        final BFloat16BufferTensor a;
        final BFloat16BufferTensor b;

        GemmerBF16(PanamaTensorOperations panamaTensorOperations, int k, AbstractTensor ta, AbstractTensor tb, AbstractTensor c, int ith, int nth, int rOffset) {
            super(panamaTensorOperations, k, ta, tb, c, ith, nth, rOffset);
            this.a = (BFloat16BufferTensor)ta;
            this.b = (BFloat16BufferTensor)tb;
        }

        @Override
        protected int pickKernel(int m0, int m, int n0, int n) {
            int mc = 1;
            int nc = 1;
            this.kernel(m0, m, 1, n0, n, 1, this.matmul1x1);
            return mc << 4 | nc;
        }

        protected BiIntConsumer initMatmul1x1() {
            return (i, j) -> {
                FloatVector vc = FloatVector.zero((VectorSpecies)FloatVector.SPECIES_PREFERRED);
                int aoffset = this.aColumnOffset;
                int alim = this.aColumnOffset + this.k;
                int blim = this.bColumnOffset + this.k;
                int slen = ShortVector.SPECIES_PREFERRED.length();
                for (int boffset = this.bColumnOffset; aoffset < alim && boffset < blim; aoffset += slen, boffset += slen) {
                    ShortVector sa = this.a.getVector((VectorSpecies<Short>)ShortVector.SPECIES_PREFERRED, new int[]{i, aoffset});
                    FloatVector va0 = sa.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT).reinterpretAsFloats();
                    FloatVector va1 = sa.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 1).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT).reinterpretAsFloats();
                    ShortVector sb = this.b.getVector((VectorSpecies<Short>)ShortVector.SPECIES_PREFERRED, new int[]{j, boffset});
                    FloatVector vb0 = sb.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT).reinterpretAsFloats();
                    FloatVector vb1 = sb.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 1).lanewise(VectorOperators.LSHL, (Vector)BF16_BYTE_SHIFT).reinterpretAsFloats();
                    vc = va0.fma((Vector)vb0, (Vector)vc);
                    vc = va1.fma((Vector)vb1, (Vector)vc);
                }
                float res = vc.reduceLanes(VectorOperators.ADD);
                this.c.set(res, i, j + this.rOffset);
            };
        }
    }

    private abstract class Gemmer {
        final int k;
        final AbstractTensor a;
        final AbstractTensor b;
        final AbstractTensor c;
        final int aColumnOffset;
        final int bColumnOffset;
        final int rOffset;

        Gemmer(PanamaTensorOperations panamaTensorOperations, int k, AbstractTensor a, AbstractTensor b, AbstractTensor c, int aColumnOffset, int bColumnOffset, int rOffset) {
            this.k = k;
            this.a = a;
            this.b = b;
            this.c = c;
            this.aColumnOffset = aColumnOffset;
            this.bColumnOffset = bColumnOffset;
            this.rOffset = rOffset;
        }

        void matmul(int m0, int m, int n0, int n) {
            this.mnpack(m0, m, n0, n);
        }

        private void mnpack(int m0, int m, int n0, int n) {
            if (m - m0 <= 0 || n - n0 <= 0) {
                return;
            }
            int r = this.pickKernel(m0, m, n0, n);
            int mc = r >> 4;
            int nc = r & 0xF;
            int mp = m0 + (m - m0) / mc * mc;
            int np = n0 + (n - n0) / nc * nc;
            this.mnpack(mp, m, n0, np);
            this.mnpack(m0, mp, np, n);
        }

        protected abstract int pickKernel(int var1, int var2, int var3, int var4);

        void kernel(int m0, int m, int RM, int n0, int n, int RN, BiIntConsumer action) {
            int ytiles = (m - m0) / RM;
            int xtiles = (n - n0) / RN;
            int tiles = ytiles * xtiles;
            for (int job = 0; job < tiles; ++job) {
                int i = m0 + job / xtiles * RM;
                int j = n0 + job % xtiles * RN;
                action.accept(i, j);
            }
        }
    }
}

