import os
import sys
sys.path.insert(0, "TruthfulQA")

import torch
import torch.nn as nn
import torch.nn.functional as F
import llama
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
import llama
import pandas as pd
import warnings
from einops import rearrange
from transformers import AutoTokenizer, AutoModelForCausalLM
from baukit import Trace, TraceDict
import sklearn
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from sklearn.linear_model import LogisticRegression
import pickle
from functools import partial
from AA_helpers import get_promt_in_template, output_is_jailbroken
import matplotlib.pyplot as plt
from sklearn.svm import SVC

from truthfulqa import utilities, models, metrics
import openai
from truthfulqa.configs import BEST_COL, ANSWER_COL, INCORRECT_COL
import wandb

ENGINE_MAP = {
    'llama_7B': 'baffo32/decapoda-research-llama-7B-hf', 
    'alpaca_7B': 'circulus/alpaca-7b', 
    'vicuna_7B': 'AlekseyKorshuk/vicuna-7b', 
    'llama2_chat_7B': 'meta-llama/Llama-2-7b-chat-hf', 
    'llama2_chat_70B': 'meta-llama/Llama-2-70b-chat-hf', 
}

from truthfulqa.utilities import (
    format_prompt,
    format_prompt_with_answer_strings,
    split_multi_answer,
    format_best,
    find_start,
)
from truthfulqa.presets import preset_map, COMPARE_PRIMER
from truthfulqa.models import find_subsequence, set_columns, MC_calcs
from truthfulqa.evaluate import format_frame, data_to_dict


def load_nq():
    dataset = load_dataset("OamPatel/iti_nq_open_val")["validation"]
    df = pd.DataFrame(columns=["question", "answer", "false_answer"])
    for row in dataset:
        new_row = pd.DataFrame({"question": [row["question"]], "answer": [[_ for _ in row["answer"]]], "false_answer": [row["false_answer"]]})
        df = pd.concat([df, new_row], ignore_index=True)
    return df

def load_triviaqa():
    dataset = load_dataset("OamPatel/iti_trivia_qa_val")["validation"]
    df = pd.DataFrame(columns=["question", "answer", "false_answer"])
    for row in dataset:
        new_row = pd.DataFrame({"question": [row["question"]], "answer": [[_ for _ in row["answer"]['aliases']]], "false_answer": [row["false_answer"]]})
        df = pd.concat([df, new_row], ignore_index=True)
    return df

def format_truthfulqa(question, choice):
    return f"Q: {question} A: {choice}"

def format_truthfulqa_end_q(question, choice, rand_question): 
    return f"Q: {question} A: {choice} Q: {rand_question}"


def tokenized_tqa(dataset, tokenizer): 

    all_prompts = []
    all_labels = []
    for i in range(len(dataset)):
        question = dataset[i]['question']
        choices = dataset[i]['mc2_targets']['choices']
        labels = dataset[i]['mc2_targets']['labels']

        assert len(choices) == len(labels), (len(choices), len(labels))

        for j in range(len(choices)): 
            choice = choices[j]
            label = labels[j]
            prompt = format_truthfulqa(question, choice)
            if i == 0 and j == 0: 
                print(prompt)
            prompt = tokenizer(prompt, return_tensors = 'pt').input_ids
            all_prompts.append(prompt)
            all_labels.append(label)
    
    return all_prompts, all_labels


def tokenized_AA(prompts, tokenizer): 
    
    tokenizer.pad_token = tokenizer.bos_token # "[PAD]" # tokenizer.unk_token

    encoded_prompts = []

    for prompt in prompts:
        prompt = tokenizer(prompt, return_tensors = 'pt').input_ids
        encoded_prompts.append(prompt)

    return encoded_prompts


def tokenized_tqa_gen_end_q(dataset, tokenizer): 

    all_prompts = []
    all_labels = []
    all_categories = []
    for i in range(len(dataset)): 
        question = dataset[i]['question']
        category = dataset[i]['category']
        rand_idx = np.random.randint(len(dataset))
        rand_question = dataset[rand_idx]['question']

        for j in range(len(dataset[i]['correct_answers'])): 
            answer = dataset[i]['correct_answers'][j]
            prompt = format_truthfulqa_end_q(question, answer, rand_question)
            prompt = tokenizer(prompt, return_tensors = 'pt').input_ids
            all_prompts.append(prompt)
            all_labels.append(1)
            all_categories.append(category)
        
        for j in range(len(dataset[i]['incorrect_answers'])):
            answer = dataset[i]['incorrect_answers'][j]
            prompt = format_truthfulqa_end_q(question, answer, rand_question)
            prompt = tokenizer(prompt, return_tensors = 'pt').input_ids
            all_prompts.append(prompt)
            all_labels.append(0)
            all_categories.append(category)
        
    return all_prompts, all_labels, all_categories


def tokenized_tqa_gen(dataset, tokenizer): 

    all_prompts = []
    all_labels = []
    all_categories = []
    for i in range(len(dataset)): 
        question = dataset[i]['question']
        category = dataset[i]['category']

        for j in range(len(dataset[i]['correct_answers'])): 
            answer = dataset[i]['correct_answers'][j]
            prompt = format_truthfulqa(question, answer)
            prompt = tokenizer(prompt, return_tensors = 'pt').input_ids
            all_prompts.append(prompt)
            all_labels.append(1)
            all_categories.append(category)
        
        for j in range(len(dataset[i]['incorrect_answers'])):
            answer = dataset[i]['incorrect_answers'][j]
            prompt = format_truthfulqa(question, answer)
            prompt = tokenizer(prompt, return_tensors = 'pt').input_ids
            all_prompts.append(prompt)
            all_labels.append(0)
            all_categories.append(category)
        
    return all_prompts, all_labels, all_categories


def get_llama_activations_bau(model, prompt): 

    HEADS = [f"model.layers.{i}.self_attn.head_out" for i in range(model.config.num_hidden_layers)]
    MLPS = [f"model.layers.{i}.mlp" for i in range(model.config.num_hidden_layers)]

    with torch.no_grad():
        prompt = prompt.to('cuda')
        with TraceDict(model, HEADS+MLPS) as ret:
            output = model(prompt, output_hidden_states = True)
        hidden_states = output.hidden_states
        hidden_states = torch.stack(hidden_states, dim = 0).squeeze()
        hidden_states = hidden_states.detach().cpu().numpy()
        head_wise_hidden_states = [ret[head].output.squeeze().detach().cpu() for head in HEADS]
        head_wise_hidden_states = torch.stack(head_wise_hidden_states, dim = 0).squeeze().numpy()
        mlp_wise_hidden_states = [ret[mlp].output.squeeze().detach().cpu() for mlp in MLPS]
        mlp_wise_hidden_states = torch.stack(mlp_wise_hidden_states, dim = 0).squeeze().numpy()

    return hidden_states, head_wise_hidden_states, mlp_wise_hidden_states


def get_jb_results(tokenised_prompts, model, interventions={}, intervention_fn=None, start_edit_location="lt"):

    tokenizer = llama.LlamaTokenizer.from_pretrained(ENGINE_MAP["llama2_chat_7B"])
    tokenizer.pad_token = tokenizer.bos_token # "[PAD]" # tokenizer.unk_token

    # model.config.pad_token_id = model.config.bos_token_id
    # tokenised_prompts = tokenized_AA(prompts, tokenizer)

    def id(head_output, layer_name): 
        return head_output

    if interventions == {}: 
        intervene = id
        layers_to_intervene = []
    else: 
        intervene = partial(intervention_fn, start_edit_location=start_edit_location)
        layers_to_intervene = list(interventions.keys())
    
    jbs = []
    with torch.no_grad():
        for idx, input_ids in enumerate(tqdm(tokenised_prompts)):
            with TraceDict(model, layers_to_intervene, edit_output=intervene) as ret: 
                input_ids = input_ids.to('cuda')
                model_gen_tokens = model.generate(input_ids, 
                                                    num_return_sequences=5,
                                                    max_new_tokens=100,
                                                    bos_token_id=1,
                                                    do_sample= True,
                                                    eos_token_id=2,
                                                    temperature=0.6,
                                                    top_p=0.9,
                                                    pad_token_id=tokenizer.pad_token_id)[:, input_ids.shape[-1]:]
            
            model_gen_str = [tokenizer.decode(elem, skip_special_tokens=True) for elem in model_gen_tokens]

            print(model_gen_str)
            
            jb = [output_is_jailbroken(elem) for elem in model_gen_str]
            print(jb)

            jbs.append(jb)

    return jbs


def flattened_idx_to_layer_head(flattened_idx, num_heads):
    return flattened_idx // num_heads, flattened_idx % num_heads

def layer_head_to_flattened_idx(layer, head, num_heads):
    return layer * num_heads + head

def train_probes(seed, train_set_idxs, val_set_idxs, head_wise_activations, labels, num_layers, num_heads):
    
    all_head_accs = []
    probes = []

    all_X_train = head_wise_activations[train_set_idxs]
    all_X_val =  head_wise_activations[val_set_idxs] 
    y_train =  labels[train_set_idxs]
    y_val = labels[val_set_idxs]
    
    for layer in tqdm(range(num_layers)): 
        for head in range(num_heads): 
            X_train = all_X_train[:,layer,head,:]
            X_val = all_X_val[:,layer,head,:]
    
            clf = LogisticRegression(random_state=seed, max_iter=1000).fit(X_train, y_train)
            y_pred = clf.predict(X_train)
            y_val_pred = clf.predict(X_val)
            all_head_accs.append(accuracy_score(y_val, y_val_pred))
            probes.append(clf)

    plt.hist(all_head_accs, bins=20)
    plt.savefig('head_accs.png')

    all_head_accs_np = np.array(all_head_accs)

    return probes, all_head_accs_np

def train_probes_svm(seed, train_set_idxs, val_set_idxs, head_wise_activations, labels, num_layers, num_heads):

    all_head_accs = []
    all_head_separability = []  # Array to store average separability for each head
    probes = []

    all_X_train = head_wise_activations[train_set_idxs]
    all_X_val = head_wise_activations[val_set_idxs]
    y_train = labels[train_set_idxs]
    y_val = labels[val_set_idxs]
    
    for layer in tqdm(range(num_layers)): 
        for head in range(num_heads):
            
            X_train = all_X_train[:, layer, head, :]
            X_val = all_X_val[:, layer, head, :]

            # Using a linear SVM classifier
            clf = SVC(kernel='linear', random_state=seed).fit(X_train, y_train)
            y_val_pred = clf.predict(X_val)

            all_head_accs.append(accuracy_score(y_val, y_val_pred))
            probes.append(clf)

            # Calculate the distance of each sample from the decision boundary
            distances = clf.decision_function(X_train)
            # Compute the average absolute distance
            avg_separability = np.mean(np.abs(distances))
            all_head_separability.append(avg_separability)

    plt.hist(all_head_accs, bins=20)
    plt.savefig('head_accs.png')

    # Plotting average separability
    plt.figure()
    plt.hist(all_head_separability, bins=20)
    plt.savefig('head_separability.png')

    all_head_accs_np = np.array(all_head_accs)
    all_head_separability_np = np.array(all_head_separability)

    return probes, all_head_accs_np, all_head_separability_np


def get_top_heads(train_idxs, val_idxs, activations, labels, num_layers, num_heads, seed, use_random_dir=False):

    INTERVENTION_THRESHOLD = 0.25

    probes, all_head_accs_np, all_head_separability_np = train_probes_svm(seed, train_idxs, val_idxs, activations, labels, num_layers=num_layers, num_heads=num_heads)

    metric = all_head_separability_np

    metric = metric.reshape(num_layers, num_heads)

    top_heads = []

    value_to_intervene = np.percentile(metric, INTERVENTION_THRESHOLD*100)
    num_to_intervene = np.sum(metric > value_to_intervene)

    # TODO plot a histogram here of accuracies
    top_accs = np.argsort(all_head_accs_np.reshape(num_heads*num_layers))[::-1][:num_to_intervene]
    top_heads = [flattened_idx_to_layer_head(idx, num_heads) for idx in top_accs]
    if use_random_dir: 
        # overwrite top heads with random heads, no replacement
        random_idxs = np.random.choice(num_heads*num_layers, num_heads*num_layers, replace=False)
        top_heads = [flattened_idx_to_layer_head(idx, num_heads) for idx in random_idxs[:num_to_intervene]]

    return top_heads, probes

def get_interventions_dict(top_heads, probes, tuning_activations, num_heads, use_center_of_mass, use_random_dir, com_directions): 

    interventions = {}
    for layer, head in top_heads: 
        interventions[f"model.layers.{layer}.self_attn.head_out"] = []
    for layer, head in top_heads:
        if use_center_of_mass: 
            direction = com_directions[layer_head_to_flattened_idx(layer, head, num_heads)]
        elif use_random_dir: 
            direction = np.random.normal(size=(128,))
        else: 
            direction = probes[layer_head_to_flattened_idx(layer, head, num_heads)].coef_

        breakpoint()

        direction = direction / np.linalg.norm(direction)
        activations = tuning_activations[:,layer,head,:] # batch x 128
        proj_vals = activations @ direction.T
        proj_val_std = np.std(proj_vals)

        # += proj_val_std * direction_to_add

        interventions[f"model.layers.{layer}.self_attn.head_out"].append((head, direction.squeeze(), proj_val_std))

    for layer, head in top_heads: 
        interventions[f"model.layers.{layer}.self_attn.head_out"] = sorted(interventions[f"model.layers.{layer}.self_attn.head_out"], key = lambda x: x[0])

    return interventions

def get_separated_activations(labels, head_wise_activations): 

    # separate activations by question
    dataset=load_dataset('truthful_qa', 'multiple_choice', revision="e89fbc73ff8b063f0ab9c586b3dd1552ed0334f2")['validation']
    actual_labels = []
    for i in range(len(dataset)):
        actual_labels.append(dataset[i]['mc2_targets']['labels'])

    idxs_to_split_at = np.cumsum([len(x) for x in actual_labels])        

    labels = list(labels)
    separated_labels = []
    for i in range(len(idxs_to_split_at)):
        if i == 0:
            separated_labels.append(labels[:idxs_to_split_at[i]])
        else:
            separated_labels.append(labels[idxs_to_split_at[i-1]:idxs_to_split_at[i]])
    assert separated_labels == actual_labels

    separated_head_wise_activations = np.split(head_wise_activations, idxs_to_split_at)

    return separated_head_wise_activations, separated_labels, idxs_to_split_at

def get_com_directions(num_layers, num_heads, train_set_idxs, val_set_idxs, head_wise_activations, labels): 

    com_directions = []

    for layer in range(num_layers): 
        for head in range(num_heads): 
            usable_idxs = np.concatenate([train_set_idxs, val_set_idxs], axis=0)
            
            usable_head_wise_activations = head_wise_activations[usable_idxs, layer, head, :]
            usable_labels = labels[usable_idxs]

            true_mass_mean = np.mean(usable_head_wise_activations[usable_labels == 1], axis=0)
            false_mass_mean = np.mean(usable_head_wise_activations[usable_labels == 0], axis=0)
            
            com_directions.append(true_mass_mean - false_mass_mean)

    com_directions = np.array(com_directions)

    return com_directions
