from .logic import Var
import itertools


class Language():
    """
    Language of logic programs

    Parameters
    ----------
    preds : List[.logic.Predicate]
        set of predicate symbols
    funcs : List[.logic.FunctionSymbol]
        set of function symbols
    consts : List[.logic.Const]
        set of constants
    subs_consts : List[.logic.Const]
        set of constants that can be substituted in the refinement step
    """

    def __init__(self, preds, funcs, consts, mode_declarations={}):
        self.preds = preds
        self.funcs = funcs
        self.consts = consts
        self.mode_declarations = mode_declarations
        self.var_gen = VariableGenerator()

    def __str__(self):
        s = "Predicates: "
        for pred in self.preds:
            s += pred.__str__() + ' '
        s += "\nFunction Symbols: "
        for func in self.funcs:
            s += func.__str__() + ' '
        s += "\nConstants: "
        for const in self.consts:
            s += const.__str__() + ' '
        return s

    def __repr__(self):
        return self.__str__()

    def get_mode_by_pred(self, pred):
        for mode_dec in self.mode_declarations:
            if mode_dec.pred == pred:
                return mode_dec
        assert False, "No matching predicate in mode declaration"

    def get_var_and_dtype(self, atom):
        '''
        get all variables in an input atom with its dtypes
        assumption with function free atoms
        '''
        var_dtype_list = []
        for i, arg in enumerate(atom.terms):
            if arg.is_var():
                dtype = self.get_mode_by_pred(atom.pred).dtypes[i]
                var_dtype_list.append((arg, dtype))
        # print(var_dtype_list)
        return var_dtype_list

    def get_by_dtype(self, dtype):
        """
        get constants that match given dtype
        """
        return [c for c in self.consts if c.dtype == dtype]

    def get_by_dtype_name(self, dtype_name):
        """
        get constants that match given dtype name
        """
        return [c for c in self.consts if c.dtype.name == dtype_name]

    def get_args_by_pred(self, pred):
        """
        get constants for a predicate removing the duplicates
        """
        dtypes = pred.dtypes
        consts_list = [self.get_by_dtype(dtype) for dtype in dtypes]
        args_list = itertools.product(*consts_list)
        args_set_list = []
        result = []
        for args in args_list:
            if not(set(args) in args_set_list):
                result.append(args)
                args_set_list.append(set(args))
        return result

    def term_index(self, term):
        """
        get index of a term in the language
        """
        terms = self.get_by_dtype(term.dtype)
        return terms.index(term)

    def get_dtype(self, const_name):
        const = [c for c in self.consts if const_name == c.name]
        assert len(const) == 1, 'Too many match in ' + term.name
        return const[0].dtype

    def get_const_by_name(self, const_name):
        const = [c for c in self.consts if const_name == c.name]
        assert len(const) == 1, 'Too many match in ' + const_name
        return const[0]

    def get_pred_by_name(self, pred_name):
        pred = [pred for pred in self.preds if pred.name == pred_name]
        print(pred)
        assert len(pred) == 1, 'Too many or less match in ' + pred_name
        return pred[0]


class DataType():
    def __init__(self, name):
        self.name = name

    def __eq__(self, other):
        return self.name == other.name

    def __str__(self):
        return self.name

    def __repr__(self):
        return self.__str__()

    def __hash__(self):
        return hash(self.__str__())


class ModeDeclaration():
    def __init__(self, pred, dtypes):
        self.pred = pred
        self.dtypes = dtypes

    def __str__(self):
        s = self.pred.name + '('
        for dtype in self.dtypes:
            s += dtype + ','
        s = s[0:-1]
        s += ')'
        return s

    def __hash__(self):
        return hash(self.__str__())

    def __repr__(self):
        return self.__str__()


class VariableGenerator():
    """
    generator of variables

    Parameters
    __________
    base_name : str
        base name of variables
    """

    def __init__(self, base_name='x'):
        self.counter = 0
        self.base_name = base_name

    def generate(self):
        """
        generate variable with new name

        Returns
        -------
        generated_var : .logic.Var
            generated variable
        """
        generated_var = Var(self.base_name + str(self.counter))
        self.counter += 1
        return generated_var
