/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.sempre.test;

import edu.stanford.nlp.sempre.Params;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import org.testng.AssertJUnit;
import org.testng.annotations.Test;

public class L1RegularizationTest {
    private static final double EPSILON = 0.001;
    private Options originalOptions = null;

    private void saveOptions() {
        this.originalOptions = new Options().initStepSize(Params.opts.initStepSize).l1Reg(Params.opts.l1Reg).l1RegCoeff(Params.opts.l1RegCoeff);
    }

    private void loadOptions(Options options) {
        Params.opts.initStepSize = options.initStepSize;
        Params.opts.l1Reg = options.l1Reg;
        Params.opts.l1RegCoeff = options.l1RegCoeff;
    }

    private Map<String, Double> constructGradient(double a, double b, double c, double d) {
        HashMap<String, Double> gradient = new HashMap<String, Double>();
        if (a != 0.0) {
            gradient.put("a", a);
        }
        if (b != 0.0) {
            gradient.put("b", b);
        }
        if (c != 0.0) {
            gradient.put("c", c);
        }
        if (d != 0.0) {
            gradient.put("d", d);
        }
        return gradient;
    }

    @Test
    public void zeroLazyL1Test() {
        this.saveOptions();
        this.loadOptions(new Options().l1Reg("none").l1RegCoeff(0.0));
        Params params = new Params();
        AssertJUnit.assertEquals((double)0.0, (double)params.getWeight("a"), (double)0.001);
        AssertJUnit.assertEquals((double)0.0, (double)params.getWeight("b"), (double)0.001);
        params.update(this.constructGradient(1.0, 0.0, 0.0, 0.0));
        AssertJUnit.assertEquals((double)(1.0 / Math.sqrt(1.0)), (double)params.getWeight("a"), (double)0.001);
        AssertJUnit.assertEquals((double)0.0, (double)params.getWeight("b"), (double)0.001);
        params.update(this.constructGradient(1.0, 0.0, 0.0, 0.0));
        AssertJUnit.assertEquals((double)(1.0 / Math.sqrt(1.0) + 1.0 / Math.sqrt(2.0)), (double)params.getWeight("a"), (double)0.001);
        AssertJUnit.assertEquals((double)0.0, (double)params.getWeight("b"), (double)0.001);
        params.update(this.constructGradient(0.0, -2.0, 0.0, 0.0));
        AssertJUnit.assertEquals((double)(1.0 / Math.sqrt(1.0) + 1.0 / Math.sqrt(2.0)), (double)params.getWeight("a"), (double)0.001);
        AssertJUnit.assertEquals((double)(-2.0 / Math.sqrt(4.0)), (double)params.getWeight("b"), (double)0.001);
        params.update(this.constructGradient(1.0, 2.0, 0.0, 0.0));
        AssertJUnit.assertEquals((double)(1.0 / Math.sqrt(1.0) + 1.0 / Math.sqrt(2.0) + 1.0 / Math.sqrt(3.0)), (double)params.getWeight("a"), (double)0.001);
        AssertJUnit.assertEquals((double)(-2.0 / Math.sqrt(4.0) + 2.0 / Math.sqrt(8.0)), (double)params.getWeight("b"), (double)0.001);
        this.loadOptions(new Options().l1Reg("nonlazy").l1RegCoeff(0.0));
        params = new Params();
        AssertJUnit.assertEquals((double)0.0, (double)params.getWeight("a"), (double)0.001);
        AssertJUnit.assertEquals((double)0.0, (double)params.getWeight("b"), (double)0.001);
        params.update(this.constructGradient(1.0, 0.0, 0.0, 0.0));
        AssertJUnit.assertEquals((double)(1.0 / Math.sqrt(2.0)), (double)params.getWeight("a"), (double)0.001);
        AssertJUnit.assertEquals((double)0.0, (double)params.getWeight("b"), (double)0.001);
        params.update(this.constructGradient(1.0, 0.0, 0.0, 0.0));
        AssertJUnit.assertEquals((double)(1.0 / Math.sqrt(2.0) + 1.0 / Math.sqrt(3.0)), (double)params.getWeight("a"), (double)0.001);
        AssertJUnit.assertEquals((double)0.0, (double)params.getWeight("b"), (double)0.001);
        params.update(this.constructGradient(0.0, -2.0, 0.0, 0.0));
        AssertJUnit.assertEquals((double)(1.0 / Math.sqrt(2.0) + 1.0 / Math.sqrt(3.0)), (double)params.getWeight("a"), (double)0.001);
        AssertJUnit.assertEquals((double)(-2.0 / Math.sqrt(5.0)), (double)params.getWeight("b"), (double)0.001);
        params.update(this.constructGradient(1.0, 2.0, 0.0, 0.0));
        AssertJUnit.assertEquals((double)(1.0 / Math.sqrt(2.0) + 1.0 / Math.sqrt(3.0) + 1.0 / Math.sqrt(4.0)), (double)params.getWeight("a"), (double)0.001);
        AssertJUnit.assertEquals((double)(-2.0 / Math.sqrt(5.0) + 2.0 / Math.sqrt(9.0)), (double)params.getWeight("b"), (double)0.001);
        this.loadOptions(new Options().l1Reg("lazy").l1RegCoeff(0.0));
        params = new Params();
        AssertJUnit.assertEquals((double)0.0, (double)params.getWeight("a"), (double)0.001);
        AssertJUnit.assertEquals((double)0.0, (double)params.getWeight("b"), (double)0.001);
        params.update(this.constructGradient(1.0, 0.0, 0.0, 0.0));
        AssertJUnit.assertEquals((double)(1.0 / Math.sqrt(2.0)), (double)params.getWeight("a"), (double)0.001);
        AssertJUnit.assertEquals((double)0.0, (double)params.getWeight("b"), (double)0.001);
        params.update(this.constructGradient(1.0, 0.0, 0.0, 0.0));
        AssertJUnit.assertEquals((double)(1.0 / Math.sqrt(2.0) + 1.0 / Math.sqrt(3.0)), (double)params.getWeight("a"), (double)0.001);
        AssertJUnit.assertEquals((double)0.0, (double)params.getWeight("b"), (double)0.001);
        params.update(this.constructGradient(0.0, -2.0, 0.0, 0.0));
        AssertJUnit.assertEquals((double)(1.0 / Math.sqrt(2.0) + 1.0 / Math.sqrt(3.0)), (double)params.getWeight("a"), (double)0.001);
        AssertJUnit.assertEquals((double)(-2.0 / Math.sqrt(5.0)), (double)params.getWeight("b"), (double)0.001);
        params.update(this.constructGradient(1.0, 2.0, 0.0, 0.0));
        AssertJUnit.assertEquals((double)(1.0 / Math.sqrt(2.0) + 1.0 / Math.sqrt(3.0) + 1.0 / Math.sqrt(4.0)), (double)params.getWeight("a"), (double)0.001);
        AssertJUnit.assertEquals((double)(-2.0 / Math.sqrt(5.0) + 2.0 / Math.sqrt(9.0)), (double)params.getWeight("b"), (double)0.001);
        this.loadOptions(this.originalOptions);
    }

    @Test
    public void nonZeroLazyL1Test() {
        this.saveOptions();
        this.loadOptions(new Options().l1Reg("nonlazy").l1RegCoeff(1.0));
        Params params = new Params();
        AssertJUnit.assertEquals((double)0.0, (double)params.getWeight("a"), (double)0.001);
        AssertJUnit.assertEquals((double)0.0, (double)params.getWeight("b"), (double)0.001);
        params.update(this.constructGradient(2.0, 0.0, -3.14, 0.0));
        AssertJUnit.assertEquals((double)(1.0 / Math.sqrt(5.0)), (double)params.getWeight("a"), (double)0.001);
        AssertJUnit.assertEquals((double)0.0, (double)params.getWeight("b"), (double)0.001);
        params.update(this.constructGradient(1.0, 0.0, 0.0, 0.0));
        AssertJUnit.assertEquals((double)(1.0 / Math.sqrt(5.0)), (double)params.getWeight("a"), (double)0.001);
        AssertJUnit.assertEquals((double)0.0, (double)params.getWeight("b"), (double)0.001);
        params.update(this.constructGradient(0.0, -2.0, 0.0, 0.0));
        AssertJUnit.assertEquals((double)(1.0 / Math.sqrt(5.0) - 1.0 / Math.sqrt(6.0)), (double)params.getWeight("a"), (double)0.001);
        AssertJUnit.assertEquals((double)(-1.0 / Math.sqrt(5.0)), (double)params.getWeight("b"), (double)0.001);
        params.update(this.constructGradient(1.0, 2.0, 0.0, 0.0));
        AssertJUnit.assertEquals((double)(1.0 / Math.sqrt(5.0) - 1.0 / Math.sqrt(6.0)), (double)params.getWeight("a"), (double)0.001);
        AssertJUnit.assertEquals((double)0.0, (double)params.getWeight("b"), (double)0.001);
        params.update(this.constructGradient(0.0, 3.0, 0.0, 0.0));
        AssertJUnit.assertEquals((double)0.0, (double)params.getWeight("a"), (double)0.001);
        AssertJUnit.assertEquals((double)(2.0 / Math.sqrt(18.0)), (double)params.getWeight("b"), (double)0.001);
        params.update(this.constructGradient(0.0, 0.0, 0.0, 0.0));
        AssertJUnit.assertEquals((double)0.0, (double)params.getWeight("a"), (double)0.001);
        AssertJUnit.assertEquals((double)(1.0 / Math.sqrt(18.0)), (double)params.getWeight("b"), (double)0.001);
        params.update(this.constructGradient(-5.0, 0.0, 1.0, 0.0));
        AssertJUnit.assertEquals((double)(-4.0 / Math.sqrt(32.0)), (double)params.getWeight("a"), (double)0.001);
        AssertJUnit.assertEquals((double)0.0, (double)params.getWeight("b"), (double)0.001);
        params.update(this.constructGradient(0.0, 0.0, -1.0, 0.0));
        AssertJUnit.assertEquals((double)(-3.0 / Math.sqrt(32.0)), (double)params.getWeight("a"), (double)0.001);
        AssertJUnit.assertEquals((double)0.0, (double)params.getWeight("b"), (double)0.001);
        AssertJUnit.assertEquals((double)0.0, (double)params.getWeight("c"), (double)0.001);
        Random r = new Random(42L);
        for (double t = 1.0; t > 0.0; t -= 0.02) {
            this.loadOptions(new Options().l1Reg("lazy").l1RegCoeff(1.0));
            Params params2 = new Params();
            if (r.nextDouble() < t) {
                AssertJUnit.assertEquals((double)0.0, (double)params2.getWeight("a"), (double)0.001);
            }
            if (r.nextDouble() < t) {
                AssertJUnit.assertEquals((double)0.0, (double)params2.getWeight("b"), (double)0.001);
            }
            params2.update(this.constructGradient(2.0, 0.0, -3.14, 0.0));
            if (r.nextDouble() < t) {
                AssertJUnit.assertEquals((double)(1.0 / Math.sqrt(5.0)), (double)params2.getWeight("a"), (double)0.001);
            }
            if (r.nextDouble() < t) {
                AssertJUnit.assertEquals((double)0.0, (double)params2.getWeight("b"), (double)0.001);
            }
            params2.update(this.constructGradient(1.0, 0.0, 0.0, 0.0));
            if (r.nextDouble() < t) {
                AssertJUnit.assertEquals((double)(1.0 / Math.sqrt(5.0)), (double)params2.getWeight("a"), (double)0.001);
            }
            if (r.nextDouble() < t) {
                AssertJUnit.assertEquals((double)0.0, (double)params2.getWeight("b"), (double)0.001);
            }
            params2.update(this.constructGradient(0.0, -2.0, 0.0, 0.0));
            if (r.nextDouble() < t) {
                AssertJUnit.assertEquals((double)(1.0 / Math.sqrt(5.0) - 1.0 / Math.sqrt(6.0)), (double)params2.getWeight("a"), (double)0.001);
            }
            if (r.nextDouble() < t) {
                AssertJUnit.assertEquals((double)(-1.0 / Math.sqrt(5.0)), (double)params2.getWeight("b"), (double)0.001);
            }
            params2.update(this.constructGradient(1.0, 2.0, 0.0, 0.0));
            if (r.nextDouble() < t) {
                AssertJUnit.assertEquals((double)(1.0 / Math.sqrt(5.0) - 1.0 / Math.sqrt(6.0)), (double)params2.getWeight("a"), (double)0.001);
            }
            if (r.nextDouble() < t) {
                AssertJUnit.assertEquals((double)0.0, (double)params2.getWeight("b"), (double)0.001);
            }
            params2.update(this.constructGradient(0.0, 3.0, 0.0, 0.0));
            if (r.nextDouble() < t) {
                AssertJUnit.assertEquals((double)0.0, (double)params2.getWeight("a"), (double)0.001);
            }
            if (r.nextDouble() < t) {
                AssertJUnit.assertEquals((double)(2.0 / Math.sqrt(18.0)), (double)params2.getWeight("b"), (double)0.001);
            }
            params2.update(this.constructGradient(0.0, 0.0, 0.0, 0.0));
            if (r.nextDouble() < t) {
                AssertJUnit.assertEquals((double)0.0, (double)params2.getWeight("a"), (double)0.001);
            }
            if (r.nextDouble() < t) {
                AssertJUnit.assertEquals((double)(1.0 / Math.sqrt(18.0)), (double)params2.getWeight("b"), (double)0.001);
            }
            params2.update(this.constructGradient(-5.0, 0.0, 1.0, 0.0));
            if (r.nextDouble() < t) {
                AssertJUnit.assertEquals((double)(-4.0 / Math.sqrt(32.0)), (double)params2.getWeight("a"), (double)0.001);
            }
            if (r.nextDouble() < t) {
                AssertJUnit.assertEquals((double)0.0, (double)params2.getWeight("b"), (double)0.001);
            }
            params2.update(this.constructGradient(0.0, 0.0, -1.0, 0.0));
            if (r.nextDouble() < t) {
                AssertJUnit.assertEquals((double)(-3.0 / Math.sqrt(32.0)), (double)params2.getWeight("a"), (double)0.001);
            }
            if (r.nextDouble() < t) {
                AssertJUnit.assertEquals((double)0.0, (double)params2.getWeight("b"), (double)0.001);
            }
            if (!(r.nextDouble() < t)) continue;
            AssertJUnit.assertEquals((double)0.0, (double)params2.getWeight("c"), (double)0.001);
        }
        this.loadOptions(this.originalOptions);
    }

    class Options {
        public double initStepSize = 1.0;
        public String l1Reg = "none";
        public double l1RegCoeff = 0.0;

        Options() {
        }

        public Options initStepSize(double x) {
            this.initStepSize = x;
            return this;
        }

        public Options l1Reg(String x) {
            this.l1Reg = x;
            return this;
        }

        public Options l1RegCoeff(double x) {
            this.l1RegCoeff = x;
            return this;
        }
    }
}

