/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.sempre;

import com.google.common.base.Splitter;
import com.google.common.collect.Lists;
import fig.basic.IOUtils;
import fig.basic.LogInfo;
import fig.basic.MapUtils;
import fig.basic.Option;
import fig.basic.Pair;
import fig.basic.ValueComparator;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;

public class Params {
    public static Options opts = new Options();
    private L1Reg l1Reg;
    private Map<String, Double> weights;
    Map<String, Double> sumSquaredGradients;
    Map<String, Double> sumGradients;
    int numUpdates;
    Map<String, Integer> l1UpdateTimeMap;

    public Params() {
        this.l1Reg = this.parseReg(Params.opts.l1Reg);
        this.weights = new HashMap<String, Double>();
        this.sumSquaredGradients = new HashMap<String, Double>();
        this.sumGradients = new HashMap<String, Double>();
        this.l1UpdateTimeMap = new HashMap<String, Integer>();
    }

    private L1Reg parseReg(String l1Reg) {
        if ("lazy".equals(l1Reg)) {
            return L1Reg.LAZY;
        }
        if ("nonlazy".equals(l1Reg)) {
            return L1Reg.NONLAZY;
        }
        if ("none".equals(l1Reg)) {
            return L1Reg.NONE;
        }
        throw new RuntimeException("not legal l1reg");
    }

    public void init(List<Pair<String, Double>> initialization) {
        if (!this.weights.isEmpty()) {
            throw new RuntimeException("Initialization is not legal when there are non-zero weights");
        }
        for (Pair<String, Double> pair : initialization) {
            this.weights.put((String)pair.getFirst(), (Double)pair.getSecond());
        }
    }

    public void read(String path) {
        LogInfo.begin_track((String)"Reading parameters from %s", (Object[])new Object[]{path});
        try {
            String line;
            BufferedReader in = IOUtils.openIn((String)path);
            while ((line = in.readLine()) != null) {
                String[] pair = Lists.newArrayList((Iterable)Splitter.on((char)'\t').split((CharSequence)line)).toArray(new String[2]);
                this.weights.put(pair[0], Double.parseDouble(pair[1]));
            }
            in.close();
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        LogInfo.logs((String)"Read %s weights", (Object[])new Object[]{this.weights.size()});
        LogInfo.end_track();
    }

    public void read(String path, String prefix) {
        LogInfo.begin_track((String)"Reading parameters from %s", (Object[])new Object[]{path});
        try {
            String line;
            BufferedReader in = IOUtils.openIn((String)path);
            while ((line = in.readLine()) != null) {
                String[] pair = Lists.newArrayList((Iterable)Splitter.on((char)'\t').split((CharSequence)line)).toArray(new String[2]);
                this.weights.put(pair[0], Double.parseDouble(pair[1]));
                this.weights.put(prefix + pair[0], Double.parseDouble(pair[1]));
            }
            in.close();
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        LogInfo.logs((String)"Read %s weights", (Object[])new Object[]{this.weights.size()});
        LogInfo.end_track();
    }

    public synchronized void update(Map<String, Double> gradient) {
        for (Map.Entry<String, Double> entry : gradient.entrySet()) {
            String f = entry.getKey();
            double g = entry.getValue();
            if (g * g == 0.0) continue;
            if (this.l1Reg == L1Reg.LAZY) {
                this.lazyL1Update(f);
            }
            double stepSize = this.computeStepSize(f, g);
            if (Params.opts.dualAveraging) {
                if (!Params.opts.adaptiveStepSize && Params.opts.stepSizeReduction != 0.0) {
                    throw new RuntimeException("Dual averaging not supported when step-size changes across iterations for features for which the gradient is zero");
                }
                MapUtils.incr(this.sumGradients, (Object)f, (double)g);
                MapUtils.set(this.weights, (Object)f, (Object)(stepSize * this.sumGradients.get(f)));
                continue;
            }
            if (stepSize * g == Double.POSITIVE_INFINITY || stepSize * g == Double.NEGATIVE_INFINITY) {
                LogInfo.logs((String)"WEIRD FEATURE UPDATE: feature=%s, currentWeight=%s, stepSize=%s, gradient=%s", (Object[])new Object[]{f, this.getWeight(f), stepSize, g});
                throw new RuntimeException("Gradient absolute value is too large or too small");
            }
            MapUtils.incr(this.weights, (Object)f, (double)(stepSize * g));
            if (this.l1Reg != L1Reg.LAZY) continue;
            this.l1UpdateTimeMap.put(f, this.numUpdates);
        }
        if (this.l1Reg == L1Reg.NONLAZY) {
            HashSet<String> features = new HashSet<String>(this.weights.keySet());
            for (String f : features) {
                double stepSize = this.computeStepSize(f, 0.0);
                double update = Params.opts.l1RegCoeff * -Math.signum(MapUtils.getDouble(this.weights, (Object)f, (double)Params.opts.defaultWeight));
                this.clipUpdate(f, stepSize * update);
            }
        }
        ++this.numUpdates;
        if (this.l1Reg == L1Reg.LAZY && Params.opts.lazyL1FullUpdateFreq > 0 && this.numUpdates % Params.opts.lazyL1FullUpdateFreq == 0) {
            LogInfo.begin_track((String)"Fully apply L1 regularization.", (Object[])new Object[0]);
            this.finalizeWeights();
            System.gc();
            LogInfo.end_track();
        }
    }

    private double computeStepSize(String feature, double gradient) {
        if (Params.opts.adaptiveStepSize) {
            MapUtils.incr(this.sumSquaredGradients, (Object)feature, (double)(gradient * gradient));
            if (this.l1Reg != L1Reg.NONE) {
                return Params.opts.initStepSize / Math.sqrt(this.sumSquaredGradients.get(feature) + 1.0);
            }
            return Params.opts.initStepSize / Math.sqrt(this.sumSquaredGradients.get(feature));
        }
        return Params.opts.initStepSize / Math.pow(this.numUpdates, Params.opts.stepSizeReduction);
    }

    private void clipUpdate(String f, double update) {
        double currWeight = MapUtils.getDouble(this.weights, (Object)f, (double)0.0);
        if (currWeight == 0.0) {
            return;
        }
        if (currWeight * (currWeight + update) < 0.0) {
            this.weights.remove(f);
        } else {
            MapUtils.incr(this.weights, (Object)f, (double)update);
        }
    }

    private void lazyL1Update(String f) {
        if (MapUtils.getDouble(this.weights, (Object)f, (double)0.0) == 0.0) {
            return;
        }
        if (this.sumSquaredGradients.get(f) == null || this.l1UpdateTimeMap.get(f) == null) {
            this.l1UpdateTimeMap.put(f, this.numUpdates);
            this.sumSquaredGradients.put(f, 0.0);
            return;
        }
        int numOfIter = this.numUpdates - (Integer)MapUtils.get(this.l1UpdateTimeMap, (Object)f, (Object)0);
        if (numOfIter == 0) {
            return;
        }
        if (numOfIter < 0) {
            throw new RuntimeException("l1UpdateTimeMap is out of sync.");
        }
        double stepSize = (double)numOfIter * Params.opts.initStepSize / Math.sqrt(this.sumSquaredGradients.get(f) + 1.0);
        double update = -Params.opts.l1RegCoeff * Math.signum(MapUtils.getDouble(this.weights, (Object)f, (double)0.0));
        this.clipUpdate(f, stepSize * update);
        if (this.weights.containsKey(f)) {
            this.l1UpdateTimeMap.put(f, this.numUpdates);
        } else {
            this.l1UpdateTimeMap.remove(f);
        }
    }

    public synchronized double getWeight(String f) {
        if (this.l1Reg == L1Reg.LAZY) {
            this.lazyL1Update(f);
        }
        if (Params.opts.initWeightsRandomly) {
            return MapUtils.getDouble(this.weights, (Object)f, (double)(2.0 * Params.opts.initRandom.nextDouble() - 1.0));
        }
        return MapUtils.getDouble(this.weights, (Object)f, (double)Params.opts.defaultWeight);
    }

    public synchronized Map<String, Double> getWeights() {
        this.finalizeWeights();
        return this.weights;
    }

    public void write(PrintWriter out) {
        this.write(null, out);
    }

    public void write(String prefix, PrintWriter out) {
        ArrayList entries = Lists.newArrayList(this.weights.entrySet());
        Collections.sort(entries, new ValueComparator(true));
        for (Map.Entry entry : entries) {
            double value = (Double)entry.getValue();
            out.println((prefix == null ? "" : prefix + "\t") + (String)entry.getKey() + "\t" + value);
        }
    }

    public void write(String path) {
        LogInfo.begin_track((String)"Params.write(%s)", (Object[])new Object[]{path});
        PrintWriter out = IOUtils.openOutHard((String)path);
        this.write(out);
        out.close();
        LogInfo.end_track();
    }

    public void log() {
        LogInfo.begin_track((String)"Params", (Object[])new Object[0]);
        ArrayList entries = Lists.newArrayList(this.weights.entrySet());
        Collections.sort(entries, new ValueComparator(true));
        for (Map.Entry entry : entries) {
            double value = (Double)entry.getValue();
            LogInfo.logs((String)"%s\t%s", (Object[])new Object[]{entry.getKey(), value});
        }
        LogInfo.end_track();
    }

    public synchronized void finalizeWeights() {
        if (this.l1Reg == L1Reg.LAZY) {
            HashSet<String> features = new HashSet<String>(this.weights.keySet());
            for (String f : features) {
                this.lazyL1Update(f);
            }
        }
    }

    public Params copyParams() {
        Params result = new Params();
        for (String feature : this.getWeights().keySet()) {
            result.weights.put(feature, this.getWeight(feature));
        }
        return result;
    }

    public Params copyParamsByPrefix(String prefix) {
        Params result = new Params();
        for (String feature : this.getWeights().keySet()) {
            if (!feature.startsWith(prefix)) continue;
            String newFeature = feature.substring(prefix.length());
            result.weights.put(newFeature, this.getWeight(feature));
        }
        return result;
    }

    public boolean isEmpty() {
        return this.weights.size() == 0;
    }

    public Params getRandomWeightParams() {
        Random rand = new Random();
        Params result = new Params();
        for (String feature : this.getWeights().keySet()) {
            result.weights.put(feature, 2.0 * rand.nextDouble() - 1.0);
        }
        return result;
    }

    public static enum L1Reg {
        LAZY,
        NONLAZY,
        NONE;

    }

    public static class Options {
        @Option(gloss="By default, all features have this weight")
        public double defaultWeight = 0.0;
        @Option(gloss="Randomly initialize the weights")
        public boolean initWeightsRandomly = false;
        @Option(gloss="Randomly initialize the weights")
        public Random initRandom = new Random(1L);
        @Option(gloss="Initial step size")
        public double initStepSize = 1.0;
        @Option(gloss="How fast to reduce the step size")
        public double stepSizeReduction = 0.0;
        @Option(gloss="Use the AdaGrad algorithm (different step size for each coordinate)")
        public boolean adaptiveStepSize = true;
        @Option(gloss="Use dual averaging")
        public boolean dualAveraging = false;
        @Option(gloss="Whether to do lazy l1 reg updates")
        public String l1Reg = "none";
        @Option(gloss="L1 reg coefficient")
        public double l1RegCoeff = 0.0;
        @Option(gloss="Lazy L1 full update frequency")
        public int lazyL1FullUpdateFreq = 5000;
    }
}

