import torch
import datasets
import numpy as np
from torch import nn
from transformers import AutoTokenizer
from peft import PeftModel
from tqdm import tqdm
from matplotlib import pyplot as plt
from sklearn.metrics import roc_curve, precision_recall_curve, auc

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def collate_fn(examples):
    batch = {}
    batch["input_ids_chosen"] = [torch.tensor(ex["input_ids_chosen"]) for ex in examples]
    batch["attention_mask_chosen"] = [torch.tensor(ex["attention_mask_chosen"]) for ex in examples]
    batch["input_ids_rejected"] = [torch.tensor(ex["input_ids_rejected"]) for ex in examples]
    batch["attention_mask_rejected"] = [torch.tensor(ex["attention_mask_rejected"]) for ex in examples]
    batch["input_ids_chosen"] = torch.stack(batch["input_ids_chosen"]).to(device)
    batch["attention_mask_chosen"] = torch.stack(batch["attention_mask_chosen"]).to(device)
    batch["input_ids_rejected"] = torch.stack(batch["input_ids_rejected"]).to(device)
    batch["attention_mask_rejected"] = torch.stack(batch["attention_mask_rejected"]).to(device)
    return batch


def collate_fn_with_prompt(examples):
    batch = {}
    batch["prompt"] = examples[0]["prompt"]
    batch["input_ids_chosen"] = [torch.tensor(ex["input_ids_chosen"]) for ex in examples]
    batch["attention_mask_chosen"] = [torch.tensor(ex["attention_mask_chosen"]) for ex in examples]
    batch["input_ids_rejected"] = [torch.tensor(ex["input_ids_rejected"]) for ex in examples]
    batch["attention_mask_rejected"] = [torch.tensor(ex["attention_mask_rejected"]) for ex in examples]
    batch["input_ids_chosen"] = torch.stack(batch["input_ids_chosen"]).to(device)
    batch["attention_mask_chosen"] = torch.stack(batch["attention_mask_chosen"]).to(device)
    batch["input_ids_rejected"] = torch.stack(batch["input_ids_rejected"]).to(device)
    batch["attention_mask_rejected"] = torch.stack(batch["attention_mask_rejected"]).to(device)
    return batch


def compute_loss(model, inputs):
    model.zero_grad()
    rewards_chosen = model(
        input_ids=inputs["input_ids_chosen"],
        attention_mask=inputs["attention_mask_chosen"],
        return_dict=True,
    )["logits"]
    rewards_rejected = model(
        input_ids=inputs["input_ids_rejected"],
        attention_mask=inputs["attention_mask_rejected"],
        return_dict=True,
    )["logits"]
    loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
    return loss


def compute_loss_positive(model, inputs):
    model.zero_grad()
    rewards_chosen = model(
        input_ids=inputs["input_ids_chosen"],
        attention_mask=inputs["attention_mask_chosen"],
        return_dict=True,
    )["logits"]
    loss = -nn.functional.logsigmoid(rewards_chosen).mean()
    return loss


def compute_loss_positive_subtract(model, inputs):
    model.zero_grad()
    rewards_chosen = model(
        input_ids=inputs["input_ids_chosen"],
        attention_mask=inputs["attention_mask_chosen"],
        return_dict=True,
    )["logits"]

    reward_pseudo_reject = model(
        input_ids=torch.tensor([inputs["prompt"]["input_ids"]]).to(device),
        attention_mask=torch.tensor([inputs["prompt"]["attention_mask"]]).to(device),
        return_dict=True,
    )["logits"]

    loss = -nn.functional.logsigmoid(rewards_chosen - reward_pseudo_reject).mean()
    return loss


def compute_loss_negative_subtract(model, inputs):
    model.zero_grad()

    reward_pseudo_chosen = model(
        input_ids=torch.tensor([inputs["prompt"]["input_ids"]]).to(device),
        attention_mask=torch.tensor([inputs["prompt"]["attention_mask"]]).to(device),
        return_dict=True,
    )["logits"]

    rewards_rejected = model(
        input_ids=inputs["input_ids_rejected"],
        attention_mask=inputs["attention_mask_rejected"],
        return_dict=True,
    )["logits"]
    loss = -nn.functional.logsigmoid(-rewards_rejected + reward_pseudo_chosen).mean()
    return loss


def compute_loss_positive_rejected(model, inputs):
    model.zero_grad()
    rewards_rejected = model(
        input_ids=inputs["input_ids_rejected"],
        attention_mask=inputs["attention_mask_rejected"],
        return_dict=True,
    )["logits"]
    loss = -nn.functional.logsigmoid(rewards_rejected).mean()
    return loss


def compute_loss_positive_rejected_subtract(model, inputs):
    model.zero_grad()

    reward_pseudo_chosen = model(
        input_ids=torch.tensor([inputs["prompt"]["input_ids"]]).to(device),
        attention_mask=torch.tensor([inputs["prompt"]["attention_mask"]]).to(device),
        return_dict=True,
    )["logits"]

    rewards_rejected = model(
        input_ids=inputs["input_ids_rejected"],
        attention_mask=inputs["attention_mask_rejected"],
        return_dict=True,
    )["logits"]
    loss = -nn.functional.logsigmoid(rewards_rejected - reward_pseudo_chosen).mean()
    return loss


def compute_loss_negative(model, inputs):
    model.zero_grad()
    rewards_rejected = model(
        input_ids=inputs["input_ids_rejected"],
        attention_mask=inputs["attention_mask_rejected"],
        return_dict=True,
    )["logits"]
    loss = -nn.functional.logsigmoid(-rewards_rejected).mean()
    return loss


def compute_loss_l2_chosen(model, inputs, reward_mean, reward_std, alice_score_norm):
    model.zero_grad()
    rewards_chosen = model(
        input_ids=inputs["input_ids_chosen"],
        attention_mask=inputs["attention_mask_chosen"],
        return_dict=True,
    )["logits"]
    loss_chosen = nn.functional.mse_loss((rewards_chosen - reward_mean) / reward_std, alice_score_norm[0]).to(device)
    return loss_chosen


def compute_loss_l2_rejected(model, inputs, reward_mean, reward_std, alice_score_norm):
    model.zero_grad()
    rewards_rejected = model(
        input_ids=inputs["input_ids_rejected"],
        attention_mask=inputs["attention_mask_rejected"],
        return_dict=True,
    )["logits"]
    loss_chosen = nn.functional.mse_loss((rewards_rejected - reward_mean) / reward_std, alice_score_norm[1]).to(device)
    return loss_chosen


def compute_loss_positive(model, inputs):
    model.zero_grad()
    rewards_chosen = model(
        input_ids=inputs["input_ids_chosen"],
        attention_mask=inputs["attention_mask_chosen"],
        return_dict=True,
    )["logits"]
    loss = -nn.functional.logsigmoid(rewards_chosen).mean()
    return loss


def compute_loss_average(model, inputs):
    model.zero_grad()
    rewards_chosen = model(
        input_ids=inputs["input_ids_chosen"],
        attention_mask=inputs["attention_mask_chosen"],
        return_dict=True,
    )["logits"]
    loss = -nn.functional.logsigmoid(rewards_chosen).mean()
    return loss


def prepare_model(model, device, peft_model_id):
    # get_peft_model
    print("Loading model")
    model = PeftModel.from_pretrained(model, peft_model_id)
    # turn lora weights require_grad to True
    model.to(device)
    for name, param in model.named_parameters():
        if "lora" in name:
            param.requires_grad = True
        elif "modules_to_save.default" in name:
            param.requires_grad = True
            print(f"requires_grad of {name} is set to True")
    model.eval()
    model.print_trainable_parameters()

    return model


def get_dataset(data_path, tokenizer_path):
    train_dataset = datasets.load_from_disk(f"{data_path}/train_dataset")
    eval_dataset = datasets.load_from_disk(f"{data_path}/eval_dataset")
    # pad the examples to max length and return the dataset with padded examples
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
    train_dataset = train_dataset.map(lambda x: _pad_examples(x, tokenizer))
    eval_dataset = eval_dataset.map(lambda x: _pad_examples(x, tokenizer))
    return train_dataset, eval_dataset


def _pad_examples(example, tokenizer, max_length=1024):
    # define pad_tokens for GPT-2 on both tokenizer and model, this is specific to GPT-2
    tokenizer.pad_token_id = tokenizer.eos_token_id
    example["input_ids_chosen"] = example["input_ids_chosen"] + [tokenizer.pad_token_id] * (
        max_length - len(example["input_ids_chosen"])
    )
    example["attention_mask_chosen"] = example["attention_mask_chosen"] + [0] * (
        max_length - len(example["attention_mask_chosen"])
    )
    example["input_ids_rejected"] = example["input_ids_rejected"] + [tokenizer.pad_token_id] * (
        max_length - len(example["input_ids_rejected"])
    )
    example["attention_mask_rejected"] = example["attention_mask_rejected"] + [0] * (
        max_length - len(example["attention_mask_rejected"])
    )
    return example


def rapid_normed_datainf(rapid_grad_train, rapid_grad_val, indices):
    n_train = len(rapid_grad_train)

    # calculate lambda
    lam = 0
    for grad in rapid_grad_train:
        grad = grad / torch.norm(grad)
        lam += torch.mean(grad**2)
    lam = 0.1 / n_train * lam
    # calculate avg gradient of validation set
    val_grad = torch.zeros_like(rapid_grad_val[0])
    for i, grad in enumerate(rapid_grad_val):
        if i in indices:
            grad = grad / torch.norm(grad)
            val_grad += grad
    val_grad /= len(indices)

    # make a tensor of shape (n_train, D) where each row is a flattened gradient
    train_grads = torch.stack(rapid_grad_train)
    train_grads_dots = torch.matmul(
        train_grads, train_grads.t()
    )  # this stores dots of all pairs of gradients in train_grads
    val_grad_dots = torch.matmul(train_grads, val_grad.t())  # this stores dots of train_grads with val_grad_avg
    # Initialize inf_list as a tensor for better performance
    rapidinf = torch.zeros(n_train)

    # Calculate the first term outside the loop
    rapidinf = -1 / lam * val_grad_dots

    # Precompute terms
    one_over_lam = 1 / lam
    one_over_lam_n_train = one_over_lam / n_train
    lam_plus_diag = lam + train_grads_dots.diag()

    # Use vectorized operations for the second term
    for k in range(n_train):
        rapidinf[k] += torch.sum(one_over_lam_n_train * (train_grads_dots[:, k] * val_grad_dots) / lam_plus_diag)

    return rapidinf.tolist()


def rapid_datainf(rapid_grad_train, rapid_grad_val, indices):
    n_train = len(rapid_grad_train)

    # calculate lambda
    lam = 0
    for grad in rapid_grad_train:
        lam += torch.mean(grad**2)
    lam = 0.1 / n_train * lam
    # calculate avg gradient of validation set
    val_grad = torch.zeros_like(rapid_grad_val[0])
    for i, grad in enumerate(rapid_grad_val):
        if i in indices:
            val_grad += grad
    val_grad /= len(indices)

    # make a tensor of shape (n_train, D) where each row is a flattened gradient
    train_grads = torch.stack(rapid_grad_train)
    train_grads_dots = torch.matmul(
        train_grads, train_grads.t()
    )  # this stores dots of all pairs of gradients in train_grads
    val_grad_dots = torch.matmul(train_grads, val_grad.t())  # this stores dots of train_grads with val_grad_avg

    # Initialize inf_list as a tensor for better performance
    rapidinf = torch.zeros(n_train)

    # Calculate the first term outside the loop
    rapidinf = -1 / lam * val_grad_dots

    # Precompute terms
    one_over_lam = 1 / lam
    one_over_lam_n_train = one_over_lam / n_train
    lam_plus_diag = lam + train_grads_dots.diag()

    # Use vectorized operations for the second term
    for k in range(n_train):
        rapidinf[k] += torch.sum(one_over_lam_n_train * (train_grads_dots[:, k] * val_grad_dots) / lam_plus_diag)

    return rapidinf.tolist()


def rapid_tracin(rapid_grad_train, rapid_grad_val, indices):
    n_train = len(rapid_grad_train)

    # calculate lambda
    lam = 0
    for grad in rapid_grad_train:
        lam += torch.mean(grad**2)
    lam = 0.1 / n_train * lam

    # calculate avg gradient of validation set
    val_grad = torch.zeros_like(rapid_grad_val[0])
    for i, grad in enumerate(rapid_grad_val):
        if i in indices:
            val_grad += grad
    val_grad /= len(indices)

    # make a tensor of shape (n_train, D) where each row is a flattened gradient
    train_grads = torch.stack(rapid_grad_train)
    val_grad_dots = torch.matmul(train_grads, val_grad.t())  # this stores dots of train_grads with val_grad_avg

    rapidinf = torch.zeros(n_train)
    rapidinf = -1 / lam * val_grad_dots
    return rapidinf.tolist()


def rapid_selfinf(rapid_grad_train):
    train_grads = torch.stack(rapid_grad_train)
    train_grads_dots = torch.matmul(
        train_grads, train_grads.t()
    )  # this stores dots of all pairs of gradients in train_grads

    return train_grads_dots.diag().tolist()


def get_length_indices(data):
    shorter_indices, longer_indices = [], []
    for i, example in tqdm(enumerate(data)):
        if len(example["input_ids_chosen"]) < len(example["input_ids_rejected"]):
            shorter_indices.append(i)
        else:
            longer_indices.append(i)
    return shorter_indices, longer_indices


def get_sycophancy_indices(data):
    less_sycophantic, more_sycophantic, equal_sycophantic = [], [], []
    for i, example in tqdm(enumerate(data)):
        if example["chosen_score"] > example["rejected_score"]:
            more_sycophantic.append(i)
        elif example["rejected_score"] > example["chosen_score"]:
            less_sycophantic.append(i)
        else:
            equal_sycophantic.append(i)

    return less_sycophantic, equal_sycophantic, more_sycophantic


def get_val_fpr_tpr(val_data, flipped_indices):
    fn, tn, fp, tp = 0, 0, 0, 0
    for i, data in enumerate(val_data):
        if i in flipped_indices:
            if data["answer_llm"] == 1:
                tp += 1
            elif data["answer_llm"] == 2:
                fn += 1
        else:
            if data["answer_llm"] == 1:
                tn += 1
            elif data["answer_llm"] == 2:
                fp += 1
    fpr = fp / (fp + tn)
    tpr = tp / (tp + fn)
    return fpr, tpr


def get_precision_recall_auc(influence, flipped_indices):
    total_points = len(influence)
    x_vals = np.linspace(0, 1, total_points)

    # Generate the true conditions assuming that the detection of flipped indices is binary
    true_conditions = np.zeros_like(x_vals)

    for i in range(total_points):
        if i in flipped_indices:
            true_conditions[i] = 1

    precision, recall, _ = precision_recall_curve(true_conditions, influence)
    # Calculate the AUC for precision-recall
    pr_auc = auc(recall, precision)
    return pr_auc, precision, recall


def plot_precision_recall_curve(influence, flipped_indices, title, precision_llm=[], recall_llm=[], llm_label=[]):
    """
    Plots the Precision-Recall (PR) curve for given data and flipped indices and calculates the PR AUC value.

    Parameters:
    influence (np.array): Array of data points from the RapidInf algorithm.
    flipped_indices (list): List of indices that were flipped.

    Returns:
    float: AUC value of the Precision-Recall curve.
    """
    pr_auc, precision, recall = get_precision_recall_auc(influence, flipped_indices)

    # Plot Precision-Recall curve
    plt.figure()
    plt.plot(recall, precision, color="darkorange", lw=2, label="PR curve (area = %0.3f)" % pr_auc)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])

    # Scatter LLM points, checking that the lists have the same length
    if len(precision_llm) == len(recall_llm) == len(llm_label):
        if len(precision_llm) > 0:
            for precision, recall, label in zip(precision_llm, recall_llm, llm_label):
                plt.scatter(recall, precision, lw=2, label=label)
    else:
        print("Warning: precision_llm, recall_llm, and llm_label must be lists of the same length.")

    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.legend(loc="lower left")
    plt.title(title)
    plt.show()
    print("AUC value:", pr_auc)
    return pr_auc


def get_roc_auc(influence, flipped_indices):
    total_points = len(influence)
    x_vals = np.linspace(0, 1, total_points)

    # Generate the true conditions assuming that the detection of flipped indices is binary
    true_conditions = np.zeros_like(x_vals)

    for i in range(total_points):
        if i in flipped_indices:
            true_conditions[i] = 1

    fpr, tpr, _ = roc_curve(true_conditions, influence)
    # Calculate the AUC
    roc_auc = auc(fpr, tpr)
    return roc_auc, fpr, tpr


def plot_roc_curve(influence, flipped_indices, title, fpr_llm=[], tpr_llm=[], llm_label=[]):
    """
    Plots the ROC curve for given data and flipped indices and calculates the AUC value.

    Parameters:
    influence (np.array): Array of data points from the RapidInf algorithm.
    flipped_indices (list): List of indices that were flipped.
    noise_percentage (int): The percentage of noise used in the title of the plot. Default is 20%.

    Returns:
    float: AUC value of the ROC curve.
    """
    roc_auc, fpr, tpr = get_roc_auc(influence, flipped_indices)
    # Plot ROC curve
    plt.figure()
    plt.plot(fpr, tpr, color="darkorange", lw=2, label="ROC curve (area = %0.3f)" % roc_auc)
    plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])

    # Scatter LLM points, checking that the lists have the same length
    if len(fpr_llm) == len(tpr_llm) == len(llm_label):
        if len(fpr_llm) > 0:
            for fpr, tpr, label in zip(fpr_llm, tpr_llm, llm_label):
                plt.scatter(fpr, tpr, lw=2, label=label)
    else:
        print("Warning: fpr_llm, tpr_llm, and llm_label must be lists of the same length.")

    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.legend(loc="lower right")
    plt.title(title)
    plt.show()


def calculate_roc_auc(rapid_grad_train, rapid_grad_val, flipped_indices, subset_indices, N, S=0):
    """
    Calculates the ROC AUC value for the given influence data, flipped indices, and subset indices.

    Parameters:
    influence (np.array): Array of data points from the RapidInf algorithm.
    flipped_indices (list): List of indices that were flipped.
    subset_indices (np.array): 2D numpy array of random indices of shape (N, S).

    Returns:
    float: AUC value of the ROC curve.
    """
    n_val = len(rapid_grad_val)
    if S == 0:
        S = n_val // 2  # subset size of eval set, use half of the eval set
    import random

    random.seed(42)
    import numpy as np
    from tqdm import tqdm

    # generate random indices from 0 to n_val - 1 with size S, N times
    random_indices = np.random.randint(0, n_val, (N, S))
    # Generate random indices from 0 to n_val - 1 with size S, N times
    random_indices = [random.sample(range(n_val), S) for _ in range(N)]
    # Convert to numpy array for easier handling
    random_indices_np = np.array(random_indices)
    print(random_indices_np.shape)  # (N, S)

    # Generate the true conditions assuming that the detection of flipped indices is binary
    n_train = len(rapid_grad_train)
    true_conditions = np.zeros(len(rapid_grad_train))

    for i in range(len(rapid_grad_train)):
        if i in flipped_indices:
            true_conditions[i] = 1

    # Initialize the list to store AUC values
    auc_list = []
    lam = 0
    for grad in rapid_grad_train:
        lam += torch.mean(grad**2)
    lam = 0.1 / n_train * lam
    # make a tensor of shape (n_train, D) where each row is a flattened gradient
    train_grads = torch.stack(rapid_grad_train)
    # compute train_grads * train_grads^T
    train_grads_dots = torch.matmul(
        train_grads, train_grads.t()
    )  # this stores dots of all pairs of gradients in train_grads
    # Precompute terms that can be reused
    one_over_lam = 1 / lam
    one_over_lam_n_train = one_over_lam / n_train
    lam_plus_diag = lam + train_grads_dots.diag()
    for indices in tqdm(subset_indices):
        ### Caluclate RapidInf
        val_grad = torch.zeros_like(rapid_grad_val[0])
        for i, grad in enumerate(rapid_grad_val):
            if i in indices:
                val_grad += grad
        val_grad /= len(indices)
        # compute train_grads * val_grad_avg^T
        val_grad_dots = torch.matmul(train_grads, val_grad.t())  # this stores dots of train_grads with val_grad_avg
        # Initialize inf_list as a tensor for better performance
        rapidinf = torch.zeros(n_train)
        # Calculate the first term outside the loop
        rapidinf = -1 / lam * val_grad_dots
        # Use vectorized operations for the second term
        for k in range(n_train):
            rapidinf[k] += torch.sum(one_over_lam_n_train * (train_grads_dots[:, k] * val_grad_dots) / lam_plus_diag)
        influence = rapidinf.tolist()

        ### Calculate ROC AUC
        fpr, tpr, _ = roc_curve(true_conditions, influence)
        # Calculate the AUC
        roc_auc = auc(fpr, tpr)
        auc_list.append(roc_auc)

    return auc_list
