import random
import string

import numpy as np
from faker import Faker

from src.template_utils.template import LogicalTemplate, ComplexLogicalTemplateModifier
from src.utils.tools import generate_ramdom_sequence, random_TF


def generate_random_str_list(num, min_length, max_length):
    """
    Generate a list of random strings with specified length range.
    """
    random_str_list = []
    for _ in range(num):
        length = random.randint(min_length, max_length)
        random_str = generate_ramdom_sequence(length, useNumber=False)
        random_str_list.append(random_str)

    return random_str_list


def random_modify_sample_names(sample: LogicalTemplate, min_length, max_length, forObjName=False, forRule=False,
                               template_class=LogicalTemplate, useWord=False, useRelatedWord=False, **kwargs):
    new_sample = template_class(**sample.__dict__())
    total_set = set()
    fake = Faker()
    if forObjName:
        while True:
            if useWord:
                random_obj_names = [fake.name().replace(' ', '_') for _ in range(new_sample.max_objnum)]
            else:
                random_obj_names = generate_random_str_list(new_sample.max_objnum, min_length, max_length)
            random_obj_names = [f'"{i}"' for i in random_obj_names]
            random_obj_names = set(random_obj_names)

            if len(random_obj_names) == new_sample.max_objnum:
                random_obj_names = list(random_obj_names)
                total_set.update(random_obj_names)
                break

        new_sample.set_obj_names(random_obj_names)

    if forRule:
        while True:
            if useRelatedWord:
                assert 'collection' in kwargs, 'collection is needed for useRelatedWord'
                collection = kwargs['collection']
                results = collection.query(
                    query_embeddings=np.random.randn(1024).tolist(),
                    n_results=new_sample.num_target_fact
                )
                random_predicate_names = [i.split('_')[0] for i in results['ids'][0]]
            elif useWord:
                random_predicate_names = [fake.word().replace(' ', '_') for _ in range(new_sample.num_target_fact)]
            else:
                random_predicate_names = generate_random_str_list(new_sample.num_target_fact, min_length, max_length)

            random_predicate_names = set(random_predicate_names)

            if (len(random_predicate_names) == sample.num_target_fact and
                    len(random_predicate_names & total_set) == 0):
                random_predicate_names = list(random_predicate_names)
                break

        new_sample.set_predicate_names(random_predicate_names)

    return new_sample


def ramdom_modify_rules_pool(
        sample: LogicalTemplate,
        add_fact_range=[0, 0],
        remove_fact_range=[0, 0],
        add_rule_range=[0, 0],
        p_add_rule_constraints=0,
        remove_rule_range=[0, 0],
        template_class=ComplexLogicalTemplateModifier):
    '''
    add_fact_range=[0, 0],
    remove_fact_range=[0, 0],
    add_rule_range=[0, 0],
    remove_rule_range=[0, 0],
    :param sample:
    :return:
    '''
    new_sample = template_class(**sample.__dict__())

    rules_pool = new_sample.rules_pool
    new_facts = []
    if add_fact_range is not None:
        for _ in range(random.randint(0, add_fact_range[1] - add_fact_range[0]) + add_fact_range[0]):
            rule = random.choices(rules_pool[:new_sample.num_init_fact])[0]
            new_facts.append([rule[0], None, [random.randint(0, new_sample.max_objnum-1) for _ in range(len(rule[2]))]])

    if remove_fact_range is not None:
        remove_fact_range[1] = min(remove_fact_range[1], new_sample.num_init_fact)

        remove_num = random.randint(0, remove_fact_range[1] - remove_fact_range[0]) + remove_fact_range[0]
        assert remove_num <= new_sample.num_init_fact, "remove_num is larger than the number of facts"
        remove_ids = random.sample(range(0, new_sample.num_init_fact), remove_num)
        for ids in remove_ids:
            rules_pool[ids][1] = -1

    if remove_rule_range is not None:
        remove_rule_range[1] = min(remove_rule_range[1], new_sample.num_target_fact-new_sample.num_init_fact)

        remove_num = random.randint(0, remove_rule_range[1] - remove_rule_range[0]) + remove_rule_range[0]
        assert remove_num <= new_sample.num_target_fact-new_sample.num_init_fact, "remove_num is larger than the number of rules"
        remove_ids = random.sample(range(new_sample.num_init_fact, new_sample.num_target_fact), remove_num)
        for ids in remove_ids:
            rules_pool[ids][1] = -1

    new_rules = []
    if add_rule_range is not None:
        for _ in range(random.randint(0, add_rule_range[1] - add_rule_range[0]) + add_rule_range[0]):
            rule = random.choices(rules_pool)[0]
            pnum = random.randint(1, new_sample.max_pnum)
            cnum = random.randint(1, new_sample.max_cnum)
            constraint_flag = random_TF(p_add_rule_constraints)
            pnum = max(2, pnum) if constraint_flag else pnum

            ids_from = random.choices(range(0, new_sample.num_target_fact), k=cnum)
            ps_from = []
            for id_from in ids_from:
                selected_rule = rules_pool[id_from]
                if selected_rule[1] is None:
                    ps_from += list(range(len(selected_rule[2])))
                else:
                    ps_from += rules_pool[id_from][2]
            new_rules.append([
                rule[0] if not constraint_flag  else -1,
                ids_from,
                random.choices(ps_from, k=random.randint(1, pnum))
            ])

    rules_pool = rules_pool + new_rules + new_facts
    new_sample.num_target_fact += len(new_rules)
    new_sample.rules_pool = rules_pool

    return new_sample


def check_rules_if_contain_default_negation(rules):
    for rule in rules:
        if 'not ' in rule:
            return True
    return False


def add_fact_against_default_rule(facts, rules, dlv):
    DN_rule = None
    for rule in rules:
        if 'not ' in rule:
            DN_rule = rule
    if DN_rule is None:
        return None, None

    results = set(dlv.quick_run(facts, rules)[1:-1].split(', '))
    DN_related_predicate = DN_rule.split('not ')[1].split(', ')[0]
    DN_related_answer = [r for r in results if r.replace('-', '').split('(')[0] in DN_rule]
    if len(DN_related_answer) == 0:
        return None, None
    else:
        DN_related_answer = DN_related_answer[0]

    DN_related_predicate_params = DN_related_predicate.split('(')[1][:-1]
    DN_related_answer_params = DN_related_answer.split('(')[1][:-1]
    DN_rule_params = DN_rule.split(' :-')[0].split('(')[1][:-1]

    obj_map = {i: obj for i, obj in zip(DN_rule_params.split(','), DN_related_answer_params.split(','))}
    params = [i for i in (DN_related_predicate_params.split(','))]
    predicate = DN_related_predicate.split('(')[0]

    predicate_objs = []
    for param in params:
        if param in obj_map:
            predicate_objs.append(obj_map[param])
        else:
            predicate_objs.append(list(obj_map.values())[0])
    new_fact = f'{predicate}({",".join(predicate_objs)})'
    return new_fact, DN_related_answer