import os.path

from lark import Lark
from .exp_parser import ExpTree
from .language import Language, DataType
from .logic import Predicate, NeuralPredicate, FuncSymbol, Const


class DataUtils(object):
    """Utilities about logic.

    A class of utilities about first-order logic.

    Args:
        dataset_type (str): A dataset type (kandinsky or clevr).
        dataset (str): A dataset to be used.

    Attrs:
        base_path: The base path of the dataset.
    """

    def __init__(self, lark_path, lang_base_path, dataset_type='kandinsky', dataset='twopairs'):
        self.base_path = lang_base_path + dataset_type + '/' + dataset + '/'
        with open(lark_path, encoding="utf-8") as grammar:
            self.lp_atom = Lark(grammar.read(), start="atom")
        with open(lark_path, encoding="utf-8") as grammar:
            self.lp_clause = Lark(grammar.read(), start="clause")

    def load_clauses(self, path, lang):
        """Read lines and parse to Atom objects.
        """
        clauses = []
        if os.path.isfile(path):
            with open(path) as f:
                for line in f:
                    if line[-1] == '\n':
                        line = line[:-1]
                    tree = self.lp_clause.parse(line)
                    clause = ExpTree(lang).transform(tree)
                    clauses.append(clause)
        return clauses

    def load_atoms(self, path, lang):
        """Read lines and parse to Atom objects.
        """
        atoms = []

        if os.path.isfile(path):
            with open(path) as f:
                for line in f:
                    if line[-1] == '\n':
                        line = line[:-2]
                    else:
                        line = line[:-1]
                    tree = self.lp_atom.parse(line)
                    atom = ExpTree(lang).transform(tree)
                    atoms.append(atom)
        return atoms

    def load_preds(self, path):
        f = open(path)
        lines = f.readlines()
        preds = [self.parse_pred(line) for line in lines]
        return preds

    def load_neural_preds(self, path):
        f = open(path)
        lines = f.readlines()
        preds = [self.parse_neural_pred(line) for line in lines]
        return preds

    def load_consts(self, path):
        f = open(path)
        lines = f.readlines()
        consts = []
        for line in lines:
            consts.extend(self.parse_const(line))
        return consts
    def load_consts_mi(self, path):
        f = open(path)
        lines = f.readlines()
        consts = []
        for line in lines:

            consts.extend(self.parse_const_mi(line))
        return consts

    def load_funcs(self, path):
        funcs = []
        if os.path.isfile(path):
            with open(path) as f:
                lines = f.readlines()
                for line in lines:
                    funcs.append(self.parse_func(line))
        return funcs

    def parse_pred(self, line):
        """Parse string to predicates.
        """
        line = line.replace('\n', '')
        pred, arity, dtype_names_str = line.split(':')
        dtype_names = dtype_names_str.split(',')
        dtypes = [DataType(dt) for dt in dtype_names]
        assert int(arity) == len(
            dtypes), 'Invalid arity and dtypes in ' + pred + '.'
        return Predicate(pred, int(arity), dtypes)

    def parse_neural_pred(self, line):
        """Parse string to predicates.
        """
        line = line.replace('\n', '')
        pred, arity, dtype_names_str = line.split(':')
        dtype_names = dtype_names_str.split(',')
        dtypes = [DataType(dt) for dt in dtype_names]
        assert int(arity) == len(
            dtypes), 'Invalid arity and dtypes in ' + pred + '.'
        return NeuralPredicate(pred, int(arity), dtypes)

    def parse_funcs(self, line):
        """Parse string to function symbols.
        """
        funcs = []
        for func_arity in line.split(','):
            func, arity = func_arity.split(':')
            funcs.append(FuncSymbol(func, int(arity)))
        return funcs
    def parse_func(self, line):
        """Parse string to function symbols.
        """
        name, arity, in_dtypes, out_dtype = line.replace("\n", "").split(':')
        in_dtypes = in_dtypes.split(',')
        in_dtypes = [DataType(in_dtype) for in_dtype in in_dtypes]
        out_dtype = DataType(out_dtype)
        return FuncSymbol(name, int(arity), in_dtypes, out_dtype)

    def parse_const(self, line):
        """Parse string to function symbols.
        """
        line = line.replace('\n', '')
        dtype_name, const_names_str = line.split(':')
        dtype = DataType(dtype_name)
        const_names = const_names_str.split(',')
        return [Const(const_name, dtype) for const_name in const_names]

    def parse_const_mi(self, line):
        """Parse string to function symbols.
        """
        line = line.replace('\n', '')
        line = line[:-1]
        dtype_name, const_names_str = line.split(':')
        dtype = DataType(dtype_name)
        const_names = const_names_str.split(';')
        return [Const(const_name, dtype) for const_name in const_names]
    def parse_clause(self, clause_str, lang):
        tree = self.lp_clause.parse(clause_str)
        return ExpTree(lang).transform(tree)

    def get_clauses(self, lang):
        return self.load_clauses(self.base_path + 'clauses.txt', lang)
    def get_clauses_mi(self, lang, meta_arg):
        if meta_arg==1:
            return self.load_clauses(self.base_path + 'mi_clauses.txt', lang)
        if meta_arg==2:
            return self.load_clauses(self.base_path + 'mi_clauses2.txt', lang)
        if meta_arg==3:
            return self.load_clauses(self.base_path + 'mi_do_clauses.txt', lang)

        if meta_arg=='Planner':
            return self.load_clauses(self.base_path + 'mi_Planner_clauses.txt', lang)
    def get_clauses_mi_train(self, lang):
        return self.load_clauses(self.base_path + 'mi_clauses_train.txt', lang)
    def get_bk(self, lang):
        return self.load_atoms(self.base_path + 'bk.txt', lang)

    def load_language(self):
        """Load language, background knowledge, and clauses from files.
        """
        preds = self.load_preds(self.base_path + 'preds.txt') + \
            self.load_neural_preds(self.base_path + 'neural_preds.txt')
        consts = self.load_consts(self.base_path + 'consts.txt')
        lang = Language(preds, [], consts)
        return lang

    def load_language_mi(self,meta_arg ):
        """Load language, background knowledge, and clauses from files.
        """
        a=0
        b =a
        if meta_arg ==1:
            preds = self.load_preds(self.base_path + 'meta_preds_train.txt')+ \
                self.load_neural_preds(self.base_path + 'mi_preds_train.txt')
            consts = self.load_consts_mi(self.base_path + 'mi_consts_train2.txt')
            funcs = self.load_funcs(self.base_path + 'funcs.txt')

        elif meta_arg==2:

            preds = self.load_preds(self.base_path + 'meta_preds2.txt') + \
                    self.load_neural_preds(self.base_path + 'mi_preds2.txt')
            consts = self.load_consts_mi(self.base_path + 'mi_consts2.txt')
            funcs = self.load_funcs(self.base_path + 'funcs.txt')

        elif meta_arg==3:

            preds = self.load_preds(self.base_path + 'meta_do_preds.txt') + \
                    self.load_neural_preds(self.base_path + 'mi_do_preds.txt')
            consts = self.load_consts_mi(self.base_path + 'mi_do_consts.txt')
            funcs = self.load_funcs(self.base_path + 'do_funcs.txt')


        elif meta_arg=="Planner":

            preds = self.load_preds(self.base_path + 'meta_Planner_preds.txt') + \
                    self.load_neural_preds(self.base_path + 'mi_Planner_preds.txt')
            consts = self.load_consts_mi(self.base_path + 'mi_Planner_consts.txt')
            funcs = self.load_funcs(self.base_path + 'Planner_funcs.txt')


        lang = Language(preds, funcs, consts)
        return lang

    def load_language_mi_train(self, meta_arg=0):
        """Load language, background knowledge, and clauses from files.
        """
        preds = self.load_preds(self.base_path + 'meta_preds_train.txt')+ \
            self.load_neural_preds(self.base_path + 'mi_preds_train.txt')
        if meta_arg==0:
            consts = self.load_consts_mi(self.base_path + 'mi_consts_train.txt')
        else:
            consts = self.load_consts_mi(self.base_path + 'mi_consts_train2.txt')
        #consts = self.load_consts_mi(self.base_path + 'mi_consts_train.txt')
        funcs = self.load_funcs(self.base_path + 'funcs.txt')
        lang = Language(preds, funcs, consts)

        return lang








    def generate_const_for_mi(self,atoms,V_T):
        atoms = atoms[2::]
        Value = V_T[2::]
        #consts = self.load_consts_mi(self.base_path + 'mi_consts_train.txt')
        with open(self.base_path + 'mi_consts.txt', 'w') as f:
            f.write('atom:')
            for item in atoms:
                if type(item.pred ) in [NeuralPredicate]:
                    f.write(str(item) + ";")


            f.write('\n'+'atom_head:')
            for item in atoms:
                if type(item.pred ) == Predicate:
                    if item.pred.name != 'diff_color':


                        claus_str = str(item)+':'+str(
                        )
                        f.write(str(item) + ";")
            f.write('\n'+'proof:')
            for i in range(len(atoms)):
                if type(atoms[i].pred) in [NeuralPredicate]:
                    f.write('('+str(atoms[i]) +',' +str(float(Value[i]))+')' ";")

    def generate_const_for_mi_train(self,atoms,V_T, meta_arg=0):
        atoms = atoms[2::]
        Value = V_T[2::]
        if meta_arg==0:
            with open(self.base_path + 'mi_consts_train.txt', 'w') as f:
                f.write('path:*;')
                f.write('\n' + 'stack:*;')
                f.write('\n' + 'start:a;')
                f.write('\n' + 'paths:*;')
                f.write('\n' + 'goal:e;')
                f.write('\n' + 'node:a;b;c;d;e;f;')
        else:
            with open(self.base_path + 'mi_consts_train2.txt', 'w') as f:
                f.write('path:*;')
                f.write('\n' + 'stack:*;')
                f.write('\n' + 'start:a;')
                f.write('\n' + 'paths:*;')
                f.write('\n' + 'goal:h;')
                f.write('\n' + 'node:a;b;c;d;e;f;g;h;')
            #f.write('\n' + 'atomsh:*;')
            #f.write(('\n'+'atom_all:'))
            #for item in atoms:
            #    if type(item.pred ) in [NeuralPredicate]:
            #        f.write(str(item) + ";")
            #    if type(item.pred ) == Predicate:
            #        if item.pred.name != 'diff_color':
            #            f.write(str(item) + ";")

            #f.write('\n'+'atom_head:')
            #for item in atoms:
            #    if type(item.pred ) == Predicate:
            #        if item.pred.name != 'diff_color':


            #            claus_str = str(item)+':'+str(
            #            )
            #            f.write(str(item) + ";")
            #f.write('\n'+'proof:')
            #for i in range(len(atoms)):
            #    if type(atoms[i].pred) in [NeuralPredicate]:
            #        f.write('('+str(atoms[i]) +',' +str(float(Value[i]))+')' ";")