package org.broadinstitute.gatk.tools.walkers.variantrecalibration;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.log4j.Logger;
import org.broadinstitute.gatk.engine.arguments.StandardCallerArgumentCollection;
import org.broadinstitute.gatk.tools.walkers.variantrecalibration.VariantDatum;
import org.broadinstitute.gatk.tools.walkers.variantrecalibration.VariantRecalibratorArgumentCollection;
import org.broadinstitute.gatk.utils.exceptions.UserException;

/* loaded from: input_file:org/broadinstitute/gatk/tools/walkers/variantrecalibration/TrancheManager.class */
public class TrancheManager {
    protected static final Logger logger = Logger.getLogger(TrancheManager.class);

    /* loaded from: input_file:org/broadinstitute/gatk/tools/walkers/variantrecalibration/TrancheManager$NovelTiTvMetric.class */
    public static class NovelTiTvMetric extends SelectionMetric {
        double[] runningTiTv;
        double targetTiTv;

        public NovelTiTvMetric(double d) {
            super("NovelTiTv");
            this.targetTiTv = StandardCallerArgumentCollection.DEFAULT_CONTAMINATION_FRACTION;
            this.targetTiTv = d;
        }

        @Override // org.broadinstitute.gatk.tools.walkers.variantrecalibration.TrancheManager.SelectionMetric
        public double getThreshold(double d) {
            return TrancheManager.fdrToTiTv(d, this.targetTiTv);
        }

        @Override // org.broadinstitute.gatk.tools.walkers.variantrecalibration.TrancheManager.SelectionMetric
        public double getTarget() {
            return this.targetTiTv;
        }

        @Override // org.broadinstitute.gatk.tools.walkers.variantrecalibration.TrancheManager.SelectionMetric
        public void calculateRunningMetric(List<VariantDatum> list) {
            int i = 0;
            int i2 = 0;
            this.runningTiTv = new double[list.size()];
            for (int size = list.size() - 1; size >= 0; size--) {
                VariantDatum variantDatum = list.get(size);
                if (!variantDatum.isKnown) {
                    if (variantDatum.isTransition) {
                        i++;
                    } else {
                        i2++;
                    }
                    this.runningTiTv[size] = i / Math.max(1.0d * i2, 1.0d);
                }
            }
        }

        @Override // org.broadinstitute.gatk.tools.walkers.variantrecalibration.TrancheManager.SelectionMetric
        public double getRunningMetric(int i) {
            return this.runningTiTv[i];
        }

        @Override // org.broadinstitute.gatk.tools.walkers.variantrecalibration.TrancheManager.SelectionMetric
        public int datumValue(VariantDatum variantDatum) {
            return variantDatum.isTransition ? 1 : 0;
        }
    }

    /* loaded from: input_file:org/broadinstitute/gatk/tools/walkers/variantrecalibration/TrancheManager$SelectionMetric.class */
    public static abstract class SelectionMetric {
        String name;

        public SelectionMetric(String str) {
            this.name = null;
            this.name = str;
        }

        public String getName() {
            return this.name;
        }

        public abstract double getThreshold(double d);

        public abstract double getTarget();

        public abstract void calculateRunningMetric(List<VariantDatum> list);

        public abstract double getRunningMetric(int i);

        public abstract int datumValue(VariantDatum variantDatum);
    }

    /* loaded from: input_file:org/broadinstitute/gatk/tools/walkers/variantrecalibration/TrancheManager$TruthSensitivityMetric.class */
    public static class TruthSensitivityMetric extends SelectionMetric {
        double[] runningSensitivity;
        int nTrueSites;

        public TruthSensitivityMetric(int i) {
            super("TruthSensitivity");
            this.nTrueSites = 0;
            this.nTrueSites = i;
        }

        @Override // org.broadinstitute.gatk.tools.walkers.variantrecalibration.TrancheManager.SelectionMetric
        public double getThreshold(double d) {
            return 1.0d - (d / 100.0d);
        }

        @Override // org.broadinstitute.gatk.tools.walkers.variantrecalibration.TrancheManager.SelectionMetric
        public double getTarget() {
            return 1.0d;
        }

        @Override // org.broadinstitute.gatk.tools.walkers.variantrecalibration.TrancheManager.SelectionMetric
        public void calculateRunningMetric(List<VariantDatum> list) {
            int i = 0;
            this.runningSensitivity = new double[list.size()];
            for (int size = list.size() - 1; size >= 0; size--) {
                i += list.get(size).atTruthSite ? 1 : 0;
                this.runningSensitivity[size] = 1.0d - (i / (1.0d * this.nTrueSites));
            }
        }

        @Override // org.broadinstitute.gatk.tools.walkers.variantrecalibration.TrancheManager.SelectionMetric
        public double getRunningMetric(int i) {
            return this.runningSensitivity[i];
        }

        @Override // org.broadinstitute.gatk.tools.walkers.variantrecalibration.TrancheManager.SelectionMetric
        public int datumValue(VariantDatum variantDatum) {
            return variantDatum.atTruthSite ? 1 : 0;
        }
    }

    public static List<Tranche> findTranches(List<VariantDatum> list, double[] dArr, SelectionMetric selectionMetric, VariantRecalibratorArgumentCollection.Mode mode) {
        return findTranches(list, dArr, selectionMetric, mode, null);
    }

    public static List<Tranche> findTranches(List<VariantDatum> list, double[] dArr, SelectionMetric selectionMetric, VariantRecalibratorArgumentCollection.Mode mode, File file) {
        logger.info(String.format("Finding %d tranches for %d variants", Integer.valueOf(dArr.length), Integer.valueOf(list.size())));
        Collections.sort(list, new VariantDatum.VariantDatumLODComparator());
        selectionMetric.calculateRunningMetric(list);
        if (file != null) {
            writeTranchesDebuggingInfo(file, list, selectionMetric);
        }
        ArrayList arrayList = new ArrayList();
        int length = dArr.length;
        int i = 0;
        while (true) {
            if (i >= length) {
                break;
            }
            double d = dArr[i];
            Tranche findTranche = findTranche(list, selectionMetric, d, mode);
            if (findTranche != null) {
                arrayList.add(findTranche);
                i++;
            } else if (arrayList.size() == 0) {
                throw new UserException(String.format("Couldn't find any tranche containing variants with a %s > %.2f. Are you sure the truth files contain unfiltered variants which overlap the input data?", selectionMetric.getName(), Double.valueOf(selectionMetric.getThreshold(d))));
            }
        }
        return arrayList;
    }

    private static void writeTranchesDebuggingInfo(File file, List<VariantDatum> list, SelectionMetric selectionMetric) {
        try {
            PrintStream printStream = new PrintStream(file);
            printStream.println("Qual metricValue runningValue");
            for (int i = 0; i < list.size(); i++) {
                VariantDatum variantDatum = list.get(i);
                printStream.printf("%.4f %d %.4f%n", Double.valueOf(variantDatum.lod), Integer.valueOf(selectionMetric.datumValue(variantDatum)), Double.valueOf(selectionMetric.getRunningMetric(i)));
            }
            printStream.close();
        } catch (FileNotFoundException e) {
            throw new UserException.CouldNotCreateOutputFile(file, e);
        }
    }

    public static Tranche findTranche(List<VariantDatum> list, SelectionMetric selectionMetric, double d, VariantRecalibratorArgumentCollection.Mode mode) {
        logger.info(String.format("  Tranche threshold %.2f => selection metric threshold %.3f", Double.valueOf(d), Double.valueOf(selectionMetric.getThreshold(d))));
        double threshold = selectionMetric.getThreshold(d);
        int size = list.size();
        for (int i = 0; i < size; i++) {
            if (selectionMetric.getRunningMetric(i) >= threshold) {
                Tranche trancheOfVariants = trancheOfVariants(list, i, d, mode);
                logger.info(String.format("  Found tranche for %.3f: %.3f threshold starting with variant %d; running score is %.3f ", Double.valueOf(d), Double.valueOf(threshold), Integer.valueOf(i), Double.valueOf(selectionMetric.getRunningMetric(i))));
                logger.info(String.format("  Tranche is %s", trancheOfVariants));
                return trancheOfVariants;
            }
        }
        return null;
    }

    public static Tranche trancheOfVariants(List<VariantDatum> list, int i, double d, VariantRecalibratorArgumentCollection.Mode mode) {
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        int i5 = 0;
        int i6 = 0;
        int i7 = 0;
        double d2 = list.get(i).lod;
        for (VariantDatum variantDatum : list) {
            if (variantDatum.lod >= d2) {
                if (variantDatum.isKnown) {
                    i2++;
                    if (variantDatum.isSNP) {
                        if (variantDatum.isTransition) {
                            i4++;
                        } else {
                            i5++;
                        }
                    }
                } else {
                    i3++;
                    if (variantDatum.isSNP) {
                        if (variantDatum.isTransition) {
                            i6++;
                        } else {
                            i7++;
                        }
                    }
                }
            }
        }
        return new Tranche(d, d2, i2, i4 / Math.max(1.0d * i5, 1.0d), i3, i6 / Math.max(1.0d * i7, 1.0d), countCallsAtTruth(list, Double.NEGATIVE_INFINITY), countCallsAtTruth(list, d2), mode);
    }

    public static double fdrToTiTv(double d, double d2) {
        return ((1.0d - (d / 100.0d)) * (d2 - 0.5d)) + 0.5d;
    }

    public static int countCallsAtTruth(List<VariantDatum> list, double d) {
        int i = 0;
        for (VariantDatum variantDatum : list) {
            i += (!variantDatum.atTruthSite || variantDatum.lod < d) ? 0 : 1;
        }
        return i;
    }
}
