import json
import os
import random

from sympy.physics.control import ramp_response_numerical_data
from tqdm import tqdm

random.seed(0)

data_dir = 'data'

def reg(dataset_name):
    dataset_info = json.load(open(f'{data_dir}/dataset_info.json', 'r'))
    dataset_info[dataset_name] = {
        "file_name": f"{dataset_name}.json",
        "columns": {
            "prompt": "instruction",
            "query": "input",
            "response": "output",
            "history": "history"
        }
    }
    json.dump(dataset_info, open(f'{data_dir}/dataset_info.json', 'w'), indent=2, ensure_ascii=False)


class SingleGenerator:
    def __init__(self):
        self.all_datasets = [
            'boolq',
            'piqa',
            'siqa',
            'hellaswag',
            'winogrande',
            'arce',
            'arcc',
            'obqa',
        ]

    def get_demo(self, dataset: str):
        train_data_file = os.path.join(data_dir, f'{dataset}_train_0-shot.json')
        examples = json.load(open(train_data_file))
        random.shuffle(examples)
        
        return examples

    def generate(self):
        for dataset in self.all_datasets:
            demos = self.get_demo(dataset)
            test_data_file = os.path.join(data_dir, f'{dataset}_test_0-shot.json')
            test_data = json.load(open(test_data_file))
            n_shot_data = []
            for ex in tqdm(test_data, desc=f'{dataset}'):
                cur_demos = random.sample(demos, 8)
                ex['history'] = []
                ex['meta']['history_guid'] = []
                for demo in cur_demos:  
                    ex['history'].append([demo['instruction'], demo['output']])
                    ex['meta']['history_guid'].append(demo['meta']['guid'])
                n_shot_data.append(ex)
            json.dump(n_shot_data, open(os.path.join(data_dir, f'{dataset}_test_8-shot.json'), 'w'),
                      indent=2, ensure_ascii=False)
            reg(f'{dataset}_test_8-shot')
            

class HeldoutGenerator:
    def __init__(self):
        all_datasets = ['boolq.uni', 'piqa.uni', 'siqa.uni', 'hellaswag.uni', 'winogrande.uni', 'arce.uni', 'arcc.uni',
                        'obqa.uni']
        self.demos = {}
        for dataset in all_datasets:
            train_data_file = os.path.join(data_dir, f'{dataset}_train_0-shot.json')
            examples = json.load(open(train_data_file))
            random.shuffle(examples)
            self.demos[dataset] = examples[:1000]

    def add_demo(self, split_id:int, test_datasets: list, candidate_datasets: list):
        cnt = 0
        all_demos = []
        for dataset in candidate_datasets:
            all_demos.extend(self.demos[dataset])
        random.shuffle(all_demos)

        for dataset in test_datasets:
            print('processing', dataset)
            test_data_file = os.path.join(data_dir, f'{dataset}_test_0-shot.json')
            test_data = json.load(open(test_data_file))
            for ex in tqdm(test_data, desc=f'{dataset}'):
                cur_demos = random.sample(all_demos, 8)
                ex['history'] = []
                ex['meta']['history_guid'] = []
                for demo in cur_demos:
                    ex['history'].append([demo['instruction'], demo['output']])
                    ex['meta']['history_guid'].append(demo['meta']['guid'])

                cnt += 1

            json.dump(test_data, open(os.path.join(data_dir, f'{dataset}.split.{split_id}_test_8-shot.json'), 'w'),
                      indent=2, ensure_ascii=False)
            reg(f'{dataset}.split.{split_id}_test_8-shot')


def gen_heldout_data():
    generator = HeldoutGenerator()
    # split.0
    test_datasets = ["piqa.uni", "siqa.uni", "winogrande.uni", "arce.uni"]
    train_datasets = ["arcc.uni", "obqa.uni", "hellaswag.uni", "boolq.uni"]
    generator.add_demo(0, test_datasets, train_datasets)

    # split.1
    test_datasets = ["boolq.uni", "piqa.uni", "siqa.uni", "arcc.uni"]
    train_datasets = ["winogrande.uni", "obqa.uni", "hellaswag.uni", "arce.uni"]
    generator.add_demo(1, test_datasets, train_datasets)

    # split.2
    test_datasets = ["boolq.uni", "piqa.uni", "arce.uni", "arcc.uni"]
    train_datasets = ["winogrande.uni", "hellaswag.uni", "siqa.uni", "obqa.uni"]
    generator.add_demo(2, test_datasets, train_datasets)

def gen_single_data():
    generator = SingleGenerator()
    generator.generate()


if __name__ == '__main__':
    gen_single_data()