package com.rtg.vcf.eval;

import com.rtg.launcher.CommonFlags;
import com.rtg.launcher.globals.GlobalFlags;
import com.rtg.launcher.globals.ToolsGlobalFlags;
import com.rtg.util.ContingencyTable;
import com.rtg.util.Environment;
import com.rtg.util.MathUtils;
import com.rtg.util.MultiSet;
import com.rtg.util.StringUtils;
import com.rtg.util.TextTable;
import com.rtg.util.Utils;
import com.rtg.util.cli.CommandLine;
import com.rtg.util.diagnostic.Diagnostic;
import com.rtg.util.io.FileUtils;
import com.rtg.util.io.LineWriter;
import com.rtg.vcf.VcfRecord;
import com.rtg.vcf.VcfUtils;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintStream;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;

/* loaded from: input_file:com/rtg/vcf/eval/RocContainer.class */
public class RocContainer {
    private static final int SCORE_DP = 3;
    private static final int COUNT_DP = 2;
    private static final int METRICS_DP = 4;
    private static final String SLOPE_EXT = "_slope.tsv";
    private final boolean mRescaleCategories;
    private final String mFieldLabel;
    private final Map<RocFilter, SortedMap<Double, RocPoint<Double>>> mRocs;
    private final Comparator<Double> mComparator;
    private final RocSortValueExtractor mRocExtractor;
    private final String mFilePrefix;
    private int mNoScoreVariants;
    private boolean mRequiresGt;
    private RocPoint<Double> mBest;
    private double mBestFMeasure;
    MultiSet<RocFilter> mBaselineTotals;
    MultiSet<RocFilter> mBaselineTpTotal;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:com/rtg/vcf/eval/RocContainer$AscendingDoubleComparator.class */
    private static class AscendingDoubleComparator implements Comparator<Double>, Serializable {
        private AscendingDoubleComparator() {
        }

        @Override // java.util.Comparator
        public int compare(Double d, Double d2) {
            return d.compareTo(d2);
        }
    }

    /* loaded from: input_file:com/rtg/vcf/eval/RocContainer$DescendingDoubleComparator.class */
    private static class DescendingDoubleComparator implements Comparator<Double>, Serializable {
        private DescendingDoubleComparator() {
        }

        @Override // java.util.Comparator
        public int compare(Double d, Double d2) {
            if (Double.isNaN(d.doubleValue())) {
                return Double.isNaN(d2.doubleValue()) ? 0 : 1;
            }
            if (Double.isNaN(d2.doubleValue())) {
                return -1;
            }
            return d2.compareTo(d);
        }
    }

    /* loaded from: input_file:com/rtg/vcf/eval/RocContainer$RocColumns.class */
    public static class RocColumns {
        public static final String SCORE = "score";
        public static final String TRUE_POSITIVES = "true_positives";
        public static final String TRUE_POSITIVES_BASELINE = "true_positives_baseline";
        public static final String FALSE_POSITIVES = "false_positives";
        public static final String TRUE_POSITIVES_CALL = "true_positives_call";
        public static final String FALSE_NEGATIVES = "false_negatives";
        public static final String PRECISION = "precision";
        public static final String SENSITIVITY = "sensitivity";
        public static final String F_MEASURE = "f_measure";
    }

    public RocContainer(RocSortValueExtractor rocSortValueExtractor) {
        this(rocSortValueExtractor, "");
    }

    public RocContainer(RocSortValueExtractor rocSortValueExtractor, String str) {
        this.mRescaleCategories = GlobalFlags.getBooleanValue(ToolsGlobalFlags.VCFEVAL_ROC_SUBSET_RESCALE);
        this.mRocs = new LinkedHashMap();
        this.mNoScoreVariants = 0;
        this.mRequiresGt = false;
        this.mBest = null;
        this.mBestFMeasure = 0.0d;
        this.mBaselineTotals = new MultiSet<>();
        this.mBaselineTpTotal = new MultiSet<>();
        switch (rocSortValueExtractor.getSortOrder()) {
            case ASCENDING:
                this.mComparator = new AscendingDoubleComparator();
                break;
            case DESCENDING:
            default:
                this.mComparator = new DescendingDoubleComparator();
                break;
        }
        this.mFieldLabel = rocSortValueExtractor.toString();
        this.mFilePrefix = str;
        this.mRocExtractor = rocSortValueExtractor;
    }

    public int getNumberOfIgnoredVariants() {
        return this.mNoScoreVariants;
    }

    public void addFilter(RocFilter rocFilter) {
        this.mRocs.put(rocFilter, new TreeMap(this.mComparator));
        this.mRequiresGt |= rocFilter.requiresGt();
    }

    public void addFilters(Set<RocFilter> set) {
        if (!$assertionsDisabled && set == null) {
            throw new AssertionError();
        }
        Iterator<RocFilter> it = set.iterator();
        while (it.hasNext()) {
            addFilter(it.next());
        }
    }

    Collection<RocFilter> filters() {
        return this.mRocs.keySet();
    }

    public void incrementBaselineCount(VcfRecord vcfRecord, int i, boolean z) {
        int[] validGt = this.mRequiresGt ? VcfUtils.getValidGt(vcfRecord, i) : null;
        for (RocFilter rocFilter : filters()) {
            if (rocFilter.accept(vcfRecord, validGt)) {
                this.mBaselineTotals.add(rocFilter);
                if (z) {
                    this.mBaselineTpTotal.add(rocFilter);
                }
            }
        }
    }

    public void addRocLine(VcfRecord vcfRecord, int i, double d, double d2, double d3) {
        RocPoint<Double> rocPoint;
        double d4 = Double.NaN;
        try {
            d4 = this.mRocExtractor.getSortValue(vcfRecord, i);
        } catch (IndexOutOfBoundsException e) {
        }
        if (Double.isNaN(d4) || Double.isInfinite(d4)) {
            this.mNoScoreVariants++;
            rocPoint = new RocPoint<>(Double.valueOf(Double.NaN), d, d2, d3);
        } else {
            rocPoint = new RocPoint<>(Double.valueOf(d4), d, d2, d3);
        }
        int[] validGt = this.mRequiresGt ? VcfUtils.getValidGt(vcfRecord, i) : null;
        for (RocFilter rocFilter : filters()) {
            if (rocFilter.accept(vcfRecord, validGt)) {
                addRocLine(rocPoint, rocFilter);
            }
        }
    }

    void addRocLine(RocPoint<Double> rocPoint, RocFilter rocFilter) {
        SortedMap<Double, RocPoint<Double>> sortedMap = this.mRocs.get(rocFilter);
        if (sortedMap.containsKey(rocPoint.getThreshold())) {
            sortedMap.get(rocPoint.getThreshold()).add(rocPoint);
        } else {
            sortedMap.put(rocPoint.getThreshold(), new RocPoint<>(rocPoint));
        }
    }

    public void writeRocs(File file, boolean z, boolean z2) throws IOException {
        boolean z3;
        double d;
        Diagnostic.developerLog("Writing ROC");
        this.mBestFMeasure = 0.0d;
        this.mBest = null;
        for (Map.Entry<RocFilter, SortedMap<Double, RocPoint<Double>>> entry : this.mRocs.entrySet()) {
            RocFilter key = entry.getKey();
            boolean z4 = this.mRescaleCategories && key != RocFilter.ALL;
            int i = this.mBaselineTotals.get(z4 ? key : RocFilter.ALL);
            RocPoint<Double> total = getTotal(z4 ? key : RocFilter.ALL);
            int round = (int) Math.round(total.getRawTruePositives() + total.getFalsePositives());
            if (z4) {
                z3 = i > 0;
                int i2 = this.mBaselineTpTotal.get(key);
                d = total.getTruePositives() > 0.0d ? i2 / total.getTruePositives() : 1.0d;
                Diagnostic.userLog("Representation bias correction factor for " + key + " " + i2 + "/" + total.getTruePositives() + " = " + d);
            } else {
                z3 = key == RocFilter.ALL && i > 0;
                d = 1.0d;
            }
            SortedMap<Double, RocPoint<Double>> value = entry.getValue();
            File zippedFileName = FileUtils.getZippedFileName(z, new File(file, this.mFilePrefix + key.fileName()));
            LineWriter lineWriter = new LineWriter(new OutputStreamWriter(FileUtils.createOutputStream(zippedFileName)));
            Throwable th = null;
            try {
                try {
                    rocHeader(lineWriter, key, i, round, z3, z4);
                    String str = null;
                    RocPoint<Double> rocPoint = new RocPoint<>();
                    Iterator<Map.Entry<Double, RocPoint<Double>>> it = value.entrySet().iterator();
                    while (it.hasNext()) {
                        RocPoint<Double> value2 = it.next().getValue();
                        String realFormat = Double.isNaN(value2.getThreshold().doubleValue()) ? "None" : Utils.realFormat(value2.getThreshold().doubleValue(), 3);
                        if (str != null && realFormat.compareTo(str) != 0) {
                            writeRocLine(lineWriter, key, str, i, rocPoint, z3, d);
                        }
                        str = realFormat;
                        rocPoint.add(value2);
                        rocPoint.setThreshold(value2.getThreshold());
                    }
                    if (str != null) {
                        writeRocLine(lineWriter, key, str, i, rocPoint, z3, d);
                    }
                    if (lineWriter != null) {
                        if (0 != 0) {
                            try {
                                lineWriter.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            lineWriter.close();
                        }
                    }
                    if (z2) {
                        produceSlopeFile(zippedFileName);
                    }
                } finally {
                }
            } catch (Throwable th3) {
                if (lineWriter != null) {
                    if (th != null) {
                        try {
                            lineWriter.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        lineWriter.close();
                    }
                }
                throw th3;
            }
        }
    }

    private void rocHeader(LineWriter lineWriter, RocFilter rocFilter, int i, int i2, boolean z, boolean z2) throws IOException {
        lineWriter.writeln("#Version " + Environment.getVersion() + ", ROC output 1.2");
        if (CommandLine.getCommandLine() != null) {
            lineWriter.writeln("#CL " + CommandLine.getCommandLine());
        }
        lineWriter.writeln("#selection: " + rocFilter.name() + (z2 ? " (baseline rescaled)" : ""));
        lineWriter.writeln("#total baseline variants: " + i);
        lineWriter.writeln("#total call variants: " + i2);
        lineWriter.writeln("#score field: " + this.mFieldLabel);
        lineWriter.write("#" + String.join("\t", Arrays.asList(RocColumns.SCORE, RocColumns.TRUE_POSITIVES_BASELINE, RocColumns.FALSE_POSITIVES, RocColumns.TRUE_POSITIVES_CALL)));
        if (z) {
            lineWriter.write("\t" + String.join("\t", Arrays.asList(RocColumns.FALSE_NEGATIVES, RocColumns.PRECISION, RocColumns.SENSITIVITY, RocColumns.F_MEASURE)));
        }
        lineWriter.newLine();
    }

    private void writeRocLine(LineWriter lineWriter, RocFilter rocFilter, String str, int i, RocPoint<Double> rocPoint, boolean z, double d) throws IOException {
        double truePositives = rocPoint.getTruePositives() * d;
        double falsePositives = rocPoint.getFalsePositives();
        double rawTruePositives = rocPoint.getRawTruePositives();
        lineWriter.write(str + "\t" + Utils.realFormat(truePositives, 2) + "\t" + Utils.realFormat(falsePositives, 2) + "\t" + Utils.realFormat(rawTruePositives, 2));
        if (z) {
            double d2 = i - truePositives;
            double precision = ContingencyTable.precision(rawTruePositives, falsePositives);
            double recall = ContingencyTable.recall(truePositives, d2);
            double fMeasure = ContingencyTable.fMeasure(precision, recall);
            lineWriter.write("\t" + Utils.realFormat(d2, 2) + "\t" + Utils.realFormat(precision, 4) + "\t" + Utils.realFormat(recall, 4) + "\t" + Utils.realFormat(fMeasure, 4));
            if (rocFilter == RocFilter.ALL && !Double.isNaN(rocPoint.getThreshold().doubleValue()) && (this.mBest == null || fMeasure >= this.mBestFMeasure)) {
                this.mBestFMeasure = fMeasure;
                this.mBest = new RocPoint<>(rocPoint);
            }
        }
        lineWriter.newLine();
    }

    private static void addSummaryRow(TextTable textTable, String str, double d, double d2, double d3, double d4) {
        double precision = ContingencyTable.precision(d3, d4);
        double recall = ContingencyTable.recall(d, d2);
        textTable.addRow(str, Long.toString(MathUtils.round(d)), Long.toString(MathUtils.round(d3)), Long.toString(MathUtils.round(d4)), Long.toString(MathUtils.round(d2)), Utils.realFormat(precision, 4), Utils.realFormat(recall, 4), Utils.realFormat(ContingencyTable.fMeasure(precision, recall), 4));
    }

    public void writeSummary(File file) throws IOException {
        String str;
        File file2 = new File(file, this.mFilePrefix + CommonFlags.SUMMARY_FILE);
        int i = this.mBaselineTotals.get(RocFilter.ALL);
        if (i > 0) {
            TextTable textTable = new TextTable();
            textTable.addRow("Threshold", "True-pos-baseline", "True-pos-call", "False-pos", "False-neg", "Precision", "Sensitivity", "F-measure");
            textTable.addSeparator();
            RocPoint<Double> rocPoint = this.mBest;
            if (rocPoint == null) {
                Diagnostic.warning("Could not maximize F-measure from ROC data, only un-thresholded statistics will be computed. Consider selecting a different scoring attribute with --vcf-score-field");
            } else {
                addSummaryRow(textTable, Utils.realFormat(rocPoint.getThreshold().doubleValue(), 3), rocPoint.getTruePositives(), i - rocPoint.getTruePositives(), rocPoint.getRawTruePositives(), rocPoint.getFalsePositives());
            }
            RocPoint<Double> total = getTotal(RocFilter.ALL);
            addSummaryRow(textTable, "None", total.getTruePositives(), i - total.getTruePositives(), total.getRawTruePositives(), total.getFalsePositives());
            str = textTable.toString();
        } else {
            str = "0 total baseline variants, no summary statistics available" + StringUtils.LS;
        }
        Diagnostic.info(str);
        OutputStream createOutputStream = FileUtils.createOutputStream(file2);
        Throwable th = null;
        try {
            try {
                createOutputStream.write(str.getBytes());
                if (createOutputStream != null) {
                    if (0 == 0) {
                        createOutputStream.close();
                        return;
                    }
                    try {
                        createOutputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (createOutputStream != null) {
                if (th != null) {
                    try {
                        createOutputStream.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    createOutputStream.close();
                }
            }
            throw th4;
        }
    }

    private void produceSlopeFile(File file) throws IOException {
        if (!file.exists() || file.length() <= 0) {
            return;
        }
        PrintStream printStream = new PrintStream(FileUtils.createOutputStream(new File(file.getParentFile(), file.getName().replaceAll(RocFilter.ROC_EXT, SLOPE_EXT))));
        Throwable th = null;
        try {
            BufferedInputStream createInputStream = FileUtils.createInputStream(file, false);
            Throwable th2 = null;
            try {
                try {
                    RocSlope.writeSlope(createInputStream, printStream);
                    if (createInputStream != null) {
                        if (0 != 0) {
                            try {
                                createInputStream.close();
                            } catch (Throwable th3) {
                                th2.addSuppressed(th3);
                            }
                        } else {
                            createInputStream.close();
                        }
                    }
                    if (printStream != null) {
                        if (0 == 0) {
                            printStream.close();
                            return;
                        }
                        try {
                            printStream.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    }
                } catch (Throwable th5) {
                    th2 = th5;
                    throw th5;
                }
            } catch (Throwable th6) {
                if (createInputStream != null) {
                    if (th2 != null) {
                        try {
                            createInputStream.close();
                        } catch (Throwable th7) {
                            th2.addSuppressed(th7);
                        }
                    } else {
                        createInputStream.close();
                    }
                }
                throw th6;
            }
        } catch (Throwable th8) {
            if (printStream != null) {
                if (0 != 0) {
                    try {
                        printStream.close();
                    } catch (Throwable th9) {
                        th.addSuppressed(th9);
                    }
                } else {
                    printStream.close();
                }
            }
            throw th8;
        }
    }

    public boolean isRocEnabled() {
        return this.mRocExtractor != RocSortValueExtractor.NULL_EXTRACTOR;
    }

    public void missingScoreWarning() {
        if (!isRocEnabled() || getNumberOfIgnoredVariants() <= 0) {
            return;
        }
        Diagnostic.warning("There were " + getNumberOfIgnoredVariants() + " variants not thresholded in ROC data files due to missing or invalid " + this.mFieldLabel + " values.");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public RocPoint<Double> getTotal(RocFilter rocFilter) {
        RocPoint<Double> rocPoint = new RocPoint<>();
        Iterator<RocPoint<Double>> it = this.mRocs.get(rocFilter).values().iterator();
        while (it.hasNext()) {
            rocPoint.add(it.next());
        }
        return rocPoint;
    }

    static {
        $assertionsDisabled = !RocContainer.class.desiredAssertionStatus();
    }
}
