import json
import random
from textwrap import indent

random.seed(0)

data_dir = 'data/'

datasets = ['boolq', 'piqa', 'siqa', 'hellaswag', 'winogrande', 'arce', 'arcc', 'obqa']

ins = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\nPlease answer the following question with true or false, question: does ethanol take more energy make that produces?\n\nAnswer format: true/false\n\n### Response:\n'

opt = 'the correct answer is false'


class Processor:
    def __init__(self, dataset):
        self.dataset = dataset

    def _process_boolq(self, ipt, opt):
        instruction = ipt.split('\n\nAnswer format:')[0]
        instruction = instruction.replace('Please answer the following question with true or false, question',
                                          'Please choose the correct answer to the question:')
        instruction += '\n\nA. true\nB. false\n\n### Response:\n'

        opt = 'Answer: A' if opt.endswith('true') else 'Answer: B'

        return instruction, opt

    def _process_piqa(self, ipt, opt):
        instruction = ipt.split('\n\nAnswer format:')[0]
        solution_1, solution_2 = instruction.split('\n\n')[-2], instruction.split('\n\n')[-1]
        solution_1 = solution_1.replace('Solution1:', 'A.')
        solution_2 = solution_2.replace('Solution2:', 'B.')
        instruction = ''.join(instruction.split('\n\n')[:-2])
        instruction += f'\n\n{solution_1}\n{solution_2}\n\n### Response:\n'

        opt = 'Answer: A' if opt.endswith('solution1') else 'Answer: B'

        return instruction, opt

    def _process_siqa(self, ipt, opt):
        instruction = ipt.split('\n\nAnswer format:')[0]
        instruction = instruction.replace('Answer1:', 'A.')
        instruction = instruction.replace(' Answer2:', '\nB.')
        instruction = instruction.replace(' Answer3:', '\nC.')
        instruction += '\n\n### Response:\n'

        if opt.endswith('answer1'):
            opt = 'Answer: A'
        elif opt.endswith('answer2'):
            opt = 'Answer: B'
        elif opt.endswith('answer3'):
            opt = 'Answer: C'
        else:
            raise ValueError('Invalid answer')

        return instruction, opt

    def _process_hellaswag(self, ipt, opt):
        instruction = ipt.split('\n\nAnswer format:')[0]
        instruction = instruction.replace('Ending1:', 'A.')
        instruction = instruction.replace(' Ending2:', '\nB.')
        instruction = instruction.replace(' Ending3:', '\nC.')
        instruction = instruction.replace(' Ending4:', '\nD.')
        instruction += '\n\n### Response:\n'

        if opt.endswith('ending1'):
            opt = 'Answer: A'
        elif opt.endswith('ending2'):
            opt = 'Answer: B'
        elif opt.endswith('ending3'):
            opt = 'Answer: C'
        elif opt.endswith('ending4'):
            opt = 'Answer: D'
        else:
            raise ValueError('Invalid answer')

        return instruction, opt

    def _process_winogrande(self, ipt, opt):
        instruction = ipt.split('Answer format:')[0]
        instruction = instruction.replace('Option1:', 'A.')
        instruction = instruction.replace(' Option2:', '\nB.')
        instruction += '\n\n### Response:\n'

        if opt.endswith('option1'):
            opt = 'Answer: A'
        elif opt.endswith('option2'):
            opt = 'Answer: B'
        else:
            raise ValueError('Invalid answer')

        return instruction, opt

    def _process_arc(self, ipt, opt):
        instruction = ipt.split('\n\nAnswer format:')[0]
        instruction = instruction.replace('Answer1:', 'A.')
        instruction = instruction.replace(' Answer2:', '\nB.')
        instruction = instruction.replace(' Answer3:', '\nC.')
        instruction = instruction.replace(' Answer4:', '\nD.')
        instruction = instruction.replace(' Answer5:', '\nE.')
        instruction += '\n\n### Response:\n'

        if opt.endswith('answer1'):
            opt = 'Answer: A'
        elif opt.endswith('answer2'):
            opt = 'Answer: B'
        elif opt.endswith('answer3'):
            opt = 'Answer: C'
        elif opt.endswith('answer4'):
            opt = 'Answer: D'
        elif opt.endswith('answer5'):
            opt = 'Answer: E'
        else:
            print(opt)
            raise ValueError('Invalid answer')

        return instruction, opt

    def process(self, ipt, opt):
        if self.dataset == 'boolq':
            return self._process_boolq(ipt, opt)
        elif self.dataset == 'piqa':
            return self._process_piqa(ipt, opt)
        elif self.dataset == 'siqa':
            return self._process_siqa(ipt, opt)
        elif self.dataset == 'hellaswag':
            return self._process_hellaswag(ipt, opt)
        elif self.dataset == 'winogrande':
            return self._process_winogrande(ipt, opt)
        elif self.dataset in ['arce', 'arcc', 'obqa']:
            return self._process_arc(ipt, opt)



def reg(dataset_name):
    dataset_info = json.load(open(f'{data_dir}/dataset_info.json', 'r'))
    dataset_info[dataset_name] = {
        "file_name": f"{dataset_name}.json",
        "columns": {
            "prompt": "instruction",
            "query": "input",
            "response": "output"
        }
    }
    json.dump(dataset_info, open(f'{data_dir}/dataset_info.json', 'w'), indent=2, ensure_ascii=False)


def gen(dataset):
    train = json.load(open(data_dir + f'{dataset}_train_0-shot.json', 'r'))
    test = json.load(open(data_dir + f'{dataset}_test_0-shot.json', 'r'))
    processor = Processor(dataset)
    for ex in train:
        ex['instruction'], ex['output'] = processor.process(ex['instruction'], ex['output'])
    # json.dump(train, open(data_dir + f'{dataset}.uni_train_0-shot.json', 'w'), indent=2, ensure_ascii=False)
    reg(f'{dataset}.uni_train_0-shot')

    for ex in test:
        ex['instruction'], ex['output'] = processor.process(ex['instruction'], ex['output'])
    # json.dump(test, open(data_dir + f'{dataset}.uni_test_0-shot.json', 'w'), indent=2, ensure_ascii=False)
    reg(f'{dataset}.uni_test_0-shot')

if __name__ == '__main__':
    gen('boolq')
    gen('piqa')
    gen('siqa')
    gen('hellaswag')
    gen('winogrande')
    gen('arcc')
    gen('arce')
    gen('obqa')