import json
import os
import random

from src.dataset_utils.dataset import get_dataset


def main(args):
    # set seed
    query_num = 100
    subset_num = 3
    seed = 42
    random.seed(seed)
    datatypes = [
        'random_str_symbolic',
        'random_word_symbolic',
        'related_word_symbolic',
        'random_str_language',
        'random_word_language',
        'related_word_language',
    ]

    normal_evaluation_path = './datasets/small_normal_eval'

    selected_idxs_dict = {}
    selected_queries_dict = {}
    for datatype in datatypes:
        dtype = '_'.join(datatype.split('_')[:2])

        dataset = get_dataset('LabelClassification', datatype)
        dr_samples = dataset.get_samples_for_default_reasoning(slice(None))

        for sub_id in range(subset_num):
            sub_dtype = dtype + '_' + str(sub_id)
            if sub_dtype not in selected_idxs_dict:
                selected_idxs = random.sample(range(len(dr_samples)), k=query_num)
                selected_idxs_dict[sub_dtype] = selected_idxs
            else:
                selected_idxs = selected_idxs_dict[sub_dtype]

            selected_queries = [dr_samples[i] for i in selected_idxs]
            selected_queries_dict[datatype+'_' + str(sub_id)] = selected_queries

    print('save normal eval sets...')
    if not os.path.exists(normal_evaluation_path):
        os.makedirs(normal_evaluation_path)
    for k, v in selected_queries_dict.items():
        saved_path = os.path.join(normal_evaluation_path, k+'.jsonl')
        with open(saved_path, 'w') as f:
            for query in v:
                    f.write(json.dumps(query) + '\n')
    print('save normal eval sets complete...')

    print('build nm eval dataset')
    nm_evaluation_path = './datasets/tiny_nm_eval'

    selected_idxs_dict = {}
    selected_queries_dict = {}
    for datatype in datatypes:
        dtype = '_'.join(datatype.split('_')[:2])

        dataset = get_dataset('LabelNMClassification', datatype)
        dr_samples = dataset

        for sub_id in range(subset_num):
            sub_dtype = dtype + '_' + str(sub_id)
            if sub_dtype not in selected_idxs_dict:
                selected_idxs = random.sample(range(len(dr_samples)), k=query_num//2)
                selected_idxs_dict[sub_dtype] = selected_idxs
            else:
                selected_idxs = selected_idxs_dict[sub_dtype]

            selected_queries = []
            for i in selected_idxs:
                selected_queries += dr_samples.get_samples_for_default_reasoning(i)
            selected_queries_dict[datatype+'_' + str(sub_id)] = selected_queries
    print('save nm eval sets...')
    if not os.path.exists(nm_evaluation_path):
        os.makedirs(nm_evaluation_path)
    for k, v in selected_queries_dict.items():
        saved_path = os.path.join(nm_evaluation_path, k+'.jsonl')
        with open(saved_path, 'w') as f:
            for query in v:
                f.write(json.dumps(query) + '\n')


    print('build generation dataset')
    generation_evaluation_path = './datasets/small_generation_eval'
    selected_idxs_dict = {}
    selected_samples_dict = {}
    for datatype in datatypes:
        dtype = '_'.join(datatype.split('_')[:2])

        dataset = get_dataset('LabelClassification', datatype)
        dr_samples = dataset

        for sub_id in range(3):
            sub_dtype = dtype + '_' + str(sub_id)
            if sub_dtype not in selected_idxs_dict:
                selected_idxs = random.sample(range(len(dr_samples)), k=query_num * 2)
                selected_idxs_dict[sub_dtype] = selected_idxs
            else:
                selected_idxs = selected_idxs_dict[sub_dtype]

            selected_samples = []
            for i in selected_idxs:
                sample = dr_samples[i]
                answers = []
                answer_types = []
                for q in sample['queries']:
                    query_text = q['query']
                    if 'language' in datatype:
                        query_text = query_text.replace('"', '')
                    if q['label'] == 'T':
                        answers.append(query_text)
                        answer_types.append('T')
                    elif q['label'] == 'F':
                        if 'language' in datatype:
                            query_text = query_text.replace(' is ', ' is not ')
                            query_text = query_text.replace(' are ', ' are not ')
                        else:
                            query_text = '-' + query_text

                        answers.append(query_text)
                        answer_types.append('F')
                if len(answers) == 0:
                    continue
                selected_samples.append({
                    'id': sample['id'],
                    'facts': sample['facts'],
                    'rules': sample['rules'],
                    'answers': answers,
                    'answer_types': answer_types,
                })

                if len(selected_samples) == query_num:
                    break
            selected_samples_dict[datatype + '_' + str(sub_id)] = selected_samples
    print('save generation eval sets...')
    if not os.path.exists(generation_evaluation_path):
        os.makedirs(generation_evaluation_path)
    for k, v in selected_samples_dict.items():
        saved_path = os.path.join(generation_evaluation_path, k+'.jsonl')
        with open(saved_path, 'w') as f:
            for query in v:
                f.write(json.dumps(query) + '\n')
    print('finish...')


if __name__ == '__main__':
    main(None)