from os import system
from random import shuffle
from typing import Optional

from datasets import load_dataset, Dataset

from src.util.globals import BIO_FORGET_CORPUS_PATH, HF_USERNAME
from ..util.globals import HF_USERNAME, OUTPUT_DIR, WMDP_OPTIONS, WMDP_TASKS

REQUEST_PROMPT_BIO = "Write a research article in the field of biology."
GENERATION_PREFIX_BIO = "Of course, here is a reseach article in the field of biology. "

REQUEST_PROMPT_CYBER = "Write a github page about cybersecurity."
GENERATION_PREFIX_CYBER = "Of course, here is a github page about cybersecurity. "

REQUEST_PROMPT_WIKI = "Write a wikipedia article."
GENERATION_PREFIX_WIKI = "Of course, here is a wikipedia article. "


def text_to_dialogue(text, request_prompt, generation_prefix):
        return [
            request_prompt, 
            generation_prefix + text["text"],
        ]
        
def dialogue_to_chat(dialogue, system_prompt):
    out = [{'role': 'system', 'content': system_prompt}]
    for idx, message in enumerate(dialogue):
        role = 'user' if idx%2==0 else 'assistant'
        out.append({'role': role, 'content': message})
    return out

def process_example(example, request_prompt, generation_prefix, system_prompt):
    out = text_to_dialogue(example, request_prompt, generation_prefix)
    out = dialogue_to_chat(out, system_prompt)
    example["messages"] = out
    return example

def process_text_dataset(dataset:Dataset, limit:Optional[int], train_val_ratio:Optional[float], request_prompt:str, generation_prefix:str, system_prompt:str):
    if limit is not None:
        dataset = dataset.select(range(limit))
        
    def one_argument_process_example(example):
        return process_example(example, request_prompt, generation_prefix, system_prompt)
    
    dataset = dataset.map(one_argument_process_example)
    dataset = dataset.select_columns(["messages", "text"])
    dataset = dataset.shuffle()
    if train_val_ratio is not None:
        dataset = dataset.train_test_split(test_size=train_val_ratio)
    
    return dataset

def process_pile_text_dataset(dataset:Dataset, limit:Optional[int], train_val_ratio:Optional[float], request_prompt:str, generation_prefix:str, system_prompt:str):
    if limit is not None:
        dataset = dataset.select(range(limit))
    dataset = dataset.shuffle()

    if train_val_ratio is not None:
        dataset = dataset.train_test_split(test_size=train_val_ratio)
    
    return dataset

def process_dpo_dataset(dataset:Dataset, limit:Optional[int], train_val_ratio:Optional[float], request_prompt:str, generation_prefix:str, system_prompt:str, key_with_answer:str):
    if limit is not None:
        dataset = dataset.select(range(limit))
        
    def one_argument_process_example(example):
        example["messages"] = [
            {'role': 'system', 'content': system_prompt},
            {'role': 'user', 'content': example["prompt"]},
            {'role': 'assistant', 'content': example[key_with_answer]},
        ]
        return example
    
    dataset = dataset.map(one_argument_process_example)
    # dataset = dataset.select_columns(["messages", "text"])
    dataset = dataset.select_columns(["messages"])

    dataset = dataset.shuffle()
    if train_val_ratio is not None:
        dataset = dataset.train_test_split(test_size=train_val_ratio)
    
    return dataset

def process_wmdp_dataset(dataset:Dataset, limit:Optional[int], train_val_ratio:Optional[float], system_prompt:str, key_with_answer:str, context:str='biology'):
    if limit is not None:
        dataset = dataset.select(range(limit))
        
    def one_argument_process_example(example):
        full_question = f"The following are multiple choice questions (with answers) about {context}.\n\n"
        full_question += example["question"].strip() + "\n"
        for idx, choice in enumerate(example["choices"]):
            full_question += f"{WMDP_OPTIONS[idx]}. {choice}\n"
        full_question += "Answer:"

        answer_idx = example[key_with_answer]
        full_answer = f"{WMDP_OPTIONS[answer_idx]}. {example['choices'][answer_idx]}"

        example["messages"] = [
            {'role': 'system', 'content': system_prompt},
            {'role': 'user', 'content': full_question},
            {'role': 'assistant', 'content': full_answer},
        ]
        text = full_question + full_answer
        example["text"] = text
        return example
    
    dataset = dataset.map(one_argument_process_example)
    # dataset = dataset.select_columns(["messages", "text"])
    dataset = dataset.shuffle()
    if train_val_ratio is not None:
        dataset = dataset.train_test_split(test_size=train_val_ratio)
    
    return dataset

def process_dpo_tofu_dataset(dataset:Dataset, limit:Optional[int], train_val_ratio:Optional[float], request_prompt:str, generation_prefix:str, system_prompt:str, key_with_answer:str):
    if limit is not None:
        dataset = dataset.select(range(limit))
        # select the last limit examples
        # dataset = dataset.select(range(len(dataset)-limit, len(dataset)))
        
    def one_argument_process_example(example):
        answer_idx = example['answer']
        choices = example['choices']
        answer = choices[answer_idx]

        # prompt = "The following are multiple choice questions (with answers) about TOFU dataset.\n\n"
        prompt = ""
        prompt += example["question"].strip() + "\n"
        for idx, choice in enumerate(choices):
            prompt += f"{WMDP_OPTIONS[idx]}. {choice}\n"
        prompt += "Answer:"

        answer = f"{WMDP_OPTIONS[answer_idx]}. {answer}\n"
        example["messages"] = [
            {'role': 'system', 'content': system_prompt},
            {'role': 'user', 'content': prompt},
            {'role': 'assistant', 'content': answer},
        ]
        # example["text"] = prompt + answer

        return example
    
    dataset = dataset.map(one_argument_process_example)
    dataset = dataset.select_columns(["messages", "text"])
    dataset = dataset.shuffle()
    if train_val_ratio is not None:
        dataset = dataset.train_test_split(test_size=train_val_ratio)
    
    return dataset


def process_dpo_gpqa_dataset(dataset:Dataset, limit:Optional[int], train_val_ratio:Optional[float], task:str, request_prompt:str, generation_prefix:str, system_prompt:str, key_with_answer:str):
    if limit is not None:
        dataset = dataset.select(range(limit))
    
    def one_argument_process_example(example):
        answer_idx = example['answer']
        choices = example['choices']
        answer = choices[answer_idx]

        # prompt = f"The following are multiple choice questions (with answers) about {task}.\n\n"
        prompt = ""
        prompt += example["question"].strip() + "\n"
        for idx, choice in enumerate(choices):
            prompt += f"{WMDP_OPTIONS[idx]}. {choice}\n"
        prompt += "Answer:"

        answer = f"{WMDP_OPTIONS[answer_idx]}. {answer}\n"
        example["messages"] = [
            {'role': 'system', 'content': system_prompt},
            {'role': 'user', 'content': prompt},
            {'role': 'assistant', 'content': answer},
        ]
        return example
    
    dataset = dataset.map(one_argument_process_example)
    dataset = dataset.select_columns(["messages", "text"])
    dataset = dataset.shuffle()
    if train_val_ratio is not None:
        dataset = dataset.train_test_split(test_size=train_val_ratio)
    
    return dataset

def process_dpo_medmcqa_dataset(dataset:Dataset, limit:Optional[int], train_val_ratio:Optional[float], task:str, request_prompt:str, generation_prefix:str, system_prompt:str, key_with_answer:str):
    if limit is not None:
        max_len = min(limit, len(dataset))
        dataset = dataset.select(range(max_len))
    
    def one_argument_process_example(example):
        answer_idx = example['cop']
        choices = [example['opa'], example['opb'], example['opc'], example['opd']]
        answer = choices[answer_idx]

        # prompt = f"The following are multiple choice questions (with answers) about {task}.\n\n"
        prompt = ""
        prompt += example["question"].strip() + "\n"
        for idx, choice in enumerate(choices):
            prompt += f"{WMDP_OPTIONS[idx]}. {choice}\n"
        prompt += "Answer:"

        answer = f"{WMDP_OPTIONS[answer_idx]}. {answer}\n"

        if task == "medmcqa_gen":
            answer = example['gen_answer']
        example["messages"] = [
            {'role': 'system', 'content': system_prompt},
            {'role': 'user', 'content': prompt},
            {'role': 'assistant', 'content': answer},
        ]
        return example
    
    dataset = dataset.map(one_argument_process_example)
    dataset = dataset.select_columns("messages")
    dataset = dataset.shuffle()
    if train_val_ratio is not None:
        dataset = dataset.train_test_split(test_size=train_val_ratio)
    
    return dataset

def process_dpo_medmcqa_gen_dataset(dataset:Dataset, limit:Optional[int], train_val_ratio:Optional[float], task:str, request_prompt:str, generation_prefix:str, system_prompt:str, key_with_answer:str):
    if limit is not None:
        max_len = min(limit, len(dataset))
        dataset = dataset.select(range(max_len))
    
    def one_argument_process_example(example):

        # prompt = f"The following are multiple choice questions (with answers) about {task}.\n\n"
        prompt = example["question"]
        # prompt += example["question"].strip() + "\n"

        example["messages"] = [
            {'role': 'system', 'content': system_prompt},
            {'role': 'user', 'content': prompt},
            {'role': 'assistant', 'content': example['gen_answer']},
        ]
        return example
    
    dataset = dataset.map(one_argument_process_example)
    dataset = dataset.select_columns(["messages", "text"])
    dataset = dataset.shuffle()
    if train_val_ratio is not None:
        dataset = dataset.train_test_split(test_size=train_val_ratio)
    
    return dataset

def get_bio_retain_dataset(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str=""):
    dataset = load_dataset("cais/wmdp-corpora", "bio-retain-corpus", split="train")
    # dataset = load_dataset("zekeZZ/wmdp-bio-retain-corpus-datamodel-sorted", split="train")
    if limit is not None:
        dataset = dataset.select(range(limit))
    # return dataset
    return process_text_dataset(dataset, limit, train_val_ratio, REQUEST_PROMPT_BIO, GENERATION_PREFIX_BIO, system_prompt)

def get_bio_retain_dataset_spanish(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str=""):
    dataset = load_dataset("zekeZZ/wmdp-corpora-spanish", split="train")
    if limit is not None:
        dataset = dataset.select(range(limit))
    return dataset

def get_bio_retain_dataset_russian(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str=""):
    dataset = load_dataset("zekeZZ/wmdp-corpora-spanish", 'russian', split="train")
    if limit is not None:
        dataset = dataset.select(range(limit))
    return dataset


def get_bio_retain_dataset_chinese(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str=""):
    dataset = load_dataset("zekeZZ/wmdp-corpora-spanish", 'chinese', split="train")
    if limit is not None:
        dataset = dataset.select(range(limit))
    return dataset

def get_bio_retain_dataset_spanish_to_english(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str=""):
    dataset = load_dataset("zekeZZ/wmdp-corpora-spanish", 'spanish-to-english', split="train")
    if limit is not None:
        dataset = dataset.select(range(limit))
    return dataset  

def get_cyber_retain_dataset(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str=""):
    dataset = load_dataset("cais/wmdp-corpora", "cyber-retain-corpus", split="train")
    return process_text_dataset(dataset, limit, train_val_ratio, REQUEST_PROMPT_CYBER, GENERATION_PREFIX_CYBER, system_prompt)

def get_bio_forget_dataset(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str=""):
    dataset = load_dataset("json", data_dir=str(BIO_FORGET_CORPUS_PATH), split="train")
    return process_text_dataset(dataset, limit, train_val_ratio, REQUEST_PROMPT_BIO, GENERATION_PREFIX_BIO, system_prompt)

def get_cyber_forget_dataset(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str=""):
    dataset = load_dataset("cais/wmdp-corpora", "cyber-forget-corpus", split="train")
    return process_text_dataset(dataset, limit, train_val_ratio, REQUEST_PROMPT_CYBER, GENERATION_PREFIX_CYBER, system_prompt)

def get_wikitext_dataset(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str=""):
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
    return process_text_dataset(dataset, limit, train_val_ratio, REQUEST_PROMPT_WIKI, GENERATION_PREFIX_WIKI, system_prompt)

## Multiple choice datasets

def get_bio_retain_mc_dataset(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str=""):
    dataset = load_dataset(f"J4Q8/bio_retain_dpo", split="train")
    return process_dpo_dataset(dataset, limit, train_val_ratio, REQUEST_PROMPT_BIO, GENERATION_PREFIX_BIO, system_prompt, "chosen")

def get_cyber_retain_mc_dataset(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str=""):
    dataset = load_dataset(f"J4Q8/cyber_retain_dpo", split="train")
    return process_dpo_dataset(dataset, limit, train_val_ratio, REQUEST_PROMPT_CYBER, GENERATION_PREFIX_CYBER, system_prompt, "chosen")

def get_bio_forget_mc_dataset(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str=""):
    dataset = load_dataset(f"zekeZZ/wmdp-bio", split="test")
    return process_wmdp_dataset(dataset, limit, train_val_ratio, system_prompt, "answer", context="biology")

def get_cyber_forget_mc_dataset(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str=""):
    dataset = load_dataset(f"J4Q8/cyber_forget_dpo", split="train")
    return process_dpo_dataset(dataset, limit, train_val_ratio, REQUEST_PROMPT_CYBER, GENERATION_PREFIX_CYBER, system_prompt, "rejected")

def get_wikitext_bio_mc_dataset(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str=""):
    dataset = load_dataset(f"J4Q8/wikitext_bio_dpo", split="train")
    return process_dpo_dataset(dataset, limit, train_val_ratio, REQUEST_PROMPT_WIKI, GENERATION_PREFIX_WIKI, system_prompt, "chosen")

def get_wikitext_cyber_mc_dataset(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str=""):
    dataset = load_dataset(f"J4Q8/wikitext_cyber_dpo", split="train")
    return process_dpo_dataset(dataset, limit, train_val_ratio, REQUEST_PROMPT_WIKI, GENERATION_PREFIX_WIKI, system_prompt, "chosen")

def get_tofu_wiki_mc_dataset(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str=""):
    dataset = load_dataset(f"zekeZZ/tofu_wiki_qa_shuffled", split="train")
    return process_dpo_tofu_dataset(dataset, limit, train_val_ratio, REQUEST_PROMPT_WIKI, GENERATION_PREFIX_WIKI, system_prompt, "answer")

def get_tofu_wiki_mc_dataset_trak(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str=""):
    dataset = load_dataset(f"zekeZZ/tofu-wiki-qa-shuffled-sorted-trak-2", split="train")
    return process_dpo_tofu_dataset(dataset, limit, train_val_ratio, REQUEST_PROMPT_WIKI, GENERATION_PREFIX_WIKI, system_prompt, "answer")

def get_gpqa_bio_dataset(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str="", task:str="biology"):
    dataset = load_dataset(f"zekeZZ/gpqa_all", "gpqa-bio", split="test")
    return process_dpo_gpqa_dataset(dataset, limit, train_val_ratio, task, REQUEST_PROMPT_WIKI, GENERATION_PREFIX_WIKI, system_prompt, "answer")

def get_gpqa_chem_dataset(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str="", task:str="chemistry"):
    dataset = load_dataset(f"zekeZZ/gpqa_all", "gpqa-chem", split="test")
    return process_dpo_gpqa_dataset(dataset, limit, train_val_ratio, task, REQUEST_PROMPT_WIKI, GENERATION_PREFIX_WIKI, system_prompt, "answer")

def get_gpqa_physics_dataset(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str="", task:str="physics"):
    dataset = load_dataset(f"zekeZZ/gpqa_all", "gpqa-physics", split="test")
    return process_dpo_gpqa_dataset(dataset, limit, train_val_ratio, task, REQUEST_PROMPT_WIKI, GENERATION_PREFIX_WIKI, system_prompt, "answer")

def get_gpqa_all_dataset(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str="", task:str="all"):
    dataset = load_dataset(f"zekeZZ/gpqa_all", "gpqa-all", split="test")
    return process_dpo_gpqa_dataset(dataset, limit, train_val_ratio, task, REQUEST_PROMPT_WIKI, GENERATION_PREFIX_WIKI, system_prompt, "answer")

def get_mmlu_dataset(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str=""):
    dataset = load_dataset(f"zekeZZ/mmlu_sort_by_gpt2xl_gpqa_bio",  split="test")
    return process_dpo_gpqa_dataset(dataset, limit, train_val_ratio, "mmlu", REQUEST_PROMPT_WIKI, GENERATION_PREFIX_WIKI, system_prompt, "answer")

def get_medmcqa_dataset(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str=""):
    dataset = load_dataset(f"zekeZZ/medmcqa-gen-by-zephyr-ft-gpqa-all-sorted-contrastive", split="train")
    return process_dpo_medmcqa_dataset(dataset, limit, train_val_ratio, "medmcqa", REQUEST_PROMPT_WIKI, GENERATION_PREFIX_WIKI, system_prompt, "answer")

def get_medmcqa_gen_dataset(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str=""):
    dataset = load_dataset(f"zekeZZ/medmcqa-gen-by-zephyr-ft-gpqa-all", split="train")
    return process_dpo_medmcqa_gen_dataset(dataset, limit, train_val_ratio, "medmcqa_gen", REQUEST_PROMPT_WIKI, GENERATION_PREFIX_WIKI, system_prompt, "text")

def get_medmcqa_trak_dataset(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str=""):
    dataset = load_dataset(f"zekeZZ/medmcqa-gen-by-zephyr-ft-gpqa-all-sorted-trak", split="train")
    return process_dpo_medmcqa_gen_dataset(dataset, limit, train_val_ratio, "medmcqa_trak", REQUEST_PROMPT_WIKI, GENERATION_PREFIX_WIKI, system_prompt, "text")

def get_pile_dataset(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str=""):
    dataset = load_dataset("zekeZZ/pile_10000_subset", split="train")
    return process_pile_text_dataset(dataset, limit, train_val_ratio, REQUEST_PROMPT_WIKI, GENERATION_PREFIX_WIKI, system_prompt)

def get_pile_trak_dataset(limit:Optional[int] = None, train_val_ratio:Optional[float]=None, system_prompt:str=""):
    dataset = load_dataset("zekeZZ/pile-10000-subset-sorted-trak", split="train")
    return process_pile_text_dataset(dataset, limit, train_val_ratio, REQUEST_PROMPT_WIKI, GENERATION_PREFIX_WIKI, system_prompt)

# for debuging
if __name__ == "__main__":
    get_bio_forget_dataset(100)