/*
 * Decompiled with CFR 0.152.
 */
package genepi.riskscore.tasks;

import genepi.io.table.writer.CsvTableWriter;
import genepi.riskscore.io.Chunk;
import genepi.riskscore.io.OutputFileWriter;
import genepi.riskscore.io.ReportFile;
import genepi.riskscore.io.RiskScoreFile;
import genepi.riskscore.io.SamplesFile;
import genepi.riskscore.io.VariantFile;
import genepi.riskscore.io.formats.RiskScoreFormatFactory;
import genepi.riskscore.io.vcf.FastVCFFileReader;
import genepi.riskscore.io.vcf.MinimalVariantContext;
import genepi.riskscore.model.ReferenceVariant;
import genepi.riskscore.model.RiskScore;
import genepi.riskscore.model.RiskScoreSummary;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Vector;
import lukfor.progress.tasks.ITaskRunnable;
import lukfor.progress.tasks.monitors.ITaskMonitor;
import lukfor.progress.util.CountingInputStream;

public class ApplyScoreTask
implements ITaskRunnable {
    private List<RiskScore> riskScores;
    private String vcf = null;
    private String[] riskScoreFilenames = null;
    private int countSamples = 0;
    private int countVariants = 0;
    private Chunk chunk = null;
    private float minR2 = 0.0f;
    private String outputVariantFilename = null;
    private String includeVariantFilename = null;
    private String includeSamplesFilename = null;
    private String outputReportFilename = null;
    private RiskScoreFormatFactory.RiskScoreFormat defaultFormat = RiskScoreFormatFactory.RiskScoreFormat.AUTO_DETECT;
    private Map<String, RiskScoreFormatFactory.RiskScoreFormat> formats = new HashMap<String, RiskScoreFormatFactory.RiskScoreFormat>();
    private String genotypeFormat = "DS";
    private int numberRiskScores = 0;
    private RiskScoreSummary[] summaries;
    private String output;
    private String outputEffectsFilename;
    private String dbsnp = null;
    private String proxies;
    private boolean fixStrandFlips = false;
    private boolean removeAmbiguous = false;
    private boolean inverseDosage = false;
    private boolean averaging = false;
    public static final String INFO_R2 = "R2";
    public static final String DOSAGE_FORMAT = "DS";
    public static boolean VERBOSE = false;
    public static final Map<Character, Character> ALLELE_SWITCHES = new HashMap<Character, Character>();

    public void setRiskScoreFilenames(String ... filenames) {
        this.riskScoreFilenames = filenames;
        for (String filename : filenames) {
            this.formats.put(filename, this.defaultFormat);
        }
    }

    public void setChunk(Chunk chunk) {
        this.chunk = chunk;
    }

    public void setVcfFilename(String vcf) {
        this.vcf = vcf;
    }

    public void setOutputVariantFilename(String outputVariantFilename) {
        this.outputVariantFilename = outputVariantFilename;
    }

    public void setIncludeVariantFilename(String includeVariantFilename) {
        this.includeVariantFilename = includeVariantFilename;
    }

    public void setIncludeSamplesFilename(String includeSamplesFilename) {
        this.includeSamplesFilename = includeSamplesFilename;
    }

    public void setGenotypeFormat(String genotypeFormat) {
        this.genotypeFormat = genotypeFormat;
    }

    public void setOutput(String output) {
        this.output = output;
    }

    public void setOutputReportFilename(String outputReportFilename) {
        this.outputReportFilename = outputReportFilename;
    }

    public void setOutputEffectsFilename(String outputEffectsFilename) {
        this.outputEffectsFilename = outputEffectsFilename;
    }

    public void setDbSnp(String dbsnp) {
        this.dbsnp = dbsnp;
    }

    public void setProxies(String proxies) {
        this.proxies = proxies;
    }

    public void run(ITaskMonitor monitor) throws Exception {
        if (this.vcf == null || this.vcf.isEmpty()) {
            throw new Exception("Please specify a vcf file.");
        }
        if (this.output == null || this.output.isEmpty()) {
            throw new Exception("Please specify a output filename.");
        }
        if (this.riskScoreFilenames == null || this.riskScoreFilenames.length == 0) {
            throw new Exception("Reference can not be null or empty.");
        }
        try {
            String chromosome = null;
            FastVCFFileReader vcfReader = new FastVCFFileReader(this.vcf);
            if (!vcfReader.next()) {
                vcfReader.close();
                throw new Exception("VCF file is empty.");
            }
            chromosome = vcfReader.get().getContig();
            vcfReader.close();
            String taskName = "[Chr " + (chromosome.length() == 1 ? "0" : "") + chromosome + "]";
            monitor.begin(taskName, new File(this.vcf).length());
            monitor.worked(0L);
            this.numberRiskScores = this.riskScoreFilenames.length;
            this.summaries = new RiskScoreSummary[this.numberRiskScores];
            for (int i = 0; i < this.numberRiskScores; ++i) {
                String name = RiskScoreFile.getName(this.riskScoreFilenames[i]);
                this.summaries[i] = new RiskScoreSummary(name);
            }
            RiskScoreFile[] riskscores = this.loadReferenceFiles(monitor, chromosome, this.dbsnp, this.proxies, this.riskScoreFilenames);
            boolean empty = true;
            for (RiskScoreFile riskscore : riskscores) {
                if (riskscore.getLoadedVariants() <= 0) continue;
                empty = false;
                break;
            }
            this.processVCF(monitor, chromosome, this.vcf, riskscores, empty);
            OutputFileWriter outputFile = new OutputFileWriter(this.riskScores, this.summaries);
            outputFile.save(this.output);
            if (this.outputReportFilename != null) {
                ReportFile reportFile = new ReportFile(this.summaries);
                reportFile.save(this.outputReportFilename);
            }
            monitor.done();
        }
        catch (Exception e) {
            if (VERBOSE) {
                System.out.println("ERROR:");
                e.printStackTrace();
            }
            throw e;
        }
        catch (Error e) {
            if (VERBOSE) {
                System.out.println("ERROR:");
                e.printStackTrace();
            }
            throw new Exception(e);
        }
    }

    private RiskScoreFile[] loadReferenceFiles(ITaskMonitor monitor, String chromosome, String dbsnp, String proxies, String ... riskScoreFilenames) throws Exception {
        RiskScoreFile[] riskscores = new RiskScoreFile[this.numberRiskScores];
        for (int i = 0; i < this.numberRiskScores; ++i) {
            this.debug("Loading file " + riskScoreFilenames[i] + "...");
            RiskScoreFormatFactory.RiskScoreFormat format = this.formats.get(riskScoreFilenames[i]);
            RiskScoreFile riskscore = new RiskScoreFile(riskScoreFilenames[i], format, dbsnp, proxies);
            if (this.chunk != null) {
                riskscore.buildIndex(chromosome, this.chunk);
            } else {
                riskscore.buildIndex(chromosome);
            }
            this.summaries[i].setVariants(riskscore.getTotalVariants());
            this.summaries[i].setVariantsIgnored(riskscore.getIgnoredVariants());
            this.debug("Loaded " + riskscore.getLoadedVariants() + " weights for chromosome " + chromosome);
            riskscores[i] = riskscore;
            monitor.worked(0L);
        }
        return riskscores;
    }

    private void processVCF(ITaskMonitor monitor, String chromosome, String vcfFilename, RiskScoreFile[] riskscores, boolean empty) throws Exception {
        this.debug("Loading file " + vcfFilename + "...");
        VariantFile includeVariants = null;
        if (this.includeVariantFilename != null) {
            this.debug("Loading file " + this.includeVariantFilename + "...");
            includeVariants = new VariantFile(this.includeVariantFilename);
            includeVariants.buildIndex(chromosome);
            this.debug("Loaded " + includeVariants.getCacheSize() + " variants for chromosome " + chromosome);
        }
        SamplesFile samplesFile = null;
        if (this.includeSamplesFilename != null) {
            samplesFile = new SamplesFile(this.includeSamplesFilename);
            samplesFile.buildIndex();
        }
        CountingInputStream countingStream = new CountingInputStream((InputStream)new FileInputStream(vcfFilename), monitor);
        FastVCFFileReader vcfReader = new FastVCFFileReader((InputStream)countingStream, vcfFilename);
        this.countSamples = vcfReader.getGenotypedSamples().size();
        this.riskScores = new Vector<RiskScore>();
        for (int i = 0; i < this.countSamples; ++i) {
            String sample = vcfReader.getGenotypedSamples().get(i);
            if (samplesFile != null && !samplesFile.contains(sample)) continue;
            RiskScore riskScore = new RiskScore(chromosome, sample, this.riskScoreFilenames.length);
            this.riskScores.add(riskScore);
        }
        boolean outOfChunk = false;
        CsvTableWriter variantsWriter = null;
        if (this.outputVariantFilename != null) {
            variantsWriter = new CsvTableWriter(this.outputVariantFilename, '\t');
            variantsWriter.setColumns(new String[]{"score", "chr_name", "chr_position", "r2", "INCLUDE"});
        }
        CsvTableWriter effectsWriter = null;
        if (this.outputEffectsFilename != null) {
            effectsWriter = new CsvTableWriter(this.outputEffectsFilename, ',');
            effectsWriter.setColumns(new String[]{"score", "sample", "chr_name", "chr_position", "effect"});
        }
        int proxy = 0;
        while (vcfReader.next() && !outOfChunk && !empty) {
            if (monitor.isCanceled()) {
                return;
            }
            MinimalVariantContext variant = vcfReader.get();
            ++this.countVariants;
            if (!variant.getContig().equals(chromosome)) {
                vcfReader.close();
                throw new Exception("Different chromosomes found in file.");
            }
            int position = variant.getStart();
            if (this.chunk != null) {
                if (position < this.chunk.getStart()) continue;
                if (position > this.chunk.getEnd()) {
                    outOfChunk = true;
                    continue;
                }
            }
            for (int j = 0; j < this.riskScoreFilenames.length; ++j) {
                RiskScoreSummary summary = this.summaries[j];
                RiskScoreFile riskscore = riskscores[j];
                boolean isPartOfRiskScore = riskscore.contains(position);
                if (!isPartOfRiskScore) {
                    summary.incNotFound();
                    continue;
                }
                if (includeVariants != null && !includeVariants.contains(summary.getName(), position)) {
                    summary.incFiltered();
                    continue;
                }
                double r2 = variant.getInfoAsDouble(INFO_R2, 0.0);
                if (r2 < (double)this.minR2) {
                    summary.incR2Filtered();
                    continue;
                }
                ReferenceVariant referenceVariant = riskscore.getVariant(position);
                float effectWeight = referenceVariant.getEffectWeight();
                String referenceAllele = variant.getReferenceAllele();
                if (variant.getAlternateAllele().length() == 0) {
                    summary.incMultiAllelic();
                    continue;
                }
                String[] alternateAlleles = variant.getAlternateAllele().split(",");
                if (alternateAlleles.length > 1) {
                    summary.incMultiAllelic();
                    continue;
                }
                String alternateAllele = alternateAlleles[0];
                if (this.removeAmbiguous && variant.isAmbigous()) {
                    summary.incAmbiguous();
                    continue;
                }
                if (!referenceVariant.hasAllele(referenceAllele) || !referenceVariant.hasAllele(alternateAllele)) {
                    if (!this.fixStrandFlips) {
                        summary.incAlleleMissmatch();
                        continue;
                    }
                    String flippedReferenceAllele = ApplyScoreTask.flip(referenceAllele);
                    String flippedAlternateAllele = ApplyScoreTask.flip(alternateAllele);
                    if (!(!variant.isAmbigous() || referenceVariant.hasAllele(flippedReferenceAllele) && referenceVariant.hasAllele(flippedAlternateAllele))) {
                        summary.incAlleleMissmatch();
                        continue;
                    }
                    referenceAllele = flippedReferenceAllele;
                    alternateAllele = flippedAlternateAllele;
                    summary.incFlipped();
                }
                boolean switched = false;
                if (!referenceVariant.isEffectAllele(alternateAllele)) {
                    if (referenceVariant.isEffectAllele(referenceAllele)) {
                        effectWeight = -effectWeight;
                        switched = true;
                        summary.incSwitched();
                    } else {
                        summary.incAlleleMissmatch();
                        continue;
                    }
                }
                if (referenceVariant.isUsed()) continue;
                if (!this.averaging && variant.hasMissingGenotypes(this.genotypeFormat)) {
                    summary.incMissingGenotypes();
                    continue;
                }
                referenceVariant.setUsed(true);
                if (variantsWriter != null) {
                    variantsWriter.setString("score", summary.getName());
                    variantsWriter.setString("chr_name", variant.getContig());
                    variantsWriter.setInteger("chr_position", variant.getStart());
                    variantsWriter.setDouble("r2", variant.getInfoAsDouble(INFO_R2, 0.0));
                    variantsWriter.setInteger("INCLUDE", 1);
                    variantsWriter.next();
                }
                float[] dosages = variant.getGenotypeDosages(this.genotypeFormat);
                int indexSample = 0;
                for (int i = 0; i < this.countSamples; ++i) {
                    float dosage;
                    String sample = vcfReader.getGenotypedSamples().get(i);
                    if (samplesFile != null && !samplesFile.contains(sample) || !((dosage = dosages[i]) >= 0.0f)) continue;
                    double effect = 0.0;
                    effect = this.inverseDosage && switched ? (double)((2.0f - dosage) * -effectWeight) : (double)(dosage * effectWeight);
                    this.riskScores.get(indexSample).incScore(j, effect);
                    ++indexSample;
                    if (effectsWriter == null) continue;
                    effectsWriter.setString("score", summary.getName());
                    effectsWriter.setString("sample", sample);
                    effectsWriter.setString("chr_name", variant.getContig());
                    effectsWriter.setInteger("chr_position", variant.getStart());
                    effectsWriter.setDouble("effect", effect);
                    effectsWriter.next();
                }
                summary.incVariantsUsed();
            }
        }
        if (variantsWriter != null) {
            for (int j = 0; j < this.riskScoreFilenames.length; ++j) {
                RiskScoreSummary summary = this.summaries[j];
                RiskScoreFile riskscore = riskscores[j];
                for (Map.Entry<Integer, ReferenceVariant> item : riskscore.getVariants().entrySet()) {
                    ReferenceVariant variant = item.getValue();
                    int position = item.getKey();
                    if (this.chunk != null) {
                        if (position < this.chunk.getStart()) continue;
                        if (position > this.chunk.getEnd()) {
                            outOfChunk = true;
                            continue;
                        }
                    }
                    if (variant.isUsed()) continue;
                    variantsWriter.setString("score", summary.getName());
                    variantsWriter.setString("chr_name", chromosome);
                    variantsWriter.setInteger("chr_position", position);
                    variantsWriter.setString("r2", "");
                    variantsWriter.setInteger("INCLUDE", 0);
                    variantsWriter.next();
                }
            }
            variantsWriter.close();
        }
        if (effectsWriter != null) {
            effectsWriter.close();
        }
        vcfReader.close();
        this.debug("Used " + proxy + " proxies");
        this.debug("Loaded " + this.countSamples + " samples and " + this.countVariants + " variants.");
    }

    public void setMinR2(float minR2) {
        this.minR2 = minR2;
    }

    public void setDefaultRiskScoreFormat(RiskScoreFormatFactory.RiskScoreFormat defaultFormat) {
        this.defaultFormat = defaultFormat;
        if (this.riskScoreFilenames != null) {
            for (String file : this.riskScoreFilenames) {
                this.setRiskScoreFormat(file, defaultFormat);
            }
        }
    }

    public void setRiskScoreFormat(String file, RiskScoreFormatFactory.RiskScoreFormat format) {
        this.formats.put(file, format);
    }

    public int getCountSamples() {
        return this.countSamples;
    }

    public RiskScoreSummary[] getSummaries() {
        return this.summaries;
    }

    int getCountVariants() {
        return this.countVariants;
    }

    public String getOutput() {
        return this.output;
    }

    public String getOutputReportFilename() {
        return this.outputReportFilename;
    }

    public String getOutputEffectsFilename() {
        return this.outputEffectsFilename;
    }

    public String getOutputVariantFilename() {
        return this.outputVariantFilename;
    }

    public void debug(String text) {
        if (VERBOSE) {
            System.out.println(text);
        }
    }

    public void setFixStrandFlips(boolean fixStrandFlips) {
        this.fixStrandFlips = fixStrandFlips;
    }

    public void setRemoveAmbiguous(boolean removeAmbiguous) {
        this.removeAmbiguous = removeAmbiguous;
    }

    protected static String flip(String allele) {
        String flippedAllele = "";
        for (int i = 0; i < allele.length(); ++i) {
            Character flipped = ALLELE_SWITCHES.get(Character.valueOf(allele.charAt(i)));
            flippedAllele = flippedAllele + flipped;
        }
        return flippedAllele;
    }

    public void setInverseDosage(boolean inverseDosage) {
        this.inverseDosage = inverseDosage;
    }

    static {
        ALLELE_SWITCHES.put(Character.valueOf('A'), Character.valueOf('T'));
        ALLELE_SWITCHES.put(Character.valueOf('T'), Character.valueOf('A'));
        ALLELE_SWITCHES.put(Character.valueOf('G'), Character.valueOf('C'));
        ALLELE_SWITCHES.put(Character.valueOf('C'), Character.valueOf('G'));
    }
}

