import transformers
import torch
import numpy as np
import random
import seaborn as sns
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.gridspec as gridspec
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from pprint import pprint
import pickle
import json
import typing
from pathlib import Path
from matplotlib.font_manager import FontProperties
import matplotlib.pylab as pylab
import sys

from torch.utils.data import Dataset, DataLoader
import argparse

REMOTE_URL = f"https://rome.baulab.info/data/dsets/counterfact.json"
device = torch.device("cuda:0")
access_token = None
NUM_SUBSAMPES=2000000
MODEL_PATH = None
BATCH_SIZE = 128

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import re

language_to_country_dict = {
    " Dutch": "The Netherlands",
    " French": "France",
    " Chinese": "China",
    " English": "The UK",
    " Russian": "Russia",
    " Korean": "Korea",
    " Greek": "Greece",
    " Italian": "Italy",
    " German": "Germany",
    " Swedish": "Sweden",
    " Finnish": "Finland",
    " Armenian": "The US",
    " Hebrew": "Israel",
    " Turkish": "Turkey",
    " Norwegian": "Norway",
    " Hindi": "India",
    " Romanian": "Romania"
}


def string_to_filename(input_string, max_length=255, default_name="default"):
    # Define a pattern to match invalid characters
    invalid_chars = r'[<>:"/\\|?*]'
    
    # Replace invalid characters with an underscore
    sanitized_string = re.sub(invalid_chars, '_', input_string)
    
    # Truncate the string to the maximum length if necessary
    sanitized_string = sanitized_string[:max_length].rstrip()
    
    # Ensure the filename is not empty
    if not sanitized_string:
        sanitized_string = default_name
    
    return sanitized_string

class PrePendDataset(Dataset):
    def __init__(self, data, model, prepend_times = 1, false_target = True, rel_test=False, rel_id = -1):
        self.data = data
        self.prepend_times = prepend_times
        self.model = model
        self.false_target = false_target
        self.rel_test=rel_test
        self.rel_id = rel_id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        d = self.data[idx]
        prompt = d['prompt']
        token_true = d['target_true']['str']
        token_false = d['target_new']['str']
        if self.model.model_name != "meta-llama/Llama-2-7b-hf":
            token_true, token_false = " " + token_true, " " + token_false

        tokens = self.model.tokenize_list([token_true, token_false])

        if self.rel_test:
            for i in range(1, self.prepend_times+1):
                if self.rel_id == "P1412":
                    if self.false_target:
                        if self.model.model_name != "meta-llama/Llama-2-7b-hf":
                            token_key = token_false
                        else:
                            token_key = " " + token_false
                        if token_key in language_to_country_dict:
                            prompt = f"Do not think of {language_to_country_dict[token_key]}." + " " + prompt 
                           
                    else:
                        if self.model.model_name != "meta-llama/Llama-2-7b-hf":
                            token_key = token_true
                        else:
                            token_key = " " + token_true
                        if token_key in language_to_country_dict:
                            prompt = f"Do not think of {language_to_country_dict[token_key]}." + " " + prompt  
                else:
                    if self.false_target:
                        prepend = {"P190" : f"The twin city of {d['subject']} is not {token_false}.",
                                "P103": f"{d['subject']} cannot speak {token_false}.",
                                "P641": f"{d['subject']} does not play {token_false}.",
                                "P131": f"{d['subject']} is not located in {token_false}.",}
                    else:
                        prepend = {"P190" : f"The twin city of {d['subject']} is {token_true}.",
                                "P103": f"{d['subject']} can speak {token_true}.",
                                "P641": f"{d['subject']} plays {token_true}.",
                                "P131": f"{d['subject']} is located in {token_true}.",}
                    prompt = prepend[self.rel_id] + " " + prompt
        else:
            for i in range(1, self.prepend_times+1):
                if self.false_target:
                    # prompt = f"{d['subject']}, {token_false}. {prompt}"
                    prompt = f"Do not think of {d['subject']} and {token_false}. {prompt}"
                else:
                    # prompt = f"{d['subject']}, {token_true}. {prompt}"
                    prompt = f"Do not think of {d['subject']} and {token_true}. {prompt}"

        return prompt, tokens

class CounterFactDataset(Dataset):
    def __init__(
        self, data_dir: str, size: typing.Optional[int] = None, *args, **kwargs
    ):
        data_dir = Path(data_dir)
        cf_loc = data_dir / "counterfact.json"
        if not cf_loc.exists():
            print(f"{cf_loc} does not exist. Downloading from {REMOTE_URL}")
            data_dir.mkdir(exist_ok=True, parents=True)
            torch.hub.download_url_to_file(REMOTE_URL, cf_loc)

        with open(cf_loc, "r") as f:
            self.data = json.load(f)
        if size is not None:
            self.data = self.data[:size]

        print(f"Loaded dataset with {len(self)} elements")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.data[item]
    
def get_dataset():
    dset = CounterFactDataset(data_dir=".")
    data = []
    for d in dset:
        d = d['requested_rewrite']
        data.append({'relation_id': d['relation_id'],
                    'subject' : d['subject'],
                    'prompt': d['prompt'].format(d['subject']),
                    'target_true' : d['target_true'],
                    'target_new' : d['target_new']})
    return data

def subsample_dataset(data, num_samples):
    indices = torch.randperm(len(data))[:num_samples]
    subsampled_data = [data[indices[i]] for i in range(num_samples)]
    return subsampled_data

class LLM:
    def __init__(self, model_name, cache_dir = None, torch_flag = True):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir = MODEL_PATH, token = access_token)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side="left"

        self.model = AutoModelForCausalLM.from_pretrained(model_name, 
                                                          pad_token_id = self.tokenizer.eos_token_id, 
                                                          cache_dir = MODEL_PATH,
                                                          token = access_token).to(device)
        self.model_name = model_name

    def tokenize_list(self, l):
        # print("List, Input IDs:", l, self.tokenizer(l)['input_ids'])
        if self.model_name == "openai-community/gpt2":
            tokens = [x[0] for x in self.tokenizer(l)['input_ids']]
        elif self.model_name in ["google/gemma-2b", "google/gemma-2b-it", "meta-llama/Llama-2-7b-hf"]:
            tokens = [x[1] for x in self.tokenizer(l)['input_ids']]
        return tokens

    def get_embeddings(self):
        if self.model_name == "openai-community/gpt2":
            embeddings = self.model.transformer.wte.weight.clone()
        else:
            embeddings = self.model.model.embed_tokens.weight.clone()
        return embeddings
    
    def get_topk_tokens(self, prompt, k = 5):
        logits = self.get_token_logits(prompt)
        return logits.topk(k = k, dim=1)
    
    def check_input_length_consistency(self, model_inputs):
        input_length = len(model_inputs.input_ids.squeeze())
        if input_length > self.tokenizer.model_max_length:
            assert False, "Prompt is much longer than context length"

    def get_outputs(self, prompt, max_tokens=2):
        model_inputs = self.tokenizer(prompt,
                                      return_tensors="pt").to(device)
        self.check_input_length_consistency(model_inputs)


        greedy_output = self.model.generate(**model_inputs,
                                            do_sample=False, 
                                            pad_token_id = self.tokenizer.eos_token_id,
                                            max_new_tokens=max_tokens)
        outputs = self.tokenizer.batch_decode(greedy_output, skip_special_tokens=True)
        return outputs
    
    # if tokens is None, return all logits
    # else return logits[tokens]
    def get_token_logits(self, prompt, tokens = None, prob=False):
        model_inputs = self.tokenizer(prompt,
                                      return_tensors="pt", truncation=True, padding=True).to(device)

        self.check_input_length_consistency(model_inputs)

        greedy_output = self.model.generate(input_ids=model_inputs['input_ids'], attention_mask=model_inputs['attention_mask'],
                                    do_sample=False, 
                                    max_new_tokens=1,
                                    pad_token_id = self.tokenizer.eos_token_id,
                                    output_hidden_states = False,
                                    output_scores = True,
                                    return_dict_in_generate=True)


        logits = greedy_output.scores[0].squeeze()

        logits_0 = logits[torch.arange(logits.size(0)), tokens[0]]
        logits_1 = logits[torch.arange(logits.size(0)), tokens[1]]

        if prob:
            probs = torch.softmax(logits, dim = 1)
            probs_0 = probs[torch.arange(probs.size(0)), tokens[0]]
            probs_1 = probs[torch.arange(probs.size(0)), tokens[1]]
            return (logits_0, logits_1), (probs_0, probs_1)
        
        return (logits_0, logits_1)
    
    # if tokens is None, return all logits
    # else return logits[tokens]
    def get_token_probs(self, prompt, tokens = None):
        logits = self.get_token_logits(prompt, tokens)
        probs = torch.softmax(logits, dim = 1)
        return probs
    
    def compare_tokens(self, prompt, token1, token2):
        tokens = [token1, token2]
        logits, probs = self.get_token_logits(prompt, tokens, prob=True)

        return torch.sum(logits[0] > logits[1]), logits, probs

def test_prompts(model):
    # prompts = ["The Eiffel Tower is in the city of", "The Eiffel Tower is not in Hong Kong. The Eiffel Tower is in the city of"]
    # prompts = ["The Eiffel Tower. The original language of Mr. Romeo was"]
    # prompts = ["Chicago. The original language of The Icelandic Dream was"]
    # prompts = ["Chicago. The original language of Invisible Cities was written in"]
    if model.model_name == "openai-community/gpt2":
        prompts = ["The Eiffel Tower is in the city of", "The Eiffel Tower is not in Chicago. Therefore, the Eiffel Tower is in the city of"]
    if model.model_name == "google/gemma-2b":
        prompts = ["The Eiffel Tower is in the city of", "The Eiffel Tower is not in Chicago. Therefore, the Eiffel Tower is in the city of"]
    if model.model_name == "google/gemma-2b-it":
        prompts = ["The Eiffel Tower is in the city of", "The Eiffel Tower is not in Chicago. However, the Chicago river is in Chicago. Therefore, the Eiffel Tower is in the city of"]
    if model.model_name == "meta-llama/Llama-2-7b-hf":
        prompts = ["The Eiffel Tower is in the city of", "Therefore, the Eiffel Tower is in the city of"]
        for i in range(8):
            prompts[1] = "The Eiffel Tower is not in Chicago. " + prompts[1]

    print("\n")
    print(f"Testing {model.model_name}")
    for prompt in prompts:
        print("Prompt:", prompt)
        print("1-token greedy output:", model.get_outputs(prompt, 1))
    print("--------\n")

    # for prompt in prompts:
    #     print("Prompt:", prompt)
    #     print("2-token greedy output:", model.get_outputs(prompt, 2))
    #     topk_tokens = model.get_topk_tokens(prompt).indices
    #     topk_tokens = model.tokenizer.batch_decode(topk_tokens, skip_special_tokens=True)
    #     print("topk tokens", topk_tokens)
    #     print("---")
    #     if model.model_name != "meta-llama/Llama-2-7b-hf":
    #         tokens = model.tokenize_list([' Hong', ' Paris'])
    #     else:
    #         tokens = model.tokenize_list(['Hong', 'Paris'])
    #     print("Tokens we're interested to compare:", model.tokenizer.batch_decode(tokens, skip_special_tokens=True))
    #     print("logits of these tokens:", model.get_token_logits(prompt, tokens))
    #     print("probabilities of these tokens:", model.get_token_probs(prompt, tokens))
    #     print("Comparison output (bool, float):", model.compare_tokens(prompt, tokens[0], tokens[1]))
    #     print("--------------------------\n\n")

# prepending with "Subject, false_target." or "Subject, true_target."
def prepend_expts_context(sub_data, model, num_prepends = 5, false_target = True, rel_id = -1):
    if rel_id != -1:
        print("Relation ID:", rel_id)
    num_samples = len(sub_data)
    acc = [0.0 for i in range(num_prepends + 1)]

    all_logitis_true = []
    all_logitis_false = []

    all_probs_true = []
    all_probs_false = []

    for i in range(0, num_prepends + 1):
        if rel_id != -1:
            dataset = PrePendDataset(data=sub_data, model=model, prepend_times = i, false_target = false_target, rel_test=True, rel_id=rel_id)
        else:
            dataset = PrePendDataset(data=sub_data, model=model, prepend_times = i, false_target = false_target)
        dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

        logitis_true = []
        logitis_false = []

        probs_true = []
        probs_false = []
        for _, batch in enumerate(tqdm(dataloader)):
            prompt, tokens = batch

            nb_corrects, logits, probs = model.compare_tokens(prompt, tokens[0], tokens[1])

            acc[i] += nb_corrects

            logitis_true += logits[0].cpu().numpy().tolist()
            logitis_false += logits[1].cpu().numpy().tolist()

            probs_true += probs[0].cpu().numpy().tolist()
            probs_false += probs[1].cpu().numpy().tolist()

        all_logitis_true.append(logitis_true)
        all_logitis_false.append(logitis_false)
        all_probs_true.append(probs_true)
        all_probs_false.append(probs_false)
            

    for i in range(num_prepends + 1):
        acc[i] /= num_samples
        # print(i, acc[i])
        acc[i] = acc[i].cpu().item()

    dump_data = {
        'acc': acc,
        'all_logitis_true': all_logitis_true,
        'all_logitis_false': all_logitis_false,
        'all_probs_true': all_probs_true,
        'all_probs_false': all_probs_false,
    }
        
    return dump_data

def plot_graphs(dump_data, filename, title, plot_prob_diff = False):
    params = {'legend.fontsize': 'x-large',
            'figure.figsize': (8, 6),
            'axes.labelsize': 'x-large',
            'xtick.labelsize':'x-large',
            'ytick.labelsize':'x-large'}
    pylab.rcParams.update(params)

    num_prepend_plus_one = max([len(v["acc"]) for k, v in dump_data.items()])
    
    # Plotting accuracy
    for model_name, model_dump_data in dump_data.items():
        acc, acc_lo, acc_hi = [], [], []
        for num_prep in range(num_prepend_plus_one):
            logits_true = np.array(model_dump_data['all_logitis_true'][num_prep])
            logits_false = np.array(model_dump_data['all_logitis_false'][num_prep])
            bool_arr = np.zeros(len(logits_true))
            bool_arr[: np.sum(logits_true > logits_false)] = 1.
            mu, err = np.mean(bool_arr), np.std(bool_arr) / np.sqrt(len(logits_true))
            acc.append(mu)
            acc_lo.append(mu - err)
            acc_hi.append(mu + err)

        x = list(range(num_prepend_plus_one))
        sns.lineplot(x=x, y=acc, label=model_name)
        plt.fill_between(x, acc_lo, acc_hi, alpha=0.2)

    plt.xlabel('Number of prepends', fontsize=20)
    plt.ylabel('Efficacy Score', fontsize=20)
    plt.title(title, fontsize=20)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"results/{filename}.png")
    plt.clf()

    if not plot_prob_diff:
        return

    for model_name, model_dump_data in dump_data.items():
        prob, prob_lo, prob_hi = [], [], []
        for num_prep in range(num_prepend_plus_one):
            probs_true = np.array(model_dump_data['all_probs_true'][num_prep])
            probs_false = np.array(model_dump_data['all_probs_false'][num_prep])
            probs_diffs = probs_true - probs_false
            mu, err = np.mean(probs_diffs), np.std(probs_diffs) / np.sqrt(len(probs_true))
            prob.append(mu)
            prob_lo.append(mu - err)
            prob_hi.append(mu + err)

        x = list(range(num_prepend_plus_one))
        sns.lineplot(x=x, y=prob, label=model_name)
        plt.fill_between(x, prob_lo, prob_hi, alpha=0.2)

    plt.xlabel('Number of prepends', fontsize=20)
    plt.ylabel('Efficacy Magnitude', fontsize=20)
    plt.title(title, fontsize=20)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"results/{filename}_probs.png")
    plt.clf()

def hijack_context(data, model_names, all_models):
    num_samples = min(NUM_SUBSAMPES, len(data))
    sub_data = subsample_dataset(data, num_samples)

    sub_false_accs = {}
    for name in model_names:
        print(f"Running {name}")
        model = all_models[name]
        sub_false_accs[name] = prepend_expts_context(sub_data, model)

    # save
    for name, dump_data in sub_false_accs.items():
        with open('results/' + string_to_filename(name) + '_sub_false_accs.pickle', 'wb') as handle:
            pickle.dump(dump_data, handle)

    plot_graphs(sub_false_accs, filename = "sub_false_accs", title = "Prepending \'Recall $s_i$ and $f_i$.\'")

    print("False accs:")
    for name in model_names:
        print(f"{name} & {sub_false_accs[name]['acc'][0]:.2f} & {sub_false_accs[name]['acc'][1]:.2f}\\\\")

    sub_true_accs = {}
    for name in model_names:
        print(f"Running {name}")
        model = all_models[name]
        sub_true_accs[name] = prepend_expts_context(sub_data, model, false_target=False)

    for name, dump_data in sub_true_accs.items():
        with open('results/' + string_to_filename(name) + '_sub_true_accs.pickle', 'wb') as handle:
            pickle.dump(dump_data, handle)
    # with open('results/sub_true_accs.pickle', 'wb') as handle:
    #     pickle.dump(sub_true_accs, handle)

    plot_graphs(sub_true_accs, filename = "sub_true_accs", title = "Prepending \'Recall $s_i$ and $t_i$\'")

def hijack_sentences(data, model_names, all_models):
    all_rel_ids = set([d['relation_id'] for d in data])
    # print(len(all_rel_ids))

    # for rel_id in all_rel_ids:
    #     print("Relation ID:", rel_id)
    #     counter = 10
    #     for d in data:
    #         if d['relation_id'] == rel_id:
    #             print(d['subject'], "-----",  d['prompt'], "-------- True: ", d['target_true']['str'], "False:", d['target_new']['str'])
    #             counter -= 1
    #             if not counter:
    #                 print("")
    #                 break

    sub_rel_ids = ["P190", "P103", "P641", "P131", "P1412"]
    sub_data = {}
    for rel_id in sub_rel_ids:
        cur_data = []
        for d in data:
            if d['relation_id'] == rel_id:
                cur_data.append(d)
        num_samples = min(NUM_SUBSAMPES, len(cur_data))
        sub_data[rel_id] = subsample_dataset(cur_data, num_samples)

    for rel_id in sub_rel_ids:
        print(f"Running {rel_id}")
        filename = f"_sentence_false_{rel_id}"
        print(filename)
        false_accs = {}
        for name in model_names:
            print(f"Running {name}")
            model = all_models[name]
            false_accs[name] = prepend_expts_context(sub_data[rel_id], model, rel_id=rel_id)
        # print(false_accs)
        # with open(f"results/{filename}.pickle", 'wb') as handle:
        #     pickle.dump(false_accs, handle)

        for name, dump_data in false_accs.items():
            with open('results/' + string_to_filename(name) + f'{filename}.pickle', 'wb') as handle:
                pickle.dump(dump_data, handle)
        plot_graphs(false_accs, filename = filename, title = f"Hijacking based on {rel_id}", plot_prob_diff = True)
        print("\n")

    for rel_id in sub_rel_ids:
        print(f"Running {rel_id}")
        filename = f"sentence_true_{rel_id}"
        print(filename)
        false_accs = {}
        for name in model_names:
            print(f"Running {name}")
            model = all_models[name]
            false_accs[name] = prepend_expts_context(sub_data[rel_id], model, false_target = False, rel_id=rel_id)
        # print(false_accs)
        # with open(f"results/{filename}.pickle", 'wb') as handle:
        #     pickle.dump(false_accs, handle)

        for name, dump_data in false_accs.items():
            with open('results/' + string_to_filename(name) + f'{filename}.pickle', 'wb') as handle:
                pickle.dump(dump_data, handle)

        plot_graphs(false_accs, filename = filename, title = f"Prepending answer based on {rel_id}")
        print("\n")

def main(args):
    model_names = ["openai-community/gpt2", "google/gemma-2b", "google/gemma-2b-it", "meta-llama/Llama-2-7b-hf"]
    if args.model_id == 0:
        model_names = ["openai-community/gpt2"]

    if args.model_id == 1:
        model_names = [ "google/gemma-2b"]

    if args.model_id == 2:
        model_names = [ "google/gemma-2b-it"]

    if args.model_id == 3:
        model_names = [ "meta-llama/Llama-2-7b-hf"]

    if args.model_id == 4:
        model_names = model_names

    
    all_models = {name: LLM(name) for name in model_names}


    # Get COUNTERFACT data
    data = get_dataset()
    if args.test_type == 0:
        hijack_context(data, model_names, all_models)
    if args.test_type == 1:
        hijack_sentences(data, model_names, all_models)
    if args.test_type == 2:
        for model_name, model in all_models.items():
            test_prompts(model)
    
if __name__ == "__main__":
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    parser = argparse.ArgumentParser()

    parser.add_argument('--model_id', default=0, type=int)
    parser.add_argument('--test_type', default=0, type=int)

    args = parser.parse_args()

    main(args)