import argparse
import json
import os
import random

import chromadb
import numpy as np

from tqdm import tqdm
import re
from faker import Faker

from src.template_utils.modify import random_modify_sample_names, ramdom_modify_rules_pool
from src.template_utils.template import ComplexLogicalTemplate, ComplexLogicalTemplateModifier
from src.utils.tools import generate_ramdom_sequence

def main(args):
    # templates_path = "logicalDatasets/complex_templates_1.jsonl"
    # generation_path = 'logicalDatasets/generation/complexSample.jsonl'
    templates_path = args.templates_path
    generation_path = args.generation_path
    # assert not os.path.exists(generation_path), f"{generation_path} already exists"
    random.seed(args.seed)
    np.random.seed(args.seed)
    Faker.seed(args.seed)

    print(random.random())
    print(np.random.random())

    lts = []
    print("loading templates...")
    with open(templates_path, "r") as f:
        templates = f.readlines()
        templates = random.sample(templates, args.selected_num)
        for template in tqdm(templates):
            lts.append(template)
    print("loading templates done")
    print(f"total selected templates: {len(lts)}")

    dataset = []

    collection = None
    if args.use_related_word:
        assert args.use_word, "use_related_word must be True when useWord is True"
        # 初始化客户端
        chroma_client = chromadb.PersistentClient(path="../chroma_data")

        # 设置集合名称
        collection_name = "wordnet_cosine"
        print("loading collection...")
        collection = chroma_client.get_collection(collection_name)
        print("loading collection done...")

    for lt_str in tqdm(lts):
        for _ in range(args.generation_num):
            lt_json = json.loads(lt_str)
            lt = ComplexLogicalTemplate(**lt_json)
            lt = random_modify_sample_names(lt, 5, 10, True, True, template_class=ComplexLogicalTemplateModifier, useWord=args.use_word, useRelatedWord=args.use_related_word, collection=collection)
            lt = ramdom_modify_rules_pool(
                lt,
                add_fact_range=args.add_fact_range,
                remove_fact_range=args.remove_fact_range,
                add_rule_range=args.add_rule_range,
                p_add_rule_constraints=args.p_add_rule_constraints,
                remove_rule_range=args.remove_rule_range,
                template_class=ComplexLogicalTemplateModifier)

            # 获取queries
            # lt = lt_
            sample_q = lt.generate_sample()
            lt_program_content = lt.self_check_by_dlv(sample=sample_q)
            if not lt_program_content:
                continue
            r_dlv = lt.DLVhandler.run_given_program(lt_program_content)
            query_sets_str = re.findall('({.*})', r_dlv)
            query_sets = []
            for query_set_str in query_sets_str:
                query_sets.append(query_set_str.strip()[1:-1].split(', '))


            if not ('{' in r_dlv and '}' in r_dlv):
                continue

            sample = lt.generate_sample(
                p_neg=args.p_neg,
                p_dneg=args.p_dneg,
                p_add_conclution=args.p_add_conclution,
                add_conclution_max=args.add_conclution_max,
                p_change_variable=args.p_change_variable
            )

            lt_program_content = lt.self_check_by_dlv(sample=sample)

            pc = lt.to_dlv_program_content(sample=sample)
            # 检查是否可以运行
            if not lt_program_content:
                continue

            r_dlv = lt.DLVhandler.run_given_program(lt_program_content)

            # 检查是否包含答案集
            if not ('{' in r_dlv and '}' in r_dlv):
                continue

            # 获取 answer_sets
            answer_sets_str = re.findall('({.*})', r_dlv)
            answer_sets = []
            for answer_set_str in answer_sets_str:
                answer_sets.append(answer_set_str.strip()[1:-1].split(', '))

            final_sample = {
                'id': 'complex_sample_' + generate_ramdom_sequence(20),
                'id_source': lt_json['id'],
                **sample,
                'queries': query_sets[0],
                'answers': answer_sets[0],
                'predicates': lt.get_predicate_names(),
                'objs':lt.get_obj_names(),
            }

            dataset.append(final_sample)


    # 输出转化率
    print('每个模板生成样本数量：', args.generation_num)
    print('转化数：', len(dataset), '/', len(lts))
    print('转化率：', len(dataset) / len(lts))
    with open(generation_path, 'w') as f:
        for data in dataset:
            f.write(json.dumps(data) + '\n')

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Process complex logical templates and generate samples.")

    parser.add_argument('--templates_path', type=str, required=True, help="Path to the input templates JSONL file")
    parser.add_argument('--generation_path', type=str, required=True, help="Path to save the generated dataset")
    parser.add_argument('--add_fact_range', type=int, nargs=2, default=[0, 0], help="Range for adding facts")
    parser.add_argument('--remove_fact_range', type=int, nargs=2, default=[0, 0], help="Range for removing facts")
    parser.add_argument('--add_rule_range', type=int, nargs=2, default=[0, 0], help="Range for adding rules")
    parser.add_argument('--p_add_rule_constraints', type=float, default=0,
                        help="Probability of adding rule constraints")
    parser.add_argument('--remove_rule_range', type=int, nargs=2, default=[0, 0], help="Range for removing rules")
    parser.add_argument('--p_neg', type=float, default=0, help="Probability of strong negation")
    parser.add_argument('--p_dneg', type=float, default=0.1, help="Probability of default negation")
    parser.add_argument('--p_add_conclution', type=float, default=0, help="Probability of adding conclusions")
    parser.add_argument('--add_conclution_max', type=int, default=2, help="Maximum number of conclusions to add")
    parser.add_argument('--p_change_variable', type=float, default=0, help="Probability of changing variable positions")

    # use Wordtrue
    parser.add_argument('--use_word', action='store_true', help="Whether to use word")
    parser.add_argument('--use_related_word', action='store_true', help="Whether to use word")
    # generation number for each template
    parser.add_argument('--generation_num', type=int, default=1, help="Number of samples to generate for each template")
    # random seed
    parser.add_argument('--seed', type=int, default=42, help="Random seed")
    # selected num
    parser.add_argument('--selected_num', type=int, default=100, help="Number of template to select")


    args = parser.parse_args()
    print(args)
    main(args)