/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.sempre;

import edu.stanford.nlp.sempre.AtomicSemType;
import edu.stanford.nlp.sempre.ConstantFn;
import edu.stanford.nlp.sempre.Derivation;
import edu.stanford.nlp.sempre.DerivationStream;
import edu.stanford.nlp.sempre.Example;
import edu.stanford.nlp.sempre.FeatureExtractor;
import edu.stanford.nlp.sempre.FeatureVector;
import edu.stanford.nlp.sempre.Formula;
import edu.stanford.nlp.sempre.Formulas;
import edu.stanford.nlp.sempre.JoinFormula;
import edu.stanford.nlp.sempre.LambdaFormula;
import edu.stanford.nlp.sempre.MultipleDerivationStream;
import edu.stanford.nlp.sempre.SemType;
import edu.stanford.nlp.sempre.SemTypeHierarchy;
import edu.stanford.nlp.sempre.SemanticFn;
import edu.stanford.nlp.sempre.TopSemType;
import edu.stanford.nlp.sempre.TypeInference;
import edu.stanford.nlp.sempre.UnionSemType;
import fig.basic.LispTree;
import fig.basic.LogInfo;
import fig.basic.Option;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

public class JoinFn
extends SemanticFn {
    public static Options opts = new Options();
    private boolean unaryFirst = false;
    private boolean unaryCanBeArg0 = false;
    private boolean unaryCanBeArg1 = false;
    private boolean betaReduce = false;
    private ConstantFn arg0Fn = null;

    public ConstantFn getArg0Fn() {
        return this.arg0Fn;
    }

    @Override
    public void init(LispTree tree) {
        super.init(tree);
        for (int j = 1; j < tree.children.size(); ++j) {
            String arg = ((LispTree)tree.child((int)j)).value;
            if (((LispTree)tree.child(j)).isLeaf()) {
                switch (arg) {
                    case "binary,unary": {
                        this.unaryFirst = false;
                        break;
                    }
                    case "unary,binary": {
                        this.unaryFirst = true;
                        break;
                    }
                    case "unaryCanBeArg0": {
                        this.unaryCanBeArg0 = true;
                        break;
                    }
                    case "unaryCanBeArg1": {
                        this.unaryCanBeArg1 = true;
                        break;
                    }
                    case "forward": {
                        this.unaryFirst = false;
                        this.unaryCanBeArg1 = true;
                        break;
                    }
                    case "backward": {
                        this.unaryFirst = true;
                        this.unaryCanBeArg1 = true;
                        break;
                    }
                    case "betaReduce": {
                        this.betaReduce = true;
                        break;
                    }
                    default: {
                        throw new RuntimeException("Invalid argument: " + arg);
                    }
                }
                continue;
            }
            if ("arg0".equals(((LispTree)((LispTree)tree.child((int)j)).child((int)0)).value)) {
                this.arg0Fn = new ConstantFn();
                this.arg0Fn.init((LispTree)tree.child(j));
                continue;
            }
            throw new RuntimeException("Invalid argument: " + tree.child(j));
        }
        if (!this.unaryCanBeArg0 && !this.unaryCanBeArg1) {
            throw new RuntimeException("At least one of unaryCanBeArg0 and unaryCanBeArg1 must be set");
        }
    }

    @Override
    public DerivationStream call(Example ex, SemanticFn.Callable c) {
        return new LazyJoinFnDerivs(ex, c);
    }

    public class LazyJoinFnDerivs
    extends MultipleDerivationStream {
        private int currIndex = 0;
        private List<Derivation> derivations = new ArrayList<Derivation>();
        private Example ex;
        private SemanticFn.Callable callable;
        Derivation unaryDeriv;
        Derivation binaryDeriv;

        public LazyJoinFnDerivs(Example ex, SemanticFn.Callable c) {
            Derivation child1;
            Derivation child0;
            this.ex = ex;
            this.callable = c;
            if (JoinFn.this.arg0Fn != null) {
                if (c.getChildren().size() != 1) {
                    throw new RuntimeException("Expected one argument (already have " + JoinFn.this.arg0Fn + "), but got args: " + c.getChildren());
                }
                DerivationStream ld = JoinFn.this.arg0Fn.call(ex, SemanticFn.CallInfo.NULL_INFO);
                child0 = (Derivation)ld.next();
                child1 = c.child(0);
            } else {
                if (c.getChildren().size() != 2) {
                    throw new RuntimeException("Expected two arguments, but got: " + c.getChildren());
                }
                child0 = c.child(0);
                child1 = c.child(1);
            }
            if (JoinFn.this.unaryFirst) {
                this.unaryDeriv = child0;
                this.binaryDeriv = child1;
            } else {
                this.binaryDeriv = child0;
                this.unaryDeriv = child1;
            }
        }

        @Override
        public int estimatedSize() {
            return 2;
        }

        @Override
        public Derivation createDerivation() {
            if (this.currIndex == 0) {
                this.doJoins(this.binaryDeriv, this.unaryDeriv);
            }
            if (this.currIndex == this.derivations.size()) {
                return null;
            }
            return this.derivations.get(this.currIndex++);
        }

        SemType specializedTypeCheck(SemType binaryType, SemType unaryType) {
            SemType argType = binaryType.getArgType();
            if (unaryType instanceof TopSemType) {
                return SemType.bottomType;
            }
            if (unaryType instanceof AtomicSemType) {
                unaryType = new UnionSemType(unaryType);
            }
            if (unaryType instanceof UnionSemType && argType instanceof AtomicSemType) {
                for (SemType t : ((UnionSemType)unaryType).baseTypes) {
                    if (!(t instanceof AtomicSemType) || !SemTypeHierarchy.singleton.getSupertypes(((AtomicSemType)t).name).contains(((AtomicSemType)argType).name)) continue;
                    return binaryType.getRetType();
                }
                return SemType.bottomType;
            }
            return binaryType.apply(unaryType);
        }

        private Derivation doJoin(Derivation binaryDeriv, Formula binaryFormula, SemType binaryType, Derivation unaryDeriv, Formula unaryFormula, SemType unaryType, String featureDesc) {
            Formula f;
            SemType type;
            SemType semType = type = JoinFn.opts.specializedTypeCheck ? this.specializedTypeCheck(binaryType, unaryType) : binaryType.apply(unaryType);
            if (!type.isValid()) {
                if (JoinFn.opts.showTypeCheckFailures) {
                    LogInfo.warnings((String)"JoinFn: type check failed: [%s : %s] JOIN [%s : %s]", (Object[])new Object[]{binaryFormula, binaryType, unaryFormula, unaryType});
                }
                return null;
            }
            if (JoinFn.this.betaReduce) {
                if (!(binaryFormula instanceof LambdaFormula)) {
                    throw new RuntimeException("Expected LambdaFormula as the binary, but got: " + binaryFormula + ", unary is " + unaryFormula);
                }
                f = Formulas.lambdaApply((LambdaFormula)binaryFormula, unaryFormula);
            } else {
                f = new JoinFormula(binaryFormula, unaryFormula);
            }
            if (JoinFn.opts.typeInference) {
                SemType fullType = TypeInference.inferType(f);
                if (JoinFn.opts.verbose >= 2) {
                    LogInfo.logs((String)"JoinFn.typeInference: %s => %s [coarse type = %s]", (Object[])new Object[]{f, fullType, type});
                }
                if (!fullType.isValid()) {
                    return null;
                }
                type = fullType;
            }
            if (JoinFn.opts.verbose >= 3) {
                LogInfo.logs((String)"JoinFn: binary: %s [%s], unary: %s [%s], result: %s [%s]", (Object[])new Object[]{binaryFormula, binaryType, unaryFormula, unaryType, f, type});
            }
            FeatureVector features = new FeatureVector();
            if (FeatureExtractor.containsDomain("joinPos") && featureDesc != null) {
                features.add("joinPos", featureDesc);
            }
            Derivation newDeriv = new Derivation.Builder().withCallable(this.callable).formula(f).type(type).localFeatureVector(features).createDerivation();
            if (SemanticFn.opts.trackLocalChoices) {
                newDeriv.addLocalChoice("JoinFn " + (binaryDeriv.start == -1 ? "-" : binaryDeriv.startEndString(this.ex.getTokens())) + " " + binaryDeriv.formula + " AND " + (unaryDeriv.start == -1 ? "-" : unaryDeriv.startEndString(this.ex.getTokens())) + " " + unaryDeriv.formula);
            }
            return newDeriv;
        }

        private void doJoins(Derivation binaryDeriv, Derivation unaryDeriv) {
            Derivation deriv;
            String binaryPos = this.ex.languageInfo.getCanonicalPos(binaryDeriv.start);
            String unaryPos = this.ex.languageInfo.getCanonicalPos(unaryDeriv.start);
            if (JoinFn.this.unaryCanBeArg0 && (deriv = this.doJoin(binaryDeriv, Formulas.reverseFormula(binaryDeriv.formula), binaryDeriv.type.reverse(), unaryDeriv, unaryDeriv.formula, unaryDeriv.type, "binary=" + binaryPos + ",unary=" + unaryPos + "_reverse")) != null) {
                this.derivations.add(deriv);
            }
            if (JoinFn.this.unaryCanBeArg1 && (deriv = this.doJoin(binaryDeriv, binaryDeriv.formula, binaryDeriv.type, unaryDeriv, unaryDeriv.formula, unaryDeriv.type, "binary=" + binaryPos + ",unary=" + unaryPos)) != null) {
                this.derivations.add(deriv);
            }
            Collections.sort(this.derivations, Derivation.derivScoreComparator);
        }
    }

    public static class Options {
        @Option(gloss="Verbose")
        public int verbose = 0;
        @Option
        public boolean showTypeCheckFailures = false;
        @Option
        public boolean typeInference = true;
        @Option
        public boolean specializedTypeCheck = false;
    }
}

