/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.sempre;

import edu.stanford.nlp.sempre.HasScore;
import edu.stanford.nlp.sempre.ParserAgenda;
import fig.basic.MapUtils;
import fig.basic.NumUtils;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;

public final class ReinforcementUtils {
    private static double logMaxValue = Math.log(Double.MAX_VALUE);

    private ReinforcementUtils() {
    }

    public static void addToDoubleMap(Map<String, Double> mutatedMap, Map<String, Double> addedMap, String prefix) {
        for (String key : addedMap.keySet()) {
            MapUtils.incr(mutatedMap, (Object)(prefix + key), (double)addedMap.get(key));
        }
    }

    public static void subtractFromDoubleMap(Map<String, Double> mutatedMap, Map<String, Double> subtractedMap) {
        for (String key : subtractedMap.keySet()) {
            MapUtils.incr(mutatedMap, (Object)key, (double)(-1.0 * subtractedMap.get(key)));
        }
    }

    public static void subtractFromDoubleMap(Map<String, Double> mutatedMap, Map<String, Double> subtractedMap, String prefix) {
        for (String key : subtractedMap.keySet()) {
            MapUtils.incr(mutatedMap, (Object)(prefix + key), (double)(-1.0 * subtractedMap.get(key)));
        }
    }

    public static Map<String, Double> multiplyDoubleMap(Map<String, Double> map, double factor) {
        HashMap<String, Double> res = new HashMap<String, Double>();
        for (Map.Entry<String, Double> entry : map.entrySet()) {
            res.put(entry.getKey(), entry.getValue() * factor);
        }
        return res;
    }

    public static int sampleIndex(Random rand, List<? extends HasScore> scorables, double denominator) {
        double randD = rand.nextDouble();
        double sum = 0.0;
        for (int i = 0; i < scorables.size(); ++i) {
            HasScore pds = scorables.get(i);
            double prob = ReinforcementUtils.computeProb(pds, denominator);
            if (!(randD < (sum += prob))) continue;
            return i;
        }
        throw new RuntimeException(sum + " < " + randD);
    }

    public static int sampleIndex(Random rand, double[] scores, double denominator) {
        double randD = rand.nextDouble();
        double sum = 0.0;
        for (int i = 0; i < scores.length; ++i) {
            double pds = scores[i];
            double prob = ReinforcementUtils.computeProb(pds, denominator);
            if (!(randD < (sum += prob))) continue;
            return i;
        }
        throw new RuntimeException(sum + " < " + randD);
    }

    public static int sampleIndex(Random rand, double[] probs) {
        double randD = rand.nextDouble();
        double sum = 0.0;
        for (int i = 0; i < probs.length; ++i) {
            if (!(randD < (sum += probs[i]))) continue;
            return i;
        }
        throw new RuntimeException(sum + " < " + randD);
    }

    public static double computeProb(HasScore deriv, double denominator) {
        double prob = Math.exp(deriv.getScore() - denominator);
        if (prob < -1.0E-4 || prob > 1.0001) {
            throw new RuntimeException("Probability is out of range, prob=" + prob + ",score=" + deriv.getScore() + ", denom=" + denominator);
        }
        return prob;
    }

    public static double computeProb(double score, double denominator) {
        double prob = Math.exp(score - denominator);
        if (prob < -1.0E-4 || prob > 1.0001) {
            throw new RuntimeException("Probability is out of range, prob=" + prob + ",score=" + score + ", denom=" + denominator);
        }
        return prob;
    }

    public static double computeLogExpSum(List<? extends HasScore> scorables) {
        double sum = Double.NEGATIVE_INFINITY;
        for (HasScore hasScore : scorables) {
            sum = NumUtils.logAdd((double)sum, (double)hasScore.getScore());
        }
        return sum;
    }

    public static double[] expNormalize(List<? extends HasScore> scorables) {
        int i;
        double[] res = new double[scorables.size()];
        double max = Double.NEGATIVE_INFINITY;
        for (i = 0; i < scorables.size(); ++i) {
            max = Math.max(max, scorables.get(i).getScore());
        }
        if (Double.isInfinite(max)) {
            throw new RuntimeException("Scoreables is probably empty");
        }
        for (i = 0; i < scorables.size(); ++i) {
            res[i] = Math.exp(scorables.get(i).getScore() - max);
        }
        NumUtils.normalize((double[])res);
        return res;
    }

    public static double[] expNormalize(ParserAgenda<? extends HasScore> scorables) {
        double[] res = new double[scorables.size()];
        double max = Double.NEGATIVE_INFINITY;
        for (HasScore hasScore : scorables) {
            max = Math.max(max, hasScore.getScore());
        }
        if (Double.isInfinite(max)) {
            throw new RuntimeException("Scoreables is probably empty");
        }
        int i = 0;
        for (HasScore hasScore : scorables) {
            res[i++] = Math.exp(hasScore.getScore() - max);
        }
        NumUtils.normalize((double[])res);
        return res;
    }

    public static double logSub(double a, double b) {
        if (a <= b) {
            throw new RuntimeException("First argument must be strictly greater than second argument");
        }
        if (Double.isInfinite(b) || a - b > logMaxValue || b - a < 30.0) {
            return a;
        }
        return a + Math.log(1.0 - Math.exp(b - a));
    }
}

