import json
import re

import chromadb
import numpy as np
from tqdm import tqdm
from faker import Faker

from src.template_utils.modify import generate_random_str_list


def get_strs(k, min_length=5, max_length=15, type='none', useWord=False, useRelatedWord=False, **kwargs):
    fake = None
    if useWord:
        fake = Faker()

    if type=='name':
        if useWord:
            results = [fake.name().replace(' ', '_') for _ in range(k)]
        else:
            results = generate_random_str_list(k, min_length, max_length)
        return [f'"{i}"' for i in results]
    else:
        if useWord:
            if useRelatedWord:
                assert 'collection' in kwargs, 'collection is needed for useRelatedWord'
                collection = kwargs['collection']
                results = collection.query(
                    query_embeddings=np.random.randn(1024).tolist(),
                    n_results=k
                )
                return [i.split('_')[0] for i in results['ids'][0]]
            else:
                return [fake.word() for _ in range(k)]
        else:
            return generate_random_str_list(k, min_length, max_length)




def build_nl_dataset(dataset_paths:list, language_dataset_path, useWord=False, useRelatedWord=False, **kwargs):
    # read symbolic dataset

    dataset = []
    for dataset_path in dataset_paths:
        with open(dataset_path, 'r') as f:
            dataset += [json.loads(line) for line in f][1:]
    new_dataset = []

    collection = None
    if useRelatedWord:
        assert useWord, "use_related_word must be True when useWord is True"
        # 初始化客户端
        chroma_client = chromadb.PersistentClient(path="../chroma_data")

        # 设置集合名称
        collection_name = "wordnet_cosine"
        print("loading collection...")
        collection = chroma_client.get_collection(collection_name)
        print("loading collection done...")

    for sample in tqdm(dataset):
        new_predicates = get_strs(len(sample['predicates'])*3, useWord=useWord, useRelatedWord=useRelatedWord, collection=collection, **kwargs)

        new_predicates = list(set(new_predicates))[:len(sample['predicates'])]

        new_objs = get_strs(len(sample['objs'])*3, type='name', useWord=useWord, useRelatedWord=useRelatedWord, **kwargs)
        new_objs = list(set(new_objs))[:len(sample['objs'])]

        old_predicates = sample['predicates']
        old_objs = sample['objs']

        map_old_new_predicates = {old_predicates[i]: new_predicates[i] for i in range(len(old_predicates))}
        map_old_new_objs = {old_objs[i]: new_objs[i] for i in range(len(old_objs))}

        new_facts = []
        new_rules = []
        new_queries = []
        for i, fact in enumerate(sample['facts']):
            for k, v in map_old_new_predicates.items():
                if k in fact:
                    fact = fact.replace(k, v)

            for k, v in map_old_new_objs.items():
                if k in fact:
                    fact = fact.replace(k, v)

            new_facts.append(fact)

        for i, rule in enumerate(sample['rules']):
            for k, v in map_old_new_predicates.items():
                if k in rule:
                    rule = rule.replace(k, v)

            new_rules.append(rule)

        for i, query in enumerate(sample['queries']):
            query_text = query['query']
            for k, v in map_old_new_predicates.items():
                if k in query_text:
                    query_text = query_text.replace(k, v)
            for k, v in map_old_new_objs.items():
                if k in query_text:
                    query_text = query_text.replace(k, v)
            new_queries.append({
                'label': query['label'],
                'query': query_text
            })

        new_sample = {**sample}
        new_sample['facts'] = new_facts
        new_sample['rules'] = new_rules
        new_sample['predicates'] = new_predicates
        new_sample['objs'] = new_objs
        new_sample['queries'] = new_queries
        new_dataset.append(new_sample)

    print(len(new_dataset))
    # transfer symbolic dataset to natural language dataset
    with open(language_dataset_path, 'w') as f:
        for sample in new_dataset:
            f.write(json.dumps(sample) + '\n')


if __name__ == '__main__':
    dataset_paths = [
        'logicalDatasets/filtered_samples/overall_dataset.jsonl.checked',
    ]
    language_dataset_path = 'logicalDatasets/filtered_samples/related_word_symbolic_dataset.jsonl'
    build_nl_dataset(dataset_paths, language_dataset_path, useWord=True, useRelatedWord=True)
    print('finish...')