from meteor_reasoner.utils import parser as parser
from meteor_reasoner.materialization.materialize import coalescing_d, build_index, naive_join
from collections import defaultdict
import typing
from meteor_reasoner.classes.interval import Interval
from meteor_reasoner.classes.rule import Rule
from meteor_reasoner.classes.atom import Atom
from meteor_reasoner.classes.literal import Literal, Operator
import copy
import random
import decimal


def random_lower_case():
    return chr(ord('a')+random.randint(0, 6))


def random_upper_case():
    return chr(ord('A')+random.randint(23, 25))


def compare(d1: defaultdict, d2: defaultdict) -> bool:
    return dump_dataset(d1) == dump_dataset(d2)


def dump_operator(operator: Operator) -> str:
    return operator.name + dump_interval(operator.interval)


def dump_literal(literal: Literal) -> str:
    return "".join([dump_operator(operator) for operator in literal.operators]) + dump_atom(literal.atom)


def dump_rule(rule: Rule) -> str:
    if len(rule.negative_body) == 0:
        return dump_atom(rule.head) + ":-" + ",".join([dump_literal(literal) for literal in rule.body])
    else:
        return dump_atom(rule.head) + ":-" + ",".join([dump_literal(literal) for literal in rule.body])\
            + "||" + ",".join([dump_literal(literal)
                              for literal in rule.negative_body_atoms])


def dump_atom(atom: Atom) -> str:
    if atom.interval:
        if len(atom.entity) > 1 or atom.entity[0].name != "nan":
            return atom.predicate + "(" + ",".join([str(item) for item in atom.entity]) + ")@" + dump_interval(atom.interval)
        else:
            return atom.predicate + "@" + dump_interval(atom.interval)

    else:
        if atom.entity:
            if len(atom.entity) == 1 and atom.entity[0].name == "nan":
                return atom.predicate

            return atom.predicate + "(" + ",".join([str(item) for item in atom.entity]) + ")"
        else:
            return atom.predicate


def dump_interval(interval: Interval) -> str:
    n = 3
    str_output = ""
    if interval.left_open:
        str_output += "("
    else:
        str_output += "["
    if interval.left_value in [decimal.Decimal("-inf")]:
        str_output += "-inf"
    elif interval.left_value in [decimal.Decimal("inf")]:
        str_output += "+inf"
    else:
        str_output += str(round(interval.left_value, n))

    str_output += ","

    if interval.right_value in [decimal.Decimal("-inf")]:
        str_output += "-inf"
    elif interval.right_value in [decimal.Decimal("inf")]:
        str_output += "+inf"
    else:
        str_output += str(round(interval.right_value, n))

    if interval.right_open:
        str_output += ")"
    else:
        str_output += "]"

    return str_output


def dump_single_data(data: typing.Tuple[str, typing.Tuple, Interval]) -> str:
    predicate, entity, interval = data
    if len(entity) == 1 and entity[0].name == "nan":
        return predicate+"@"+dump_interval(interval)
    else:
        return predicate + "(" + ",".join([item.name for item in entity]
                                          ) + ")@" + dump_interval(interval)


def dump_dataset_array(D: defaultdict) -> typing.List[str]:
    result = []
    for predicate in D:
        for entity, intervals in D[predicate].items():
            for interval in intervals:
                result.append(dump_single_data((predicate, entity, interval)))
    return result


def dump_dataset(D: defaultdict) -> str:
    result = ""
    for predicate in D:
        for entity, intervals in D[predicate].items():
            for interval in intervals:
                result += dump_single_data((predicate,
                                           entity, interval)) + "\n"
    return result


def infer(dataset: defaultdict, program: typing.List[Rule]):
    remain_cnt = 0
    tot_cnt = 0

    delta_new_tot = defaultdict(lambda: defaultdict(list))
    while True:
        dataset_old = copy.deepcopy(dataset)
        for rule in program:
            delta_new = defaultdict(lambda: defaultdict(list))
            d_index = build_index(dataset)
            naive_join(rule, dataset, delta_new, d_index)

            for predicate in delta_new:
                if predicate not in dataset:
                    dataset[predicate] = delta_new[predicate]
                else:
                    for entity in delta_new[predicate]:
                        dataset[predicate][entity] = dataset[predicate][entity] + \
                            delta_new[predicate][entity]
                if predicate not in delta_new_tot:
                    delta_new_tot[predicate] = delta_new[predicate]
                else:
                    for entity in delta_new[predicate]:
                        delta_new_tot[predicate][entity] = delta_new_tot[predicate][entity] + \
                            delta_new[predicate][entity]
                coalescing_d(dataset)
                coalescing_d(delta_new_tot)
        if compare(dataset_old, dataset):
            remain_cnt += 1
        else:
            remain_cnt = 0
        if remain_cnt == 10:
            break
        tot_cnt += 1
        if tot_cnt > 50:
            print("Maxium iteration exceed")
            break
    return dataset, delta_new_tot


def merge_data(D: defaultdict, data: typing.Tuple[str, typing.Tuple, Interval]):
    predicate, entity, interval = data
    if predicate not in D:
        if entity:
            D[predicate][entity] = [interval]
        else:
            D[predicate] = [interval]
    else:
        if isinstance(D[predicate], list) and entity is not None:
            raise ValueError(
                "One predicate can not have both entity and Null cases!")

        if not isinstance(D[predicate], list) and entity is None:
            raise ValueError(
                "One predicate can not have both entity and Null cases!")

        if entity:
            if entity in D[predicate]:
                D[predicate][entity].append(interval)
            else:
                D[predicate][entity] = [interval]
        else:
            D[predicate].append(interval)
    coalescing_d(D)


def entail(dataset: defaultdict, data_entry: typing.Tuple[str, typing.Tuple, Interval]):
    predicate, entity, interval = data_entry
    if predicate in dataset:
        if entity in dataset[predicate]:
            for intv in dataset[predicate][entity]:
                if interval.left_value >= intv.left_value and interval.right_value <= intv.right_value:
                    return True
    return False