import random

import numpy as np

from src.dlv_utils.dlv import DLVHandler
from src.utils.tools import generate_ramdom_sequence, random_add_str, random_TF, are_sets_connected, \
    modify_sets_to_connect
import string
import json


class LogicalTemplate:
    """docstring for LogicalTemplate"""

    def __init__(self, sample_type, matrix, num_init_fact, num_target_fact, max_pnum, max_cnum, max_objnum,  **kargs):
        self.sample_type = sample_type
        if type(matrix) is list:
            self.matrix = np.zeros((num_target_fact, num_target_fact))
            for i, j in zip(*matrix):
                self.matrix[i][j] = 1
        else:
            self.matrix = np.array(matrix)
        self.num_init_fact = num_init_fact
        self.num_target_fact = num_target_fact
        self.max_pnum = max_pnum
        self.max_cnum = max_cnum
        self.max_objnum = max_objnum

        self.DLVhandler = DLVHandler()

        if 'predicate_names' in kargs:
            self.predicate_names = kargs['predicate_names']
        else:
            self.predicate_names = None
        if 'obj_names' in kargs:
            self.obj_names = kargs['obj_names']
        else:
            self.obj_names = [f"\"name_{i}\"" for i in range(max_objnum)]

        if 'id' in kargs:
            self.id = kargs['id']
        else:
            self.id = None

    def __dict__(self):
        return {
            'id': 'template_sample_' + generate_ramdom_sequence(20) if self.id is None else self.id,
            'sample_type': self.sample_type,
            'matrix': [i.tolist() for i in np.where(self.matrix == 1)],
            'num_init_fact': self.num_init_fact,
            'num_target_fact': self.num_target_fact,
            'max_pnum': self.max_pnum,
            'max_cnum': self.max_cnum,
            'predicate_names': self.predicate_names,
            'obj_names': self.obj_names,
            'max_objnum': self.max_objnum,
        }

    def __str__(self):
        return f"range: {self.num_init_fact} -> {self.num_target_fact}, type: {self.sample_type}, max_pnum: {self.max_pnum}"

    def __eq__(self, other):
        return str(self) == str(other)

    def __hash__(self):
        return hash(str(self))

    def get_predicate_names(self):
        if self.predicate_names is None:
            return [f'w_{i}' for i in range(self.num_target_fact)]
        return self.predicate_names

    def set_predicate_names(self, facts_name: list[str]):
        assert len(facts_name) == self.num_target_fact, 'facts_name length must be equal to num_target_fact'
        assert len(facts_name) == len(set(facts_name)), 'facts_name must be unique'
        self.predicate_names = facts_name

    def set_obj_names(self, objs_name: list[str]):
        self.obj_names = objs_name

    def get_obj_names(self):
        return self.obj_names

    def generate_sample(self, **kwargs) -> dict:
        facts_name = self.get_predicate_names()

        sample = {
            'facts': [f'{facts_name[i]}({self.get_obj_names()[0]})' for i in range(self.num_init_fact)],
        }

        rules = []

        for idx, froms in enumerate(self.matrix.T[self.num_init_fact:]):
            idx_to = idx + self.num_init_fact
            idx_froms = []
            for idx_from, is_from in enumerate(froms):
                if is_from:
                    idx_froms.append(idx_from)
            rules.append(
                f'{facts_name[idx_to]}(X) :- {", ".join([f"{facts_name[idx_from]}(X)" for idx_from in idx_froms])}')

        sample['rules'] = rules

        sample['queries'] = [f'{facts_name[i]}({self.get_obj_names()[0]})' for i in range(self.num_target_fact)]

        return sample

    def self_check_by_dlv(self, **kwargs):
        r = self.generate_sample(**kwargs)
        for k, v in r.items():
            for vv in v:
                if k == 'facts':
                    self.DLVhandler.add_fact(vv)
                elif k == 'rules':
                    self.DLVhandler.add_rule(vv)
        self.DLVhandler.save_program()
        program_content = self.DLVhandler.program_content
        dlv_output = self.DLVhandler.run_and_get_results()
        self.DLVhandler.clear_program()

        if '{' in dlv_output and '}' in dlv_output:
            return program_content
        else:
            return None

    def to_dlv_program_content(self, **kwargs):
        if 'sample' in kwargs:
            r = kwargs['sample']
        else:
            r = self.generate_sample(**kwargs)
        for k, v in r.items():
            for vv in v:
                if k == 'facts':
                    self.DLVhandler.add_fact(vv)
                elif k == 'rules':
                    self.DLVhandler.add_rule(vv)
        content = self.DLVhandler.program_content
        self.DLVhandler.clear_program()
        return content


class ComplexLogicalTemplate(LogicalTemplate):
    def __init__(self, sample_type, rules_pool, num_init_fact, num_target_fact, max_pnum, max_cnum, max_objnum, **kargs):
        super().__init__(
            sample_type=sample_type,
            matrix=None,
            num_init_fact=num_init_fact,
            num_target_fact=num_target_fact,
            max_pnum=max_pnum,
            max_cnum=max_cnum,
            max_objnum=max_objnum,
            **kargs
        )
        self.rules_pool = rules_pool

    def generate_sample(self, **kwargs) -> dict:
        '''

        :return:
        '''
        facts_name = self.get_predicate_names()

        facts_pool = [i for i in self.rules_pool if i[1] is None]
        sample = {}
        fact_desc_list = []
        for fact in facts_pool:
            fact_desc = f'{facts_name[fact[0]]}('
            fact_desc += ','.join([self.get_obj_names()[i] for i in fact[2]])
            fact_desc += ')'
            fact_desc_list.append(fact_desc)
        sample['facts'] = fact_desc_list
        rules = []

        def get_param(i):
            if i >= 26:
                return get_param(i // 26 - 1) + string.ascii_letters[26 + i % 26]
            return string.ascii_letters[26 + i]

        for idx, rule in enumerate(self.rules_pool[self.num_init_fact:]):
            to_param = ','.join([get_param(i) for i in rule[2]])
            idx_froms = rule[1]
            rule_str_list = []
            for idx_from in idx_froms:
                fidx, _, fparams = self.rules_pool[idx_from]
                rule_str_list.append(f'{facts_name[fidx]}({",".join([get_param(i) for i in fparams])})')
            rule_str = f'{facts_name[idx + self.num_init_fact]}({to_param}) :- {", ".join(rule_str_list)}'
            rules.append(rule_str)

        sample['rules'] = rules
        return sample

    def __str__(self):
        return json.dumps(self.generate_sample())

    def __hash__(self):
        return hash(str(self))

    def __dict__(self):
        dict = super().__dict__()
        del dict['matrix']
        dict['rules_pool'] = self.rules_pool
        dict['id'] = 'template_complex_sample_' + generate_ramdom_sequence(20) if self.id is None else self.id
        return dict

    def self_check_by_dlv(self, **kwargs) -> bool:
        if 'sample' in kwargs:
            r = kwargs['sample']
        else:
            r = self.generate_sample(**kwargs)

        for k, v in r.items():
            for vv in v:
                if k == 'facts':
                    self.DLVhandler.add_fact(vv)
                elif k == 'rules':
                    self.DLVhandler.add_rule(vv)
        self.DLVhandler.save_program()
        program_content = self.DLVhandler.program_content
        dlv_output = self.DLVhandler.run_and_get_results()
        self.DLVhandler.clear_program()

        if '{' in dlv_output and '}' in dlv_output:
            return program_content
        else:
            return None


class ComplexLogicalTemplateModifier(ComplexLogicalTemplate):
    def __init__(self, **kargs):
        super().__init__(**kargs)

    def generate_sample(self, p_neg=0, p_dneg=0, p_add_conclution=0, add_conclution_max=1,p_change_variable=0,
                        **kwargs) -> dict:
        '''

        :param p_neg: 0~1 Strong Negation Probability
        :param p_dneg: 0~1 Default Negation Probability
        :return:
        '''

        facts_name = self.get_predicate_names()

        def get_param(i):
            if i >= 26:
                return get_param(i // 26 - 1) + string.ascii_letters[26 + i % 26]
            return string.ascii_letters[26 + i]

        def get_predicate(fidx, ptype):
            strs = []
            if ptype not in ['conclustion', 'fact']:
                strs.append(random_add_str("not ", p_dneg))

            strs.append(random_add_str("-", p_neg))

            strs.append(facts_name[fidx])
            return ''.join(strs)

        def get_conclusion(idx, to_param=None):
            rule = self.rules_pool[idx]
            to_param = ','.join([get_param(i) for i in rule[2]]) if to_param is None else to_param
            return get_predicate(idx, 'conclustion') + f'({to_param})'

        facts_pool = [i for i in self.rules_pool if i[1] is None]
        sample = {}
        fact_desc_list = []

        for fact in facts_pool:
            if type(fact[1]) is int:
                continue

            fact_desc = f'{get_predicate(fact[0], "fact")}('
            fact_desc += ','.join([self.get_obj_names()[i] for i in fact[2]])
            fact_desc += ')'
            fact_desc_list.append(fact_desc)
        sample['facts'] = fact_desc_list
        rules = []

        for idx, rule in enumerate(self.rules_pool[self.num_init_fact:self.num_target_fact]):
            if type(rule[1]) is int:
                continue
            idx_froms = rule[1]

            change_variable_flag = random_TF(p_change_variable)
            params = []
            random_integers = None
            # 随机修改参数序号
            if change_variable_flag:
                params_sum = 0
                for idx_from in idx_froms:
                    fidx, _, fparams = self.rules_pool[idx_from]
                    params.append(fparams)
                    params_sum += len(fparams)

                random_integers = [random.randint(0, params_sum) for _ in range(params_sum)]
                int2idx = {j: i for i, j in enumerate(list(set(random_integers)))}
                random_integers = [int2idx[i] for i in random_integers]
                i = 0
                for ps in params:
                    for j in range(len(ps)):
                        ps[j] = random_integers[i]
                        i += 1
                if not are_sets_connected(params):
                    params = modify_sets_to_connect(params)

            rule_str_list = []
            for idx_idx_from, idx_from in enumerate(idx_froms):
                fidx, _, fparams = self.rules_pool[idx_from]
                if change_variable_flag:
                    assert len(params[idx_idx_from]) == len(fparams), 'change variable error'
                    fparams = params[idx_idx_from]

                rule_str_list.append(
                    f'{get_predicate(fidx, "condition")}({",".join([get_param(i) for i in fparams])})')

            to_param = rule[2]
            if change_variable_flag:
                to_param = random.choices(list(set(random_integers)), k=len(to_param))
            conclusions = get_conclusion(rule[0], to_param=",".join([get_param(i) for i in to_param]))

            # 添加新结论
            new_conclusion = None
            if random_TF(p_add_conclution):
                random_new_c_ids = range(self.num_init_fact, self.num_target_fact)
                random_new_c_ids = [self.rules_pool[i][0] for i in random_new_c_ids]
                random_new_c_num = random.randint(1, add_conclution_max)
                new_cs = random.sample(random_new_c_ids, random_new_c_num)
                new_params = []
                for c in new_cs:
                    c_rule = self.rules_pool[c]
                    if change_variable_flag:
                        new_params.append(random.choices(list(set(random_integers)), k=len(c_rule[2])))
                    else:
                        new_params.append(c_rule[2])

                new_conclusion = '|'.join([
                    get_conclusion(i, to_param=",".join([get_param(i) for i in new_params[idx_new_c]])) for idx_new_c, i in enumerate(new_cs)
                    if i != idx + self.num_init_fact])

            if new_conclusion:
                conclusions += random_add_str('|'+new_conclusion, p_add_conclution)

            if rule[0] < 0:
                rule_str = f':- {", ".join(rule_str_list)}'
            else:
                rule_str = f'{conclusions} :- {", ".join(rule_str_list)}'
            rules.append(rule_str)

        sample['rules'] = rules
        return sample

def generate_sample(num_init_fact, num_target_fact, max_cnum=2, max_pnum=1):
    m = np.zeros((num_target_fact, num_target_fact))

    for ic in range(num_init_fact, num_target_fact):

        cnum = random.randint(1, max_cnum)

        if ic < cnum:
            cnum = ic

        from_idxs = random.sample(range(0, ic), cnum)
        for from_idx in from_idxs:
            m[from_idx, ic] = 1

    return {
        'sample_type': [i for i in m.sum(1).astype(int).tolist() if i != 0],
        'matrix': m,
        'num_init_fact': num_init_fact,
        'num_target_fact': num_target_fact
    }


def generate_sample_complex(num_init_fact, num_target_fact, max_cnum=2, max_pnum=3, max_objnum=2, rule_window=999):
    rules_pool = []
    rule_count = 0
    for _ in range(num_init_fact):
        to_idxs = random.choices(range(0, max_objnum), k=random.randint(1, max_pnum))
        rules_pool.append((rule_count, None, to_idxs))
        rule_count += 1

    for ic in range(num_init_fact, num_target_fact):

        cnum = random.randint(1, max_cnum)
        if ic < cnum:
            cnum = ic

        if ic < rule_window:
            rule_start = 0
        else:
            rule_start = ic - rule_window

        from_idxs = random.sample(rules_pool[rule_start:], k=cnum)

        current_maxpnum = max([len(i[2]) for i in from_idxs])
        pnum = random.randint(1, current_maxpnum)
        to_idxs = random.sample(range(0, current_maxpnum), pnum)

        from_idxs.sort()

        froms = [i[0] for i in from_idxs]
        froms.sort()

        rules_pool.append((rule_count, froms, to_idxs))
        rule_count += 1

    return {
        'sample_type': f'init_{num_init_fact}_target_{num_target_fact}_cnum_{max_cnum}_pnum_{max_pnum}',
        'rules_pool': rules_pool,
        'num_init_fact': num_init_fact,
        'num_target_fact': num_target_fact
    }


def to_dlv_program_content_by_facts_rules(facts, rules):
    handler = DLVHandler()
    for fact in facts:
        handler.add_fact(fact)

    for rule in rules:
        handler.add_rule(rule)
    content = handler.program_content
    handler.clear_program()
    return content
