/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.listener;

import ai.djl.metric.Metric;
import ai.djl.metric.Metrics;
import ai.djl.training.Trainer;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.listener.TrainingListenerAdapter;
import java.io.BufferedWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.nio.file.attribute.FileAttribute;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TimeMeasureTrainingListener
extends TrainingListenerAdapter {
    private static final Logger logger = LoggerFactory.getLogger(TimeMeasureTrainingListener.class);
    private String outputDir;
    private long trainBatchBeginTime;
    private long validateBatchBeginTime;

    public TimeMeasureTrainingListener(String outputDir) {
        this.outputDir = outputDir;
        this.trainBatchBeginTime = -1L;
        this.validateBatchBeginTime = -1L;
    }

    @Override
    public void onEpoch(Trainer trainer) {
        this.trainBatchBeginTime = -1L;
        this.validateBatchBeginTime = -1L;
    }

    @Override
    public void onTrainingBatch(Trainer trainer, TrainingListener.BatchData batchData) {
        if (this.trainBatchBeginTime != -1L) {
            trainer.addMetric("train", this.trainBatchBeginTime);
        }
        this.trainBatchBeginTime = System.nanoTime();
    }

    @Override
    public void onValidationBatch(Trainer trainer, TrainingListener.BatchData batchData) {
        if (this.validateBatchBeginTime != -1L) {
            trainer.addMetric("validate", this.validateBatchBeginTime);
        }
        this.validateBatchBeginTime = System.nanoTime();
    }

    @Override
    public void onTrainingEnd(Trainer trainer) {
        Metrics metrics = trainer.getMetrics();
        TimeMeasureTrainingListener.dumpTrainingTimeInfo(metrics, this.outputDir);
    }

    private static void dumpTrainingTimeInfo(Metrics metrics, String logDir) {
        if (metrics == null || logDir == null) {
            return;
        }
        try {
            Path dir = Paths.get(logDir, new String[0]);
            Files.createDirectories(dir, new FileAttribute[0]);
            TimeMeasureTrainingListener.dumpMetricToFile(dir.resolve("training.log"), metrics.getMetric("train"));
            TimeMeasureTrainingListener.dumpMetricToFile(dir.resolve("validate.log"), metrics.getMetric("validate"));
        }
        catch (IOException e) {
            logger.error("Failed dump training log", (Throwable)e);
        }
    }

    private static void dumpMetricToFile(Path path, List<Metric> metrics) throws IOException {
        if (metrics == null || metrics.isEmpty()) {
            return;
        }
        try (BufferedWriter writer = Files.newBufferedWriter(path, StandardOpenOption.CREATE, StandardOpenOption.APPEND);){
            for (Metric metric : metrics) {
                writer.append(metric.toString());
                writer.newLine();
            }
        }
    }
}

