/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.utils.molecules.preprocessing;

import cz.cvut.fel.ida.logic.Clause;
import cz.cvut.fel.ida.logic.Constant;
import cz.cvut.fel.ida.logic.Literal;
import cz.cvut.fel.ida.logic.Term;
import cz.cvut.fel.ida.logic.io.PseudoPrologParser;
import cz.cvut.fel.ida.utils.generic.tuples.Pair;
import cz.cvut.fel.ida.utils.math.StringUtils;
import cz.cvut.fel.ida.utils.math.Sugar;
import cz.cvut.fel.ida.utils.molecules.preprocessing.molecules.Atom;
import cz.cvut.fel.ida.utils.molecules.preprocessing.molecules.Bond;
import cz.cvut.fel.ida.utils.molecules.preprocessing.molecules.Molecule;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.Reader;
import java.io.Writer;
import java.nio.file.Paths;
import java.text.NumberFormat;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;

public class ConvertMol2ToPsPr {
    private static final Logger LOG = Logger.getLogger(ConvertMol2ToPsPr.class.getName());
    public static String defaultMolName = "molecule";
    public static String defaultPredictionTargetName = "predict";
    static Set<String> allAtomTypes = new HashSet<String>();
    static Set<String> allBondTypes = new HashSet<String>();
    public static int embeddingDim = 3;
    static String atomName = "a";
    static String bondName = "b";

    public static double parseDoubleProper(String numberString) throws ParseException {
        NumberFormat format = NumberFormat.getInstance(Locale.getDefault());
        Number number = format.parse(numberString);
        return number.doubleValue();
    }

    public static void convertMol2InDir(String dirPath, String filename) throws IOException, ParseException {
        FileReader mol2Reader = new FileReader(Paths.get(dirPath, filename + ".mol2").toFile());
        FileReader idsReader = null;
        try {
            idsReader = new FileReader(Paths.get(dirPath, filename + ".ids").toFile());
        }
        catch (FileNotFoundException fileNotFoundException) {
            // empty catch block
        }
        FileReader queryReader = null;
        PrintWriter queriesWriter = null;
        try {
            queryReader = new FileReader(Paths.get(dirPath, filename + ".q").toFile());
            File queriesFile = Paths.get(dirPath, filename + "Queries.txt").toFile();
            queriesWriter = ConvertMol2ToPsPr.getWriter(queriesFile);
        }
        catch (FileNotFoundException queriesFile) {
            // empty catch block
        }
        PrintWriter examplesWriter = ConvertMol2ToPsPr.getWriter(Paths.get(dirPath, filename + "Examples.txt").toFile());
        if (queryReader == null) {
            ConvertMol2ToPsPr.convertMol2ToProlog(mol2Reader, examplesWriter);
        } else if (idsReader == null) {
            ConvertMol2ToPsPr.convertMol2ToProlog(mol2Reader, examplesWriter);
            ConvertMol2ToPsPr.convertClassLabels(queryReader, queriesWriter);
        } else {
            ConvertMol2ToPsPr.convertMol2ToProlog(mol2Reader, idsReader, queryReader, examplesWriter);
        }
        ConvertMol2ToPsPr.exportLRNNembeddings(dirPath);
    }

    protected static PrintWriter getWriter(File path) throws IOException {
        File examplesFile = path;
        examplesFile.getParentFile().mkdirs();
        examplesFile.createNewFile();
        return new PrintWriter(examplesFile);
    }

    private static void exportLRNNembeddings(String dirPath) throws IOException {
        PrintWriter atoms = ConvertMol2ToPsPr.getWriter(new File(dirPath + "/atomEmbeddings"));
        for (String atomType : allAtomTypes) {
            atoms.println("{" + embeddingDim + ",1} atom_embed(A) :- " + atomType + "(A).");
        }
        atoms.println("atom_embed/1 {" + embeddingDim + ",1}");
        atoms.flush();
        atoms.close();
        PrintWriter bonds = ConvertMol2ToPsPr.getWriter(new File(dirPath + "/bondEmbeddings"));
        for (String bondType : allBondTypes) {
            bonds.println("{" + embeddingDim + ",1} bond_embed(B) :- " + bondType + "(B).");
        }
        bonds.println("bond_embed/1 {" + embeddingDim + ",1}");
        bonds.flush();
        bonds.close();
    }

    public static void convertMol2ToProlog(FileReader mol2Reader, PrintWriter examplesWriter) throws IOException, ParseException {
        List<Molecule> molecules = ConvertMol2ToPsPr.readMolecules(mol2Reader);
        for (Molecule mol : molecules) {
            examplesWriter.println(ConvertMol2ToPsPr.moleculeToLRNNClause(mol) + ".");
        }
        examplesWriter.flush();
        examplesWriter.close();
    }

    @Deprecated
    public static void convertMol2ToProlog(Reader mol2Reader, Reader idsReader, Reader classLabelsReader, PrintWriter writer) throws IOException, ParseException {
        Map<String, String> classLabels = ConvertMol2ToPsPr.readClassLabels(idsReader, classLabelsReader);
        PrintWriter pw = new PrintWriter(writer);
        for (Molecule mol : ConvertMol2ToPsPr.readMolecules(mol2Reader)) {
            pw.println(classLabels.get(mol.getName()) + " " + ConvertMol2ToPsPr.moleculeToClause(mol));
        }
        pw.flush();
        pw.close();
    }

    public static Clause moleculeToLRNNClause(Molecule molecule) {
        HashSet<Literal> literals = new HashSet<Literal>();
        for (Atom atom : molecule.atoms()) {
            String atomType = Constant.construct(atom.getType()).name().replaceAll("\\.", "_").toLowerCase();
            literals.add(new Literal(atomType, Constant.construct(atomName + atom.getName())));
            literals.add(new Literal("charge", Constant.construct(atomName + atom.getName()), Constant.construct(String.valueOf(atom.getCharge()))));
            allAtomTypes.add(atomType);
        }
        for (Bond bond : molecule.bonds()) {
            String bondType = bondName + "_" + Constant.construct(bond.getType()).name();
            literals.add(new Literal("bond", Constant.construct(atomName + bond.getA().getName()), Constant.construct(atomName + bond.getB().getName()), Constant.construct(bondName + bond.getBondId() + "l")));
            literals.add(new Literal(bondType, Constant.construct(bondName + bond.getBondId() + "l")));
            literals.add(new Literal("bond", Constant.construct(atomName + bond.getB().getName()), Constant.construct(atomName + bond.getA().getName()), Constant.construct(bondName + bond.getBondId() + "r")));
            literals.add(new Literal(bondType, Constant.construct(bondName + bond.getBondId() + "r")));
            allBondTypes.add(bondType);
        }
        return new Clause(literals);
    }

    @Deprecated
    public static Clause moleculeToClause(Molecule molecule) {
        HashSet<Literal> literals = new HashSet<Literal>();
        for (Atom atom : molecule.atoms()) {
            literals.add(new Literal("atm", Constant.construct(atom.getName()), Constant.construct(atom.getType()), Constant.construct(String.valueOf(atom.getCharge()))));
        }
        for (Bond bond : molecule.bonds()) {
            literals.add(new Literal("bond", Constant.construct(bond.getA().getName()), Constant.construct(bond.getA().getName()), Constant.construct(bond.getB().getName()), Constant.construct(bond.getB().getName()), Constant.construct(bond.getType())));
            literals.add(new Literal("bond", Constant.construct(bond.getB().getName()), Constant.construct(bond.getB().getName()), Constant.construct(bond.getA().getName()), Constant.construct(bond.getA().getName()), Constant.construct(bond.getType())));
        }
        return new Clause(literals);
    }

    public static void convertClassLabels(Reader idsReader, PrintWriter queriesWriter) throws IOException {
        List<String> lines = Sugar.readLines(idsReader);
        lines.forEach(line -> queriesWriter.println(line + " " + defaultPredictionTargetName + "."));
        queriesWriter.flush();
        queriesWriter.close();
    }

    public static Map<String, String> readClassLabels(Reader idsReader, Reader classLabelsReader) throws IOException {
        Iterator<String> iter1;
        HashMap<String, String> retVal = new HashMap<String, String>();
        if (idsReader == null) {
            iter1 = new Iterator<String>(){
                private int counter = 1;

                @Override
                public boolean hasNext() {
                    return true;
                }

                @Override
                public String next() {
                    return defaultMolName + " " + this.counter++;
                }
            };
        } else {
            List<String> ids = Sugar.readLines(idsReader);
            iter1 = ids.iterator();
        }
        List<String> classLabels = Sugar.readLines(classLabelsReader);
        Iterator<String> iter2 = classLabels.iterator();
        while (iter1.hasNext() && iter2.hasNext()) {
            retVal.put(iter1.next(), iter2.next());
        }
        return retVal;
    }

    public static List<Molecule> readMolecules(Reader mol2Reader) throws IOException, ParseException {
        boolean counter = false;
        String line = null;
        BufferedReader br1 = new BufferedReader(mol2Reader);
        LinkedHashMap<String, Molecule> molecules = new LinkedHashMap<String, Molecule>();
        Molecule mol = null;
        boolean OTHER = false;
        boolean AFTER_TRIPOS_MOLECULE = true;
        int AFTER_TRIPOS_ATOM = 2;
        int AFTER_TRIPOS_BOND = 3;
        int state = 0;
        while ((line = br1.readLine()) != null) {
            String[] splitted;
            if ((line = line.trim()).length() <= 0) continue;
            if (line.startsWith("@<TRIPOS>BOND")) {
                state = 3;
                continue;
            }
            if (line.startsWith("@<TRIPOS>ATOM")) {
                state = 2;
                continue;
            }
            if (line.startsWith("@<TRIPOS>MOLECULE")) {
                if (mol != null) {
                    molecules.put(mol.getName(), mol);
                }
                state = 1;
                continue;
            }
            if (state == 2) {
                splitted = line.split("[ ]+");
                Atom atom = new Atom(splitted[0], splitted[5], ConvertMol2ToPsPr.parseDoubleProper(splitted[8]), ConvertMol2ToPsPr.parseDoubleProper(splitted[2]), ConvertMol2ToPsPr.parseDoubleProper(splitted[3]), ConvertMol2ToPsPr.parseDoubleProper(splitted[4]));
                mol.addAtom(atom);
                continue;
            }
            if (state == 1) {
                mol = new Molecule(line);
                state = 0;
                continue;
            }
            if (state != 3) continue;
            splitted = line.split("[ ]+");
            Bond bond = new Bond(splitted[0], mol.getAtom(splitted[1]), mol.getAtom(splitted[2]), splitted[3]);
            mol.addBond(bond);
        }
        molecules.put(mol.getName(), mol);
        return new ArrayList<Molecule>(molecules.values());
    }

    public static void convertPsPr2Tilde(Reader reader, Writer writer) throws IOException {
        List<Pair<Clause, String>> examples = PseudoPrologParser.read(reader);
        PrintWriter pw = new PrintWriter(writer);
        int index = 0;
        for (Pair<Clause, String> pair : examples) {
            pw.println("begin(model(example_" + index + ")).");
            if (((String)pair.s).equals("+") || ((String)pair.s).equals("+1")) {
                pw.println("positive.");
            } else if (((String)pair.s).equals("-") || ((String)pair.s).equals("-1")) {
                pw.println("negative.");
            } else {
                pw.println((String)pair.s + ".");
            }
            ArrayList<String> literals = new ArrayList<String>();
            for (Literal l : ((Clause)pair.r).literals()) {
                Literal newLit = new Literal(l.predicate().name, l.arity());
                for (int i = 0; i < l.arity(); ++i) {
                    if (StringUtils.isNumeric(l.get(i).name())) {
                        newLit.set(l.get(i), i);
                        continue;
                    }
                    newLit.set((Term)Constant.construct(l.get(i).name().toLowerCase().replaceAll("\\.", "")), i);
                }
                literals.add(newLit + ".");
            }
            Collections.sort(literals);
            for (String line : literals) {
                pw.println(line);
            }
            pw.println("end(model(example_" + index + ")).");
            ++index;
        }
        pw.flush();
    }
}

