from accelerate.utils import prepare_deepspeed_cmd_env

from dataset_processor import (
    Gsm8kProcessor,
    SpiderProcessor,
    Sql2textProcessor,
    CommonsenseProcessor,
    Commonsense170kProcessor,
    UniHellaswagProcessor
)
import os
import json

raw_data_path = {
    'gsm8k': '/data/local/user/shenyu/projects/MeMv2/raw_datasets/gsm8k',
    'spider': '/data/local/user/shenyu/projects/MeMv2/raw_datasets/structlm',
    'sql2text': '/data/local/user/shenyu/projects/MeMv2/raw_datasets/structlm',
    # 'hellaswag': '/data/local/user/shenyu/projects/MeMv2/raw_datasets/hellaswag',
    'hellaswag': 'raw_datasets/hellaswag',
    'arce': 'raw_datasets/ARC-Easy',
    'arcc': 'raw_datasets/ARC-Challenge',
    'boolq': 'raw_datasets/boolq',
    'piqa': 'raw_datasets/piqa',
    'siqa': 'raw_datasets/social_i_qa',
    'winogrande': 'raw_datasets/winogrande',
    'obqa': 'raw_datasets/openbookqa',
    'commonsense_170k': 'raw_datasets/commonsense_170k'
}


def process_dataset(dataset_name: str, shots: list, output_dir: str, sampling: bool=False):
    os.makedirs(output_dir, exist_ok=True)

    print('Processing dataset:', dataset_name)

    for shot_n in shots:
        print('Processing shot {}'.format(shot_n))
        if dataset_name == 'gsm8k':
            data_processor = Gsm8kProcessor(dataset_name, raw_data_path['gsm8k'])
        elif dataset_name == 'spider':
            data_processor = SpiderProcessor(dataset_name, raw_data_path['spider'])
        elif dataset_name == 'sql2text':
            data_processor = Sql2textProcessor(dataset_name, raw_data_path['sql2text'])

        elif dataset_name in ['hellaswag', 'arcc', 'arce', 'boolq', 'piqa', 'siqa', 'winogrande', 'obqa']:
            data_processor = CommonsenseProcessor(dataset_name, raw_data_path[dataset_name])
        elif dataset_name == 'commonsense_170k':
            data_processor = Commonsense170kProcessor(dataset_name, raw_data_path['commonsense_170k'])
        elif dataset_name == 'hellaswag.uni':
            data_processor = UniHellaswagProcessor(dataset_name)
        else:
            raise NotImplementedError

        train_data, test_data = data_processor.process(shot_n=shot_n, sampling=sampling)

        if os.path.exists(f'{output_dir}/dataset_info.json'):
            dataset_info = json.load(open(f'{output_dir}/dataset_info.json', 'r', encoding='utf-8'))
        else:
            dataset_info = {}

        json.dump(train_data, open(f'{output_dir}/{dataset_name}_train_{shot_n}-shot.json', 'w', encoding='utf-8'),
                  ensure_ascii=False, indent=2)
        if test_data is not None:
            json.dump(test_data, open(f'{output_dir}/{dataset_name}_test_{shot_n}-shot.json', 'w', encoding='utf-8'),
                      ensure_ascii=False, indent=2)

        columns = {
            'prompt': 'instruction',
            'query': 'input',
            'response': 'output',
        }
        if 'history' in train_data[0].keys():
            columns['history'] = 'history'

        dataset_info[f'{dataset_name}_train_{shot_n}-shot'] = {
            'file_name': f'{dataset_name}_train_{shot_n}-shot.json',
            'columns': columns
        }
        if test_data is not None:
            dataset_info[f'{dataset_name}_test_{shot_n}-shot'] = {
                'file_name': f'{dataset_name}_test_{shot_n}-shot.json',
                'columns': columns
            }

        json.dump(dataset_info, open(f'{output_dir}/dataset_info.json', 'w', encoding='utf-8'), ensure_ascii=False,
                  indent=2)


if __name__ == '__main__':
    # output_dir = '/data/local/user/shenyu/projects/MeMv2/data'
    output_dir = 'data/'
    shots = [0]
    # process_dataset('hellaswag', shots=shots, output_dir=output_dir, sampling=False)
    # process_dataset('arce', shots=shots, output_dir=output_dir, sampling=False)
    # process_dataset('arcc', shots=shots, output_dir=output_dir, sampling=False)
    # process_dataset('boolq', shots=shots, output_dir=output_dir, sampling=False)
    # process_dataset('piqa', shots=shots, output_dir=output_dir, sampling=False)
    # process_dataset('siqa', shots=shots, output_dir=output_dir, sampling=False)
    # process_dataset('winogrande', shots=shots, output_dir=output_dir, sampling=False)
    # process_dataset('obqa', shots=shots, output_dir=output_dir, sampling=False)
    # process_dataset('commonsense_170k', shots=shots, output_dir=output_dir, sampling=False)

    process_dataset('hellaswag.uni', shots, output_dir)
