package genepi.riskscore.tasks;

import genepi.io.table.reader.CsvTableReader;
import genepi.io.table.writer.CsvTableWriter;
import genepi.riskscore.io.SamplesFile;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Vector;

/* loaded from: input_file:genepi/riskscore/tasks/PopulationPredictor.class */
public class PopulationPredictor {
    private static final String LABEL_UNKNOWN = "Unknown";
    private String studyFile = null;
    private String referenceFile = null;
    private String samplesFile = null;
    private int maxPcs = 3;
    private int K = 10;
    private double weightThreshold = 0.75d;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:genepi/riskscore/tasks/PopulationPredictor$Neighbor.class */
    public class Neighbor implements Comparable<Neighbor> {
        private ReferenceSample sample;
        private double distance;

        public Neighbor(ReferenceSample referenceSample, double d) {
            this.sample = referenceSample;
            this.distance = d;
        }

        public double getDistance() {
            return this.distance;
        }

        public ReferenceSample getSample() {
            return this.sample;
        }

        @Override // java.lang.Comparable
        public int compareTo(Neighbor neighbor) {
            return Double.compare(this.distance, neighbor.getDistance());
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:genepi/riskscore/tasks/PopulationPredictor$PredictedPopulation.class */
    public class PredictedPopulation implements Comparable<PredictedPopulation> {
        private String label;
        private double sumWeight = 0.0d;
        private int count = 0;

        PredictedPopulation() {
        }

        public void setLabel(String str) {
            this.label = str;
        }

        public String getLabel() {
            return this.label;
        }

        public double getWeight() {
            return this.sumWeight;
        }

        public void setSumWeight(double d) {
            this.sumWeight = d;
        }

        public void addSample(double d) {
            this.count++;
            this.sumWeight += d;
        }

        public int getCount() {
            return this.count;
        }

        @Override // java.lang.Comparable
        public int compareTo(PredictedPopulation predictedPopulation) {
            return -Double.compare(getWeight(), predictedPopulation.getWeight());
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:genepi/riskscore/tasks/PopulationPredictor$ReferenceSample.class */
    public class ReferenceSample {
        private String label;
        private double[] pcs;

        ReferenceSample() {
        }

        public void setLabel(String str) {
            this.label = str;
        }

        public String getLabel() {
            return this.label;
        }

        public void setPcs(double[] dArr) {
            this.pcs = dArr;
        }

        public double[] getPcs() {
            return this.pcs;
        }

        public double distanceTo(double[] dArr) {
            double d = 0.0d;
            for (int i = 0; i < this.pcs.length; i++) {
                d += Math.pow(dArr[i] - this.pcs[i], 2.0d);
            }
            return Math.sqrt(d);
        }
    }

    public void predictPopulation(String str) {
        HashMap hashMap = new HashMap();
        CsvTableReader csvTableReader = new CsvTableReader(this.samplesFile, '\t');
        while (csvTableReader.next()) {
            String string = csvTableReader.getString("indivID");
            String string2 = csvTableReader.getString("superpopID");
            ReferenceSample referenceSample = new ReferenceSample();
            referenceSample.setLabel(string2);
            hashMap.put(string, referenceSample);
        }
        csvTableReader.close();
        System.out.println("Loaded " + hashMap.size() + " reference samples.");
        Vector vector = new Vector();
        CsvTableReader csvTableReader2 = new CsvTableReader(this.referenceFile, '\t');
        while (csvTableReader2.next()) {
            String string3 = csvTableReader2.getString("indivID");
            double[] dArr = new double[this.maxPcs];
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = csvTableReader2.getDouble("PC" + (i + 1));
            }
            ReferenceSample referenceSample2 = (ReferenceSample) hashMap.get(string3);
            referenceSample2.setPcs(dArr);
            vector.add(referenceSample2);
        }
        csvTableReader2.close();
        System.out.println("Loaded " + this.maxPcs + " PCs for " + vector.size() + "/" + hashMap.size() + " reference samples.");
        CsvTableWriter csvTableWriter = new CsvTableWriter(str, '\t', false);
        String[] strArr = new String[4 + this.maxPcs];
        strArr[0] = "sample";
        strArr[1] = SamplesFile.COLUMN_POPULATION;
        strArr[2] = "voting_popluation";
        strArr[3] = "voting";
        for (int i2 = 0; i2 < this.maxPcs; i2++) {
            strArr[i2 + 4] = "PC" + (i2 + 1);
        }
        csvTableWriter.setColumns(strArr);
        int i3 = 0;
        CsvTableReader csvTableReader3 = new CsvTableReader(this.studyFile, '\t');
        while (csvTableReader3.next()) {
            String string4 = csvTableReader3.getString("indivID");
            double[] dArr2 = new double[this.maxPcs];
            for (int i4 = 0; i4 < dArr2.length; i4++) {
                dArr2[i4] = csvTableReader3.getDouble("PC" + (i4 + 1));
            }
            PredictedPopulation[] voting = getVoting(getNearestNeighbors(vector, dArr2, this.K));
            csvTableWriter.setString("sample", string4);
            if (voting[0].getWeight() >= this.weightThreshold) {
                csvTableWriter.setString(SamplesFile.COLUMN_POPULATION, voting[0].getLabel());
            } else {
                csvTableWriter.setString(SamplesFile.COLUMN_POPULATION, LABEL_UNKNOWN);
            }
            csvTableWriter.setString("voting_popluation", voting[0].getLabel());
            csvTableWriter.setDouble("voting", voting[0].getWeight());
            for (int i5 = 0; i5 < dArr2.length; i5++) {
                csvTableWriter.setDouble("PC" + (i5 + 1), dArr2[i5]);
            }
            csvTableWriter.next();
            i3++;
        }
        csvTableReader3.close();
        csvTableWriter.close();
        System.out.println("Predicted the population for " + i3 + " samples");
    }

    public void setReferenceFile(String str) {
        this.referenceFile = str;
    }

    public void setSamplesFile(String str) {
        this.samplesFile = str;
    }

    public void setStudyFile(String str) {
        this.studyFile = str;
    }

    public void setMaxPcs(int i) {
        this.maxPcs = i;
    }

    public void setK(int i) {
        this.K = i;
    }

    public void setWeightThreshold(double d) {
        this.weightThreshold = d;
    }

    protected PredictedPopulation[] getVoting(Neighbor[] neighborArr) {
        HashMap hashMap = new HashMap();
        for (Neighbor neighbor : neighborArr) {
            String label = neighbor.getSample().getLabel();
            PredictedPopulation predictedPopulation = (PredictedPopulation) hashMap.get(label);
            if (predictedPopulation == null) {
                predictedPopulation = new PredictedPopulation();
                predictedPopulation.setLabel(label);
                hashMap.put(label, predictedPopulation);
            }
            predictedPopulation.addSample(1.0d / neighbor.getDistance());
        }
        PredictedPopulation[] predictedPopulationArr = new PredictedPopulation[hashMap.size()];
        int i = 0;
        double d = 0.0d;
        for (PredictedPopulation predictedPopulation2 : hashMap.values()) {
            predictedPopulationArr[i] = predictedPopulation2;
            d += predictedPopulation2.getWeight();
            i++;
        }
        for (PredictedPopulation predictedPopulation3 : hashMap.values()) {
            predictedPopulation3.setSumWeight(predictedPopulation3.getWeight() / d);
        }
        Arrays.sort(predictedPopulationArr);
        return predictedPopulationArr;
    }

    protected Neighbor[] getNearestNeighbors(Collection<ReferenceSample> collection, double[] dArr, int i) {
        Neighbor[] neighborArr = new Neighbor[collection.size()];
        int i2 = 0;
        for (ReferenceSample referenceSample : collection) {
            neighborArr[i2] = new Neighbor(referenceSample, referenceSample.distanceTo(dArr));
            i2++;
        }
        Arrays.sort(neighborArr);
        Neighbor[] neighborArr2 = new Neighbor[i];
        for (int i3 = 0; i3 < i; i3++) {
            neighborArr2[i3] = neighborArr[i3];
        }
        return neighborArr2;
    }
}
