/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.tensor;

import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.BFloat16BufferTensor;
import com.github.tjake.jlama.tensor.Float16BufferTensor;
import com.github.tjake.jlama.tensor.FloatBufferTensor;
import com.github.tjake.jlama.tensor.Q4ByteBufferTensor;
import com.github.tjake.jlama.tensor.Q8ByteBufferTensor;
import com.github.tjake.jlama.tensor.TensorShape;
import com.google.common.collect.Maps;
import java.util.Objects;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import org.jctools.queues.MpmcUnboundedXaddArrayQueue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TensorCache {
    public static final TensorCache instance = new TensorCache(0x6400000L);
    private static final Logger logger = LoggerFactory.getLogger(TensorCache.class);
    private final long bytesCapacity;
    private final AtomicLong currentBytes;
    private final ConcurrentMap<ShapeKey, MpmcUnboundedXaddArrayQueue<AbstractTensor>> availableByShape;
    private final Function<ShapeKey, MpmcUnboundedXaddArrayQueue<AbstractTensor>> queueFactory = s -> new MpmcUnboundedXaddArrayQueue(128);

    public TensorCache(long bytesCapacity) {
        this.bytesCapacity = bytesCapacity;
        this.currentBytes = new AtomicLong(0L);
        this.availableByShape = Maps.newConcurrentMap();
    }

    public AbstractTensor get(DType dType, TensorShape shape) {
        MpmcUnboundedXaddArrayQueue<AbstractTensor> availableQueue = this.availableByShape.computeIfAbsent(new ShapeKey(dType, shape), this.queueFactory);
        AbstractTensor t = (AbstractTensor)availableQueue.poll();
        if (t != null) {
            return t;
        }
        t = switch (dType) {
            case DType.F32 -> new FloatBufferTensor(shape);
            case DType.F16 -> new Float16BufferTensor(shape);
            case DType.BF16 -> new BFloat16BufferTensor(shape);
            case DType.I8 -> new Q8ByteBufferTensor(shape);
            case DType.Q4 -> new Q4ByteBufferTensor(shape);
            default -> throw new RuntimeException("Unsupported tensor type: " + String.valueOf((Object)dType));
        };
        if (this.currentBytes.addAndGet(t.size()) < this.bytesCapacity) {
            t.setOwnerCache(this);
        } else {
            logger.debug("Full!");
            this.currentBytes.addAndGet(-t.size());
        }
        return t;
    }

    void release(AbstractTensor b) {
        b.clear();
        MpmcUnboundedXaddArrayQueue<AbstractTensor> availableQueue = this.availableByShape.computeIfAbsent(new ShapeKey(b.dType(), b.shape()), this.queueFactory);
        availableQueue.offer((Object)b);
    }

    public static class ShapeKey {
        final TensorShape shape;
        final DType dType;

        ShapeKey(DType dType, TensorShape shape) {
            this.dType = dType;
            this.shape = shape;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            ShapeKey shapeKey = (ShapeKey)o;
            return Objects.equals(this.shape, shapeKey.shape) && this.dType == shapeKey.dType;
        }

        public int hashCode() {
            return Objects.hash(new Object[]{this.shape, this.dType});
        }
    }
}

