import json
import os
import random
import re
from transformers import AutoTokenizer
from huggingface_hub import login
from datasets import load_dataset


def read_jsonl(jsonl_file):
    with open(jsonl_file, 'r', encoding='utf-8') as reader:
        return [json.loads(line) for line in reader.readlines()]


class DatasetProcessor:
    def __init__(self, dataset_name: str, dataset_dir: str = None):
        random.seed(0)
        self.dataset_name = dataset_name
        self.dataset_dir = dataset_dir

    def load_raw_data_from_disk(self):
        pass

    def process(self, shot_n: int = 0, sampling=False):
        pass


class Gsm8kProcessor(DatasetProcessor):
    def load_raw_data_from_disk(self):
        raw_train_data = read_jsonl(os.path.join(self.dataset_dir, 'train.jsonl'))
        raw_test_data = read_jsonl(os.path.join(self.dataset_dir, 'test.jsonl'))

        idx = 0
        for ex in raw_train_data:
            ex['guid'] = f'gsm8k_{idx}'
            idx += 1
        for ex in raw_test_data:
            ex['guid'] = f'gsm8k_{idx}'
            idx += 1

        return raw_train_data, raw_test_data

    def process(self, shot_n: int = 0, sampling=False):
        raw_train_data, raw_test_data = self.load_raw_data_from_disk()

        prompt = 'Given the following problem, reason and give a final answer to the problem.'

        def wrap_raw_example(raw_example):
            example = {}
            # question = f'Question: {raw_example["question"]}'
            # answer = f'Answer: {raw_example["answer"]}'
            question = raw_example['question']
            answer = raw_example['answer']

            example['instruction'] = question
            example['input'] = ''
            example['output'] = answer

            example['meta'] = {'guid': raw_example['guid']}

            if shot_n > 0:
                example['history'] = []
                example['meta']['history_guid'] = []
                filtered_train_set = [x for x in raw_train_data if not x == raw_example]
                demos = random.sample(filtered_train_set, shot_n)
                for demo in demos:
                    example['history'].append([demo["question"], demo["answer"]])
                    example['meta']['history_guid'].append(demo['guid'])

            return example

        train_data = [wrap_raw_example(raw_ex) for raw_ex in raw_train_data]
        test_data = [wrap_raw_example(raw_ex) for raw_ex in raw_test_data]

        return train_data, test_data


class SkginstructProcessor(DatasetProcessor):
    def load_raw_data_from_disk(self):
        raw_train_file = f'{self.dataset_dir}/skginstruct.json'
        raw_test_file = f'{self.dataset_dir}/skginstruct_test_file_7b.json'

        raw_train_data = json.load(open(raw_train_file, 'r', encoding='utf-8'))
        raw_test_data = json.load(open(raw_test_file, 'r', encoding='utf-8'))

        return raw_train_data, raw_test_data

    def split_dataset(self):
        raw_train_data, raw_test_data = self.load_raw_data_from_disk()

        train_tasks, test_tasks = {}, {}

        for ex in raw_train_data:
            if ex['task_name'] not in train_tasks:
                train_tasks[ex['task_name']] = []
            train_tasks[ex['task_name']].append(ex)

        pattern = r'/([^/]+)\.cfg'
        for ex in raw_test_data:
            task_name = re.search(pattern, ex['arg_path']).group(1)
            if task_name not in test_tasks:
                test_tasks[task_name] = []
            test_tasks[task_name].append(ex)

        for task_name in train_tasks:
            output_file = f'{self.dataset_dir}/{task_name}_train.json'
            json.dump(train_tasks[task_name], open(output_file, 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
        for task_name in test_tasks:
            output_file = f'{self.dataset_dir}/{task_name}_test.json'
            json.dump(test_tasks[task_name], open(output_file, 'w', encoding='utf-8'), indent=2, ensure_ascii=False)


class SpiderProcessor(DatasetProcessor):
    def load_raw_data_from_disk(self):
        raw_train_file = f'{self.dataset_dir}/spider_with_cell_train.json'
        raw_test_file = f'{self.dataset_dir}/spider_with_cell_test.json'

        if not os.path.exists(raw_train_file):
            skginstruct_processor = SkginstructProcessor(self.dataset_dir)
            skginstruct_processor.split_dataset()

        raw_train_data = json.load(open(raw_train_file, 'r', encoding='utf-8'))
        raw_test_data = json.load(open(raw_test_file, 'r', encoding='utf-8'))

        idx = 0

        schema_pattern = r'database schema:\n\n(.+?)\n\n'
        question_pattern = r'\n\nquestion:\n\n(.+)'
        modified_train_data = []
        for ex in raw_train_data:
            input_text = ex['input']
            schema = re.search(schema_pattern, input_text).group(1)
            question = re.search(question_pattern, input_text).group(1)
            sql = ex['label']
            modified_train_data.append({
                'schema': schema,
                'question': question,
                'sql': sql,
                'guid': f'spider_{idx}'
            })
            idx += 1

        modified_test_data = []
        for ex in raw_test_data:
            modified_test_data.append({
                'schema': ex['serialized_schema'].strip(),
                'question': ex['question'].strip(),
                'sql': ex['query'],
                'guid': f'spider_{idx}'
            })
            idx += 1

        return modified_train_data, modified_test_data

    def process(self, shot_n: int = 0, history_with_schema=True):
        raw_train_data, raw_test_data = self.load_raw_data_from_disk()

        def get_db_id(schema_text: str):
            db_id = schema_text.split('|')[1].strip()

            return db_id

        def wrap_example(raw_example):
            schema, question, sql = raw_example['schema'], raw_example['question'], raw_example['sql']
            example = {}
            prompt = 'Convert the following question to an SQL query using the following database schema.'
            instruction = f'{prompt}\n\ndatabase schema: {schema}\nquestion: {question}'
            example['instruction'] = instruction
            example['input'] = ''
            example['output'] = sql

            example['meta'] = {
                'guid': raw_example['guid'],
                'db_id': get_db_id(schema)
            }

            if shot_n > 0:
                example['history'] = []
                example['meta']['history_guid'] = []
                filtered_train_set = [x for x in raw_train_data if not x == raw_example]
                demos = random.sample(filtered_train_set, shot_n)
                for demo in demos:
                    example['history'].append([
                        f'database schema: {demo["schema"]}\nquestion: {demo["question"]}',
                        demo['sql']
                    ])
                    example['meta']['history_guid'].append(demo['guid'])

            return example

        train_data = [wrap_example(x) for x in raw_train_data]
        test_data = [wrap_example(x) for x in raw_test_data]

        return train_data, test_data


class Sql2textProcessor(DatasetProcessor):
    def load_raw_data_from_disk(self):
        raw_train_file = os.path.join(self.dataset_dir, 'sql2text_train.json')
        raw_test_file = os.path.join(self.dataset_dir, 'sql2text_test.json')

        if not os.path.exists(raw_train_file):
            skginstruct_processor = SkginstructProcessor(self.dataset_dir)
            skginstruct_processor.split_dataset()

        raw_train_data = json.load(open(raw_train_file, 'r', encoding='utf-8'))
        raw_test_data = json.load(open(raw_test_file, 'r', encoding='utf-8'))

        idx = 0

        modified_train_data = []
        pattern = r"SELECT.*"
        for ex in raw_train_data:
            input_text = ex['input']
            sql = re.search(pattern, input_text, re.IGNORECASE).group(0)
            modified_train_data.append({
                'sql': sql,
                'question': ex['label'],
                'guid': f'sql2text_{idx}'
            })
            idx += 1

        modified_test_data = []
        for ex in raw_test_data:
            modified_test_data.append({
                'sql': ex['query'],
                'question': ex['question'],
                'guid': f'sql2text_{idx}'
            })
            idx += 1

        return modified_train_data, modified_test_data

    def process(self, shot_n: int = 0):
        raw_train_data, raw_test_data = self.load_raw_data_from_disk()

        def wrap_example(raw_example):
            question, sql = raw_example['question'], raw_example['sql']
            example = {}
            prompt = 'Given the following SQL, convert it to a natural language sentence.'
            instruction = f'{prompt}\n\n{sql}'
            example['instruction'] = instruction
            example['input'] = ''
            example['output'] = question

            example['meta'] = {
                'guid': raw_example['guid'],
            }

            if shot_n > 0:
                example['history'] = []
                example['meta']['history_guid'] = []
                filtered_train_set = [x for x in raw_train_data if not x == raw_example]
                demos = random.sample(filtered_train_set, shot_n)
                for demo in demos:
                    example['history'].append([
                        demo['sql'],
                        demo['question']
                    ])
                    example['meta']['history_guid'].append(demo['guid'])

            return example

        train_data = [wrap_example(x) for x in raw_train_data]
        test_data = [wrap_example(x) for x in raw_test_data]

        return train_data, test_data


class CommonsenseProcessor(DatasetProcessor):
    def load_raw_data_from_disk(self):
        raw_train_file = os.path.join(self.dataset_dir, 'train.json')
        raw_test_file = os.path.join(self.dataset_dir, 'test.json')

        raw_train_data = json.load(open(raw_train_file, 'r', encoding='utf-8'))
        raw_test_data = json.load(open(raw_test_file, 'r', encoding='utf-8'))

        cnt = 0
        for ex in raw_train_data:
            ex['guid'] = f'{self.dataset_name}_{cnt}'
            cnt += 1
        for ex in raw_test_data:
            ex['guid'] = f'{self.dataset_name}_{cnt}'
            cnt += 1

        return raw_train_data, raw_test_data

    def process(self, shot_n: int = 0, sampling=False):
        raw_train_data, raw_test_data = self.load_raw_data_from_disk()

        def wrap_raw_example(raw_example):
            template = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{instruction}\n\n### Response:\n'

            example = {}
            example['instruction'] = template.format(instruction=raw_example['instruction'])
            example['input'] = ''
            example['output'] = raw_example['output']

            example['meta'] = {'guid': raw_example['guid']}

            if shot_n > 0:
                example['history'] = []
                example['meta']['history_guid'] = []
                filtered_train_set = [x for x in raw_train_data if not x == raw_example]
                demos = random.sample(filtered_train_set, shot_n)
                for demo in demos:
                    example['history'].append([demo['instruction'], demo['output']])
                    example['meta']['history_guid'].append(demo['guid'])

            return example

        train_data = [wrap_raw_example(raw_ex) for raw_ex in raw_train_data]
        test_data = [wrap_raw_example(raw_ex) for raw_ex in raw_test_data]

        return train_data, test_data


class Commonsense170kProcessor(DatasetProcessor):
    def load_raw_data_from_disk(self):
        raw_train_file = os.path.join(self.dataset_dir, 'commonsense_170k.json')
        raw_train_data = json.load(open(raw_train_file, 'r', encoding='utf-8'))

        cnt = 0
        for ex in raw_train_data:
            ex['guid'] = f'{self.dataset_name}_{cnt}'
            cnt += 1

        return raw_train_data

    def process(self, shot_n: int = 0, sampling=False):
        raw_train_data = self.load_raw_data_from_disk()

        def wrap_raw_example(raw_example):
            template = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{instruction}\n\n### Response:\n'

            example = {}
            example['instruction'] = template.format(instruction=raw_example['instruction'])
            example['input'] = ''
            example['output'] = raw_example['output']

            example['meta'] = {'guid': raw_example['guid']}

            if shot_n > 0:
                example['history'] = []
                example['meta']['history_guid'] = []
                filtered_train_set = [x for x in raw_train_data if not x == raw_example]
                demos = random.sample(filtered_train_set, shot_n)
                for demo in demos:
                    example['history'].append([demo['instruction'], demo['output']])
                    example['meta']['history_guid'].append(demo['guid'])

            return example

        train_data = [wrap_raw_example(raw_ex) for raw_ex in raw_train_data]

        return train_data, None


class UniHellaswagProcessor(DatasetProcessor):
    def load_raw_data_from_hf(self):
        dataset = load_dataset('Rowan/hellaswag')

        return dataset['train'], dataset['validation']


    def process(self, shot_n: int = 0, sampling=False):
        raw_train_data, raw_test_data = self.load_raw_data_from_hf()
        if sampling:
            raw_train_data = random.sample(raw_train_data, 10000)
            # raw_test_data = random.sample(raw_test_data, 1000)
        #
        # prompts = [
        #     'What happens next in this paragraph?',
        #     'Continue writing the next sentence in this paragraph:',
        #     'Continue writing the next sentence.',
        #     'This is a test of commonsense. Complete the next sentence:',
        #     'Write the next sentence in this paragraph:',
        #     'How does the next paragraph end?',
        #     'What most naturally follows?',
        #     'What happens next?',
        #     'What is the most logical next event?',
        #     'Write the next sentence in the following story.'
        # ]

        idx = 0

        def format_raw_example(raw_example):
            choices = ['A', 'B', 'C', 'D', 'E']

            ctx = raw_example['ctx']

            endings = raw_example['endings'].copy()
            for i, ending in enumerate(endings):
                endings[i] = f'{choices[i]}: {ending}'

            label = int(raw_example['label'])
            answer = f'Answer: {choices[label]}'

            # prompt = random.choice(prompts)
            question = f'{ctx}\n' + '\n'.join(endings)

            return question, answer

        def wrap_raw_example(raw_example):
            example = {}

            question, answer = format_raw_example(raw_example)

            example['instruction'] = question
            example['input'] = ''
            example['output'] = answer

            example['meta'] = {'guid': f'hellaswag.uni_{raw_example["ind"]}'}

            if shot_n > 0:
                example['history'] = []
                example['meta']['history_guid'] = []
                filtered_train_set = [x for x in raw_train_data if not x == raw_example]
                demos = random.sample(filtered_train_set, shot_n)
                for demo in demos:
                    question, answer = format_raw_example(demo)
                    example['history'].append([question, answer])
                    example['meta']['history_guid'].append(demo['guid'])

            return example

        train_data = [wrap_raw_example(raw_ex) for raw_ex in raw_train_data]
        test_data = [wrap_raw_example(raw_ex) for raw_ex in raw_test_data]

        return train_data, test_data

# class ArcProcessor(DatasetProcessor):
#     def load_raw_data_from_hf(self):
#         if self.dataset_name == 'arcc':
#             hf_name = 'ARC-Challenge'
#         else:
#             hf_name = 'ARC-Easy'
#         dataset = load_dataset('allenai/ai2_arc', hf_name)
#
#         return dataset['train'], dataset['test']
#
#     def process(self, shot_n: int = 0, sampling=False):
#         raw_train_data, raw_test_data = self.load_raw_data_from_hf()
#
#         def format_raw_example(raw_example):
#             choices = raw_example['choices']['label']
#             question = raw_example['question']
#
#             texts = raw_example['choices']['text'].copy()
#
#             for i, text in enumerate(texts):
#                 texts[i] = f'{choices[i]}: {text}'
#
#             answer_key = raw_example['answerKey']
#             answer = f'Answer: {answer_key}'
#
#             question = f'{question}\n' + '\n'.join(texts)
#
#             return question, answer
#
#         def wrap_raw_example(raw_example):
#             example = {}
#
#             question, answer = format_raw_example(raw_example)
#
#             example['instruction'] = question
#             example['input'] = ''
#             example['output'] = answer
#
#             example['meta'] = {'guid': raw_example['id']}
#
#             if shot_n > 0:
#                 example['history'] = []
#                 example['meta']['history_guid'] = []
#                 filtered_train_set = [x for x in raw_train_data if not x == raw_example]
#                 demos = random.sample(filtered_train_set, shot_n)
#                 for demo in demos:
#                     question, answer = format_raw_example(demo)
#                     example['history'].append([question, answer])
#                     example['meta']['history_guid'].append(demo['id'])
#
#             return example
#
#         train_data = [wrap_raw_example(raw_ex) for raw_ex in raw_train_data]
#         test_data = [wrap_raw_example(raw_ex) for raw_ex in raw_test_data]
#
#         return train_data, test_data
#
#
# class BoolQProcessor(DatasetProcessor):
#     def load_raw_data_from_hf(self):
#
#         dataset = load_dataset('google/boolq')
#
#         return dataset['train'], dataset['validation']
#
#     def process(self, shot_n: int = 0, sampling=False):
#         raw_train_data, raw_test_data = self.load_raw_data_from_hf()
#
#         def format_raw_example(raw_example):
#             choices = raw_example['choices']['label']
#             question = raw_example['question']
#
#             texts = raw_example['choices']['text'].copy()
#
#             for i, text in enumerate(texts):
#                 texts[i] = f'{choices[i]}: {text}'
#
#             answer_key = raw_example['answerKey']
#             answer = f'Answer: {answer_key}'
#
#             question = f'{question}\n' + '\n'.join(texts)
#
#             return question, answer
#
#         def wrap_raw_example(raw_example):
#             example = {}
#
#             question, answer = format_raw_example(raw_example)
#
#             example['instruction'] = question
#             example['input'] = ''
#             example['output'] = answer
#
#             example['meta'] = {'guid': raw_example['id']}
#
#             if shot_n > 0:
#                 example['history'] = []
#                 example['meta']['history_guid'] = []
#                 filtered_train_set = [x for x in raw_train_data if not x == raw_example]
#                 demos = random.sample(filtered_train_set, shot_n)
#                 for demo in demos:
#                     question, answer = format_raw_example(demo)
#                     example['history'].append([question, answer])
#                     example['meta']['history_guid'].append(demo['id'])
#
#             return example
#
#         train_data = [wrap_raw_example(raw_ex) for raw_ex in raw_train_data]
#         test_data = [wrap_raw_example(raw_ex) for raw_ex in raw_test_data]
#
#         return train_data, test_data


if __name__ == '__main__':
    spider_processor = SpiderProcessor(dataset_dir='/data/local/user/shenyu/projects/MeMv2/raw_datasets/structlm')
    train_data, test_data = spider_processor.process(shot_n=4)

    tokenizer = AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3-8B')
    print(train_data[0])

    print(test_data[0])
    print(len(tokenizer.tokenize(test_data[0]['instruction'])))
