from torch.utils.data import Dataset
import datasets
from datasets import load_dataset
import transformers
from typing import Dict
import torch
import numpy as np
from tqdm import tqdm
import json
import random
random.seed(0)

from typing import List

mistral_short_circuit_template = "{user_tag} {instruction} {assistant_tag}<SEPARATOR>{response}"
# mistral_short_circuit_template_with_sys = "{user_tag} <<SYS>>\nAlways assist with care, respect, and truth. Respond with utmost utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies promote fairness and positivity.\n<</SYS>>\n\n{instruction} {assistant_tag}<SEPARATOR>{response}"
# mistral_sys_prompt = "{user_tag} <<SYS>>\nAlways assist with care, respect, and truth. Respond with utmost utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies promote fairness and positivity.\n<</SYS>>\n\n"

llama_short_circuit_template = "{user_tag}\n\n{instruction}{assistant_tag}\n\n<SEPARATOR>{response}"

class ShortCircuitingDataset(Dataset):
    
    def __init__(self, 
                tokenizer: transformers.PreTrainedTokenizer, 
                num_examples,
                lorra_args,
                model_name_or_path,
                sc_train_subset : List[str] = None,
                use_refusal_retain : bool = True,
                ):
        super(ShortCircuitingDataset, self).__init__()

        self.model_name_or_path = model_name_or_path.lower()

        if 'llama' in self.model_name_or_path:
            print("USING LLAMA TEMPLATE")
            short_circuit_template = llama_short_circuit_template

        elif 'mistral' in self.model_name_or_path:
            print("USING MISTRAL TEMPLATE")
            short_circuit_template = mistral_short_circuit_template
            # fix spacing issue in template
            tokenizer.chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
        else:
            raise ValueError("Model name not recognized")

        self.user_tag = user_tag = lorra_args.user_tag
        self.assistant_tag = assistant_tag = lorra_args.assistant_tag

        if 'llama' in self.model_name_or_path:
            switch_select = [0,2]
        else:
            switch_select = [0]


        # ======================= Retain ======================= #
        ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="test_sft")
        orig_s = []
        for example in ds:
            messages = example["messages"]
            if len(messages) < 2: continue
            # We add an empty system message if there is none
            # if messages[0]["role"] != "system":
            #     messages.insert(0, {"role": "system", "content": ""})
            switch = np.random.choice(switch_select)
            if switch == 0:
                formatted_input = tokenizer.apply_chat_template(messages, tokenize=False)
            elif switch == 1:
                formatted_input = tokenizer.apply_chat_template(messages, tokenize=False)
                formatted_input = formatted_input.replace(f"<s>{user_tag} ", sys_prompt.format(user_tag=user_tag))
            elif switch == 2:
                formatted_input = short_circuit_template.format(
                    user_tag=user_tag, assistant_tag=assistant_tag,
                    instruction="", response=messages[1]["content"])

            orig_s.append(formatted_input.replace(tokenizer.bos_token, ""))

            if len(orig_s) > num_examples:
                break
        self.orig_s_retain = orig_s
        random.shuffle(self.orig_s_retain)
        print("orig_s_retain[0]", orig_s[0])
        print("Orig s length:", len(self.orig_s_retain))

        
        # ======================= Borderline Retain ======================= #
        import csv
        with open(f'../data/xstest_v2_completions_gpt4_gpteval.csv', newline='') as f:
            data = [dict(row) for row in csv.DictReader(f)]
            data = [row for row in data if row['final_label'] == "1_full_compliance"]
        
        borderline_orig_s = []
        for i, d in enumerate(data * 50):
            switch = np.random.choice(switch_select)
            if switch == 0:
                formatted_input = short_circuit_template.format(
                    user_tag=user_tag, assistant_tag=assistant_tag,
                    instruction=d['prompt'], response=d['completion'])
            elif switch == 1:
                formatted_input = short_circuit_template_with_sys.format(
                    user_tag=user_tag, assistant_tag=assistant_tag,
                    instruction=d['prompt'], response=d['completion'])
            elif switch == 2:
                formatted_input = short_circuit_template.format(
                    user_tag=user_tag, assistant_tag=assistant_tag,
                    instruction="", response=d['completion'])
            
            borderline_orig_s.append(formatted_input.replace(tokenizer.bos_token, ""))

        self.orig_s_retain += borderline_orig_s
        random.shuffle(self.orig_s_retain)
        print("borderline_orig_s[0]", borderline_orig_s[0])
        print("Orig s length:", len(self.orig_s_retain))


        # ======================= Refusal Retain ======================= #
        if 'llama' in self.model_name_or_path and use_refusal_retain:
            with open("/data/private_models/short_circuiting/data/5kharmful_filtered_1_1_filter_harmbench_categorized.json") as file:
                dataset = json.load(file)
            
            # select subset
            random.shuffle(dataset)
            
            if sc_train_subset is None: #if we're not doing generalization experiment we want to cap data at 2000, else not, since every category has less than 2000 samples
                dataset = dataset[:2000]

            refusal_retain_orig = []
            for i, d in tqdm(enumerate(dataset * 2)):
                switch = np.random.choice(switch_select)
                if switch == 0:
                    formatted_input = short_circuit_template.format(
                        user_tag=user_tag, assistant_tag=assistant_tag,
                        instruction=d['prompt'], response=d['llama3_output'])
                elif switch == 1:
                    formatted_input = short_circuit_template_with_sys.format(
                        user_tag=user_tag, assistant_tag=assistant_tag,
                        instruction=d['prompt'], response=d['llama3_output'])
                elif switch == 2:
                    formatted_input = short_circuit_template.format(
                        user_tag=user_tag, assistant_tag=assistant_tag,
                        instruction="", response=d['llama3_output'])
                elif switch == 3:
                    random_prefix = self.orig_s_retain[i % len(self.orig_s_retain)]
                    instruction_prefix = random_prefix.split(user_tag)[1:]
                    instruction_prefix = f"{user_tag}".join(instruction_prefix) + "\n<|user|>\n"
                    formatted_input = short_circuit_template.format(
                        user_tag=user_tag, assistant_tag=assistant_tag,
                        instruction=instruction_prefix + d['prompt'], response=d['llama3_output'])
                
                if sc_train_subset is not None:
                    # print(f"Only training on: {sc_train_subset}")
                    if d['category'] not in sc_train_subset:
                        continue
                
                refusal_retain_orig.append(formatted_input.replace(tokenizer.bos_token, ""))

            self.orig_s_retain += refusal_retain_orig
            random.shuffle(self.orig_s_retain)
            print("refusal_orig_s[0]", refusal_retain_orig[0])
            print("Orig s length:", len(self.orig_s_retain))


        # ======================= Short Circuit ======================= #
        with open("/data/private_models/short_circuiting/data/5kharmful_filtered_1_1_filter_harmbench_categorized.json") as file:
            dataset = json.load(file)
        short_circuit_orig = []
        for i, d in tqdm(enumerate(dataset)):
            switch = np.random.choice(switch_select)
            if switch == 0:
                formatted_input = short_circuit_template.format(
                    user_tag=user_tag, assistant_tag=assistant_tag,
                    instruction=d['prompt'], response=d['output'])
            elif switch == 1:
                formatted_input = short_circuit_template_with_sys.format(
                    user_tag=user_tag, assistant_tag=assistant_tag,
                    instruction=d['prompt'], response=d['output'])
            elif switch == 2:
                formatted_input = short_circuit_template.format(
                    user_tag=user_tag, assistant_tag=assistant_tag,
                    instruction="", response=d['output'])
            elif switch == 3:
                random_prefix = self.orig_s_retain[i % len(self.orig_s_retain)]
                instruction_prefix = random_prefix.split(user_tag)[1:]
                instruction_prefix = f"{user_tag}".join(instruction_prefix) + "\n<|user|>\n"
                formatted_input = short_circuit_template.format(
                    user_tag=user_tag, assistant_tag=assistant_tag,
                    instruction=instruction_prefix + d['prompt'], response=d['output'])
            
            if sc_train_subset is not None:
                # print(f"Only training on: {sc_train_subset}")
                if d['category'] not in sc_train_subset:
                    continue
            
            short_circuit_orig.append(formatted_input.replace(tokenizer.bos_token, ""))
        
        self.short_circuit_orig = short_circuit_orig
        random.shuffle(self.short_circuit_orig)
        print("short_circuit_orig[0]", short_circuit_orig[0])
        print("Short circuit length:", len(self.short_circuit_orig))


        # ======================= Val ======================= #
        with open("/data/private_models/short_circuiting/data/5kharmful_filtered_harmless.json") as file:
        # with open(f"/data/{person_path}/short-circuiting/data/5kharmful_filtered_harmless.json") as file:
            dataset = json.load(file)
        val_orig = []
        for i, d in tqdm(enumerate(dataset)):
            val_orig.append(short_circuit_template.format(
                user_tag=user_tag, assistant_tag=assistant_tag,
                instruction=d['prompt'], response=d['output']).replace(tokenizer.bos_token, ""))

        self.val_orig = val_orig

        self.tokenizer = tokenizer

    def __len__(self):
        return min(len(self.orig_s_retain), len(self.short_circuit_orig))

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        orig_s_retain = self.orig_s_retain[i]
        short_circuit_orig = self.short_circuit_orig[i]
        val_orig = self.val_orig[i % len(self.val_orig)]

        target = ' ' if 'mistral' in self.model_name_or_path else '' # mistral has a space after assistant token

        self.tokenizer.padding_side = "left"
        tokenized_inputs_retain = self.tokenizer(
            orig_s_retain.replace('<SEPARATOR>', target),
            padding="max_length",
            truncation=True,
            max_length=1024,
            return_tensors="pt",
        )
        tokenized_inputs_val = self.tokenizer(
            val_orig.replace('<SEPARATOR>', target),
            padding="max_length",
            truncation=True,
            max_length=1024,
            return_tensors="pt",
        )

        self.tokenizer.padding_side = "left"
        tokenized_short_circuit = self.tokenizer(
            short_circuit_orig.split('<SEPARATOR>')[0],
            padding="max_length",
            truncation=True,
            max_length=512,
            return_tensors="pt",
        )
        self.tokenizer.padding_side = "right"
        response_tokenized_short_circuit = self.tokenizer(
            short_circuit_orig.split('<SEPARATOR>')[1],
            padding="max_length",
            truncation=True,
            max_length=512,
            return_tensors="pt",
            add_special_tokens=False
        )

        self.tokenizer.padding_side = "left"
        combined_input_ids_short_circuit = torch.cat([tokenized_short_circuit["input_ids"], response_tokenized_short_circuit["input_ids"]], dim=1)
        combined_attention_mask_short_circuit = torch.cat([tokenized_short_circuit["attention_mask"], response_tokenized_short_circuit["attention_mask"]], dim=1)

        return dict(
            input_ids=tokenized_inputs_retain["input_ids"],
            attention_mask=tokenized_inputs_retain["attention_mask"],
            input_ids_val=tokenized_inputs_val["input_ids"],
            attention_mask_val=tokenized_inputs_val["attention_mask"],
            input_ids_short_circuit=combined_input_ids_short_circuit,
            attention_mask_short_circuit=combined_attention_mask_short_circuit
        )
