/*
 * Decompiled with CFR 0.152.
 */
package x.y.z;

import de.unima.ki.anyburl.data.Triple;
import de.unima.ki.anyburl.data.TripleSet;
import de.unima.ki.anyburl.eval.CompletionResult;
import de.unima.ki.anyburl.eval.HitsAtK;
import de.unima.ki.anyburl.eval.ResultSet;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import x.y.z.AlphaBeta;
import x.y.z.RescoreLearner;

public class Rescorer {
    private TripleSet train;
    private TripleSet valid;
    private static final int TOP_K = 100;
    private LinkedList<String> relations = new LinkedList();
    private ResultSet basis;
    private HashMap<String, LinkedHashMap<String, Double>> headPredictionsProviderValid;
    private HashMap<String, LinkedHashMap<String, Double>> tailPredictionsProviderValid;
    private HashMap<String, LinkedHashMap<String, Double>> headPredictionsProviderTest;
    private HashMap<String, LinkedHashMap<String, Double>> tailPredictionsProviderTest;
    private ConcurrentHashMap<String, AlphaBeta> relation2AlphaBeta = new ConcurrentHashMap();

    public static void main(String[] args) throws IOException {
        if (args.length != 7) {
            System.out.println("error: please specify 7 arguments in exactly this order seperated blanks");
            System.out.println(" 1) path to training file");
            System.out.println(" 2) path to validation file");
            System.out.println(" 3) path to ranking file created by a rule learner for the validation set");
            System.out.println(" 4) path to rule-based ranking file with scores from kge model for the validation set");
            System.out.println(" 5) path to ranking file created by a rule learner for the test set");
            System.out.println(" 6) path to rule-based ranking file with scores from kge model for the test set");
            System.out.println(" 7) output path whre the aggregated ranking file is stored");
            System.exit(1);
        }
        String training = args[0];
        String valid = args[1];
        String rankingRulesValid = args[2];
        String rankingRulesTest = args[3];
        String rankingKGEValid = args[4];
        String rankingKGETest = args[5];
        String rankingOutput = args[6];
        TripleSet trainTS = new TripleSet(training);
        TripleSet validTS = new TripleSet(valid);
        Rescorer rs = new Rescorer(trainTS, validTS);
        ResultSet rulesValid = new ResultSet("rules valid", rankingRulesValid, true, 100);
        ResultSet kgeValid = new ResultSet("KGE valid", rankingKGEValid, true, 100);
        ResultSet rulesTest = new ResultSet("rules test", rankingRulesTest, true, 100);
        ResultSet kgeTest = new ResultSet("KGE test", rankingKGETest, true, 100);
        rs.searchJoinparameter(rulesValid, kgeValid, 4);
        rs.writeReorderedTestResult(rulesTest, kgeTest, rankingOutput);
    }

    public static void runBatch() throws IOException {
        String[] models;
        String[] stringArray = models = new String[]{"complex", "conve", "distmult", "hitter", "rescal", "transe"};
        int n = models.length;
        int n2 = 0;
        while (n2 < n) {
            String kgeModel = stringArray[n2];
            System.out.println("*******************************");
            System.out.println("*** going for " + kgeModel.toUpperCase() + " ***");
            System.out.println("*******************************");
            TripleSet train = new TripleSet("data/WN18RR/train.txt");
            TripleSet valid = new TripleSet("data/WN18RR/valid.txt");
            int numOfThreads = 4;
            Rescorer rs = new Rescorer(train, valid);
            String frag = "wn18rr/anyburl-c1-3600-100";
            ResultSet basisValid = new ResultSet("anyburl", "exp/understanding/" + frag + "-valid", true, 100);
            ResultSet providerValid = new ResultSet(kgeModel, "exp/understanding/" + frag + "-valid-" + kgeModel, true, 100);
            ResultSet basisTest = new ResultSet("anyburl", "exp/understanding/" + frag + "-test", true, 100);
            ResultSet providerTest = new ResultSet("kgeModel", "exp/understanding/" + frag + "-test-" + kgeModel, true, 100);
            rs.searchJoinparameter(basisValid, providerValid, numOfThreads);
            rs.writeReorderedTestResult(basisTest, providerTest, "exp/understanding/" + frag + "-test-" + kgeModel + "-MM");
            ++n2;
        }
    }

    public Rescorer(TripleSet train, TripleSet valid) {
        this.train = train;
        this.valid = valid;
        this.headPredictionsProviderValid = new HashMap();
        this.tailPredictionsProviderValid = new HashMap();
        this.headPredictionsProviderTest = new HashMap();
        this.tailPredictionsProviderTest = new HashMap();
    }

    public void searchJoinparameter(ResultSet basis, ResultSet scoreProviderValid, int numOfThreads) throws IOException {
        this.basis = basis;
        this.reorderProvider(scoreProviderValid, basis, true);
        System.out.println(">>> search for relation specific parameters");
        this.relations.addAll(this.train.getRelations());
        boolean r = false;
        Thread[] rslearners = new Thread[numOfThreads];
        System.out.print(">>> creating worker threads ");
        int threadCounter = 0;
        while (threadCounter < 4) {
            System.out.print("#" + threadCounter + " ");
            rslearners[threadCounter] = new RescoreLearner(this);
            rslearners[threadCounter].start();
            ++threadCounter;
        }
        System.out.println();
        while (Rescorer.alive(rslearners)) {
            try {
                Thread.sleep(500L);
            }
            catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
        System.out.println(">>> all worker threads are done with their jobs");
    }

    public synchronized String getNextRelation() {
        if (this.relations.size() > 0) {
            String relation = this.relations.poll();
            return relation;
        }
        return null;
    }

    public void searchParameterForRelation(String relation) throws IOException {
        TripleSet trainingSetR = new TripleSet();
        trainingSetR.addTriples(this.train.getTriplesByRelation(relation));
        ResultSet basisR = new ResultSet(this.basis, relation);
        TripleSet validationSetR = new TripleSet();
        validationSetR.addTriples(this.valid.getTriplesByRelation(relation));
        HitsAtK hitsAtK = new HitsAtK();
        hitsAtK.addFilterTripleSet(trainingSetR);
        hitsAtK.addFilterTripleSet(validationSetR);
        double[] beta_values = new double[]{0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0};
        double[] best_beta = new double[2];
        double[] best_mrr = new double[2];
        double[] mrr = new double[2];
        double[] dArray = beta_values;
        int n = beta_values.length;
        int n2 = 0;
        while (n2 < n) {
            double beta = dArray[n2];
            ResultSet rx = this.getRescored(basisR, beta);
            Rescorer.computeScores(rx, validationSetR, hitsAtK);
            mrr = new double[]{hitsAtK.getMRRHeads(), hitsAtK.getMRRTails()};
            int i = 0;
            while (i < 2) {
                if (mrr[i] > best_mrr[i]) {
                    best_beta[i] = beta;
                    best_mrr[i] = mrr[i];
                }
                ++i;
            }
            hitsAtK.reset();
            ++n2;
        }
        this.relation2AlphaBeta.put(String.valueOf(relation) + "+", new AlphaBeta(best_beta[0]));
        this.relation2AlphaBeta.put(String.valueOf(relation) + "-", new AlphaBeta(best_beta[1]));
        System.out.println(">>> " + relation + "(" + this.valid.getTriplesByRelation(relation).size() + ") [+] " + best_mrr[0] + " based on beta=" + best_beta[0] + " [-] " + best_mrr[1] + " based on beta=" + best_beta[1]);
    }

    private static boolean alive(Thread[] threads) {
        Thread[] threadArray = threads;
        int n = threads.length;
        int n2 = 0;
        while (n2 < n) {
            Thread t = threadArray[n2];
            if (t.isAlive()) {
                return true;
            }
            ++n2;
        }
        return false;
    }

    private void reorderProvider(ResultSet scoreProvider, ResultSet basis, boolean validNotTest) {
        System.out.println(">>> index and order the provider result set (" + (validNotTest ? "valid" : "test") + ")");
        for (String triple : scoreProvider.getTriples()) {
            double score;
            String candidate;
            CompletionResult cr = scoreProvider.getCompletionResult(triple);
            LinkedHashMap<String, Double> headPredictions = new LinkedHashMap<String, Double>();
            LinkedHashMap<String, Double> tailPredictions = new LinkedHashMap<String, Double>();
            int i = 0;
            while (i < cr.getHeads().size()) {
                candidate = (String)cr.getHeads().get(i);
                score = (Double)cr.getHeadConfidences().get(i);
                headPredictions.put(candidate, score);
                ++i;
            }
            i = 0;
            while (i < cr.getTails().size()) {
                candidate = (String)cr.getTails().get(i);
                score = (Double)cr.getTailConfidences().get(i);
                tailPredictions.put(candidate, score);
                ++i;
            }
            Rescorer.orderByValueDescending(headPredictions);
            Rescorer.orderByValueDescending(tailPredictions);
            this.normalizeMinMax(headPredictions, basis.getHeadConfidences(triple));
            this.normalizeMinMax(tailPredictions, basis.getTailConfidences(triple));
            if (validNotTest) {
                this.headPredictionsProviderValid.put(triple, headPredictions);
                this.tailPredictionsProviderValid.put(triple, tailPredictions);
                continue;
            }
            this.headPredictionsProviderTest.put(triple, headPredictions);
            this.tailPredictionsProviderTest.put(triple, tailPredictions);
        }
    }

    private void normalizeP(LinkedHashMap<String, Double> predictions) {
        int i = 1;
        for (String c : predictions.keySet()) {
            predictions.put(c, 1.0 / (double)i);
            ++i;
        }
    }

    private void normalize01(LinkedHashMap<String, Double> predictions) {
        double s;
        double max = -1000000.0;
        double min = 1000000.0;
        for (String c : predictions.keySet()) {
            s = predictions.get(c);
            if (max < s) {
                max = s;
            }
            if (!(min > s)) continue;
            min = s;
        }
        for (String c : predictions.keySet()) {
            s = predictions.get(c);
            double normalized = (s - min) / (max - min);
            if (normalized > 1.0) {
                normalized = 1.0;
            }
            if (normalized < 0.0) {
                normalized = 0.0;
            }
            predictions.put(c, normalized);
        }
    }

    private void normalizeMinMax(LinkedHashMap<String, Double> predictions, ArrayList<Double> confidences) {
        double s;
        double targetMax = confidences.size() > 0 ? confidences.get(0) : 0.01;
        double targetMin = confidences.size() == 100 ? confidences.get(99) : 0.0;
        double targetSpan = targetMax - targetMin > 0.0 ? targetMax - targetMin : 0.01;
        double max = -1000000.0;
        double min = 1000000.0;
        for (String c : predictions.keySet()) {
            s = predictions.get(c);
            if (max < s) {
                max = s;
            }
            if (!(min > s)) continue;
            min = s;
        }
        for (String c : predictions.keySet()) {
            s = predictions.get(c);
            double normalized = (s - min) / (max - min);
            if (normalized > 1.0) {
                normalized = 1.0;
            }
            if (normalized < 0.0) {
                normalized = 0.0;
            }
            normalized = normalized * targetSpan + targetMin;
            predictions.put(c, normalized);
        }
    }

    private void writeReorderedTestResult(ResultSet basis, ResultSet scoreProvider, String outputPath) throws FileNotFoundException {
        this.reorderProvider(scoreProvider, basis, false);
        ResultSet rx = this.getRescored(basis);
        rx.write(outputPath);
    }

    private void fillUp(ResultSet rs, ResultSet filler) {
        for (CompletionResult cr : rs) {
            String fcandidate;
            int i;
            CompletionResult fr;
            if (cr.getHeads().size() < 100) {
                System.out.println("Fill up head: " + cr.getTripleAsString() + ": " + cr.getHeads().size());
                HashSet<String> heads = new HashSet<String>();
                for (String candidate : cr.getHeads()) {
                    heads.add(candidate);
                }
                fr = filler.getCompletionResult(cr.getTripleAsString());
                i = 0;
                while (i < 100) {
                    fcandidate = (String)fr.getHeads().get(i);
                    if (!heads.contains(fcandidate)) {
                        cr.getHeads().add(fcandidate);
                        cr.getHeadConfidences().add(0.0);
                    }
                    ++i;
                }
            }
            if (cr.getTails().size() >= 100) continue;
            System.out.println("Fill up tail: " + cr.getTripleAsString() + ": " + cr.getTails().size());
            HashSet<String> tails = new HashSet<String>();
            for (String candidate : cr.getTails()) {
                tails.add(candidate);
            }
            fr = filler.getCompletionResult(cr.getTripleAsString());
            i = 0;
            while (i < 100) {
                fcandidate = (String)fr.getTails().get(i);
                if (!tails.contains(fcandidate)) {
                    cr.getTails().add(fcandidate);
                    cr.getTailConfidences().add(0.0);
                }
                ++i;
            }
        }
    }

    private ResultSet getRescored(ResultSet basis, double beta) {
        ResultSet rs = new ResultSet();
        for (String tripleAsString : basis.getTriples()) {
            CompletionResult thisCr = basis.getCompletionResult(tripleAsString);
            CompletionResult reorderedCr = new CompletionResult(tripleAsString);
            ArrayList<String> reorderedHeads = new ArrayList<String>();
            ArrayList<Double> reorderedHeadConfidences = new ArrayList<Double>();
            this.reorderListsWeighted(reorderedHeads, reorderedHeadConfidences, thisCr.getHeads(), thisCr.getHeadConfidences(), (HashMap<String, Double>)this.headPredictionsProviderValid.get(tripleAsString), beta);
            reorderedCr.setHeads(reorderedHeads);
            reorderedCr.setHeadConfidences(reorderedHeadConfidences);
            ArrayList<String> reorderedTails = new ArrayList<String>();
            ArrayList<Double> reorderedTailConfidences = new ArrayList<Double>();
            this.reorderListsWeighted(reorderedTails, reorderedTailConfidences, thisCr.getTails(), thisCr.getTailConfidences(), (HashMap<String, Double>)this.tailPredictionsProviderValid.get(tripleAsString), beta);
            reorderedCr.setTails(reorderedTails);
            reorderedCr.setTailConfidences(reorderedTailConfidences);
            rs.results.put(tripleAsString, reorderedCr);
        }
        return rs;
    }

    private ResultSet getRescored(ResultSet basis) {
        ResultSet rs = new ResultSet();
        for (String tripleAsString : basis.getTriples()) {
            String r = tripleAsString.split(" ")[1];
            AlphaBeta abH = this.relation2AlphaBeta.get(String.valueOf(r) + "+");
            AlphaBeta abT = this.relation2AlphaBeta.get(String.valueOf(r) + "-");
            CompletionResult thisCr = basis.getCompletionResult(tripleAsString);
            CompletionResult reorderedCr = new CompletionResult(tripleAsString);
            ArrayList<String> reorderedHeads = new ArrayList<String>();
            ArrayList<Double> reorderedHeadConfidences = new ArrayList<Double>();
            this.reorderListsWeighted(reorderedHeads, reorderedHeadConfidences, thisCr.getHeads(), thisCr.getHeadConfidences(), (HashMap<String, Double>)this.headPredictionsProviderTest.get(tripleAsString), abH.beta);
            reorderedCr.setHeads(reorderedHeads);
            reorderedCr.setHeadConfidences(reorderedHeadConfidences);
            ArrayList<String> reorderedTails = new ArrayList<String>();
            ArrayList<Double> reorderedTailConfidences = new ArrayList<Double>();
            this.reorderListsWeighted(reorderedTails, reorderedTailConfidences, thisCr.getTails(), thisCr.getTailConfidences(), (HashMap<String, Double>)this.tailPredictionsProviderTest.get(tripleAsString), abT.beta);
            reorderedCr.setTails(reorderedTails);
            reorderedCr.setTailConfidences(reorderedTailConfidences);
            rs.results.put(tripleAsString, reorderedCr);
        }
        return rs;
    }

    private void reorderListsWeighted(ArrayList<String> reorderedCandidates, ArrayList<Double> reorderedConfidences, ArrayList<String> thisCandidates, ArrayList<Double> thisConfidences, HashMap<String, Double> normalizedProvided, double beta) {
        LinkedHashMap<String, Double> map = new LinkedHashMap<String, Double>();
        int i = 0;
        while (i < thisCandidates.size()) {
            String candidate = thisCandidates.get(i);
            double confidence = thisConfidences.get(i);
            double normalizedScoreProvided = normalizedProvided.get(candidate);
            double score = beta * confidence + (1.0 - beta) * normalizedScoreProvided;
            map.put(candidate, score);
            ++i;
        }
        Rescorer.orderByValueDescending(map);
        for (Map.Entry<String, Double> e : map.entrySet()) {
            reorderedCandidates.add(e.getKey());
            reorderedConfidences.add(e.getValue());
        }
    }

    private static void computeScores(ResultSet rs, TripleSet test, HitsAtK hitsAtK) {
        for (Triple t : test.getTriples()) {
            ArrayList cand1 = rs.getHeadCandidates(t.toString());
            hitsAtK.evaluateHead(cand1, t);
            ArrayList cand2 = rs.getTailCandidates(t.toString());
            hitsAtK.evaluateTail(cand2, t);
        }
    }

    private static void orderByValueDescending(LinkedHashMap<String, Double> m) {
        ArrayList<Map.Entry<String, Double>> entries = new ArrayList<Map.Entry<String, Double>>(m.entrySet());
        Collections.sort(entries, new Comparator<Map.Entry<String, Double>>(){

            @Override
            public int compare(Map.Entry<String, Double> lhs, Map.Entry<String, Double> rhs) {
                if (lhs.getValue() - rhs.getValue() > 0.0) {
                    return -1;
                }
                if (lhs.getValue() - rhs.getValue() == 0.0) {
                    return 0;
                }
                return 1;
            }
        });
        m.clear();
        for (Map.Entry entry : entries) {
            m.put((String)entry.getKey(), (Double)entry.getValue());
        }
    }
}

