import torch
from sklearn.decomposition import PCA
import numpy as np

import matplotlib.pyplot as plt

"""
Helper functions for processing HuggingFace outputs
"""

def get_out_tensor(outputs):
    tensor = torch.concatenate(outputs.hidden_states)
    return tensor

def get_out_attention(outputs, stacked_tensor=False):
    assert outputs.attentions is not None, "output_attentions is None"

    attn = outputs.attentions

    attn_list = [a[0] for a in attn]

    if stacked_tensor:
        return torch.stack(attn_list)
    else:
        return attn_list

def get_representations(hidden_states):
    representations = torch.concatenate(hidden_states)

    mean = torch.mean(representations[0, :, :], dim = 0)
    # for ilayer in range(representations.shape[0]):
    #     for itoken in range(representations.shape[1]):
    #         representations[ilayer, itoken, :] = representations[ilayer, itoken, :] - mean

    representations = representations - mean

    return representations



def get_representation_pca(tensor):
    nlayers = tensor.shape[0]
    ntoken = tensor.shape[1]


    tensor = tensor.cpu().detach()

    pca_time = -1
    pca = PCA(n_components=2, svd_solver="full")
    pca = pca.fit(tensor[pca_time, :, :])
    
    tensor_pca = np.zeros((nlayers, ntoken, 2))

    for j in range(nlayers):
        tensor_pca[j, :, :] = pca.transform(tensor[j, :, :])

    return tensor_pca

def plot_pca(tensor, name, dim=2, labels=[]):
    if dim == 2:
        fig, ax = plt.subplots(1, 1)
    if dim == 3:
        ax = plt.figure().add_subplot(projection='3d')
        ax.view_init(elev=20, azim=15)
    tensor = tensor.detach().cpu()
    pca = PCA(dim,  svd_solver="full")
    #pca.fit(torch.flatten(tensor, end_dim=-2))
    pca.fit(tensor[-1])

    center = torch.mean(tensor[0], dim=0)

    for i, layer in enumerate(tensor.transpose(0, 1)):
        layer_transform = pca.transform(layer)
        layer_transform -= center

        print(layer_transform.shape)
        input_tuple = tuple(list(layer_transform.transpose()))
        ax.scatter(*input_tuple, s=3)
        ax.plot(*input_tuple, label=labels[i])
    plt.legend(fontsize=8, loc="upper right")
    plt.tight_layout()
    plt.savefig(name + "_pca_paris.jpg", dpi=400)


def get_mean_norm(tensor):
    mean_norm = torch.mean(torch.norm(tensor, dim = -1), dim = 1) 
    return mean_norm

def get_lss(tensor):
    array = []
    for i in range(tensor.shape[1]): #loop through each word
      x = tensor[0, i] # =x_0
      for j in range(1, tensor.shape[0]): #loop through each point
        x = x + (tensor[j, i] - tensor[j-1, i]) / torch.norm(tensor[j, i] - tensor[j-1, i])

      score = (tensor.shape[0] - 1) / torch.norm(x - tensor[0, i])
      array.append(score.item())

    mean_lss = torch.mean(torch.tensor(array))

    return mean_lss


def get_diffs(tensor):

    # compute avg_tokens(|x(t+1)-x(t)|)
    diffs = torch.mean(torch.norm(torch.diff(tensor, dim = 0), dim = -1), dim = 1) 

    return diffs


def get_equidistance(tensor):
    # computes the mean distance between each consecutive hidden representation
    norm_diff = torch.norm(torch.diff(tensor, dim = 0), dim = -1) # |x(t+1) - x(t)|


    var = torch.var(norm_diff, dim = 0)

    mean = torch.mean(norm_diff, dim = 0)

    coef_var = var / mean ** 2

    return coef_var


def get_expodistance(tensor):

    nlayer = tensor.shape[0]
    
    diff = torch.norm(torch.diff(tensor, dim = 0), dim = -1) # |x(t+1) - x(t)|

    expo = torch.mean(torch.log(diff / diff[0]), dim = 1) / (torch.arange(nlayer-1) + 1)

    var = torch.var(expo)
    mean = torch.mean(expo)

    return var / mean ** 2

def get_expodistance2(tensor):

    nlayer = tensor.shape[0]
    
    diff = torch.norm(torch.diff(tensor, dim = 0), dim = -1) # |x(t+1) - x(t)|

    expo = torch.log(diff[1:] / (diff[0])) / (torch.arange(nlayer-2).view(-1, 1) + 2)

    var = torch.var(expo, dim = 0)
    mean = torch.mean(expo, dim = 0)

    coef_var = var[1:] / mean[1:] ** 2 # first token seems to cause issues

    return coef_var


def get_out_logits(outputs):
    tensor = outputs.logits[0]
    return tensor


def clean_up_tokens(tokens):
    tokens_new = []
    for t in tokens:
        if t[0] in ["▁", "Ġ"]:
            tokens_new.append(t[1:])
        else:
            tokens_new.append(t)
    return tokens_new


def get_next_token_prediction(outputs, tokenizer):
    """
    :param outputs: LM outputs
    :param tokenizer: LM tokenizer
    :return: next_token, next_token_id
    """
    logits = get_out_logits(outputs)

    # argmax over last token
    next_token_id = torch.argmax(logits[-1])
    next_token = tokenizer.decode(next_token_id)

    return next_token, next_token_id


def get_attention_variance(attention_matrices, nlayers):

    attentions = torch.stack([a[0] for a in attention_matrices])

    attn_mean = torch.mean(attentions, dim=1)
    mean = torch.mean(attn_mean, dim=0)
    mean_norm = torch.norm(mean, p = "fro", dim = (0, 1))
    var = torch.mean(torch.norm(attn_mean - mean, p = 'fro', dim = (1, 2)))

    return var / mean_norm ** 2


def get_attention_variance2(attention_matrices, nlayers):
    # Compute variance across each head, then average
    
    attentions = torch.stack([a[0] for a in attention_matrices])

    attn_mean = torch.mean(attentions, dim=0)

    mean_norm = torch.norm(attn_mean, p = "fro", dim = (-2, -1))
    var = torch.mean(torch.norm(attentions - attn_mean, p = 'fro', dim = (-2, -1)), dim=0)

    return torch.mean(var / mean_norm**2)





def get_onehot_prompt(vocab_tokens, input_ids):

    vocab_size = len(vocab_tokens)

    one_hot_matrix = torch.zeros(len(input_ids[0]), vocab_size, dtype=torch.float16)

    for i, token_id in enumerate(input_ids[0]):
        one_hot_matrix[i, token_id] = 1

    return one_hot_matrix


def get_model_classifier(model, tokenizer, nlayers, input_ids, attentions):
    
    
    if model.config.model_type == "gpt_neox":
        lm_matrix = model.embed_out.weight.to("cuda:0")
    else:
        lm_matrix = model.lm_head.weight.to("cuda:0")

    vocab_tokens = tokenizer.get_vocab()
    nvocab = len(vocab_tokens)

    y_onehot = get_onehot_prompt(vocab_tokens, input_ids)


    lm_dtype = lm_matrix.dtype
    y_onehot = y_onehot.to("cuda:0")
    y_onehot = y_onehot.to(lm_dtype)

    ntokens = len(input_ids)
    nheads = attentions[0][0].shape[0]

    matrix = torch.matmul(y_onehot, lm_matrix[:nvocab])
    matrix = torch.matmul(matrix, matrix.T)
    mask = torch.triu(torch.ones_like(matrix), diagonal=1)
    matrix = matrix.masked_fill(mask == 1, float('-inf'))
    matrix = torch.softmax(matrix, dim = 1)

    diffs = []
    for ilayer in range(nlayers):
        attn_head = torch.mean(attentions[ilayer][0], dim = 0)
        attn_head = attn_head.to('cuda:0')

        diff = torch.norm(attn_head - matrix, p = 'fro')
    
        diffs.append(diff)

    return torch.mean(torch.tensor(diffs)) / (2 * ntokens)




def get_representation_variation(tensor):

    product = torch.matmul(tensor, torch.transpose(tensor, 1, 2))

    maxval = torch.max(product)

    softmax = torch.softmax(product, dim = -1)

    var = torch.var(softmax, dim = 0)

    variation = torch.mean(var)

    return variation


def get_representation_classifier(model, tokenizer, input_ids, representations):
    vocab_tokens = tokenizer.get_vocab()
    y_onehot = get_onehot_prompt(vocab_tokens, input_ids)

    ntoken = len(input_ids[0])
    nvocab = len(vocab_tokens)

    if model.config.model_type == "gpt_neox":
        lm_matrix = model.embed_out.weight.to("cuda:0")
    else:
        lm_matrix = model.lm_head.weight.to("cuda:0")

    # lm_matrix = model.lm_head.weight.to("cuda:0")
    lm_dtype = lm_matrix.dtype
    y_onehot = y_onehot.to("cuda:0")
    y_onehot = y_onehot.to(lm_dtype)

    matrix = torch.matmul(y_onehot, lm_matrix[:nvocab])
    matrix = torch.matmul(matrix, matrix.T)
    # mask = torch.triu(torch.ones_like(matrix), diagonal=1)
    # matrix = matrix.masked_fill(mask == 1, float('-inf'))
    matrix = torch.softmax(matrix, dim = 1)


    product = torch.matmul(representations, torch.transpose(representations, 1, 2))
    reps_prod = torch.softmax(product, dim = -1).to("cuda:0")

    difference = reps_prod - matrix
    norms = torch.norm(difference, p = "fro", dim = (1, 2))
    mean_norm = torch.mean(norms) / (2 * ntoken)

    return mean_norm



def get_rep_cs(tensor):
    # Normalize representations
    norm = torch.norm(tensor, dim = -1)
    norm_expanded = norm.unsqueeze(-1)
    tensor = tensor / norm_expanded

    # Compute inner products
    product = torch.matmul(tensor, torch.transpose(tensor, 1, 2))

    var = torch.var(product, dim = 0)

    variation = torch.norm(var, p = "fro") / (2 * tensor.shape[1] ** 2)

    return variation / torch.mean(product) ** 2


def get_rep_class_cs(model, tokenizer, input_ids, representations):
    vocab_tokens = tokenizer.get_vocab()
    y_onehot = get_onehot_prompt(vocab_tokens, input_ids)

    ntoken = len(input_ids[0])
    nvocab = len(vocab_tokens)

    if model.config.model_type == "gpt_neox":
        lm_matrix = model.embed_out.weight.to("cuda:0")
    else:
        lm_matrix = model.lm_head.weight.to("cuda:0")

    # Construct W_clf
    # lm_matrix = model.lm_head.weight.to("cuda:0")
    lm_dtype = lm_matrix.dtype
    y_onehot = y_onehot.to("cuda:0")
    y_onehot = y_onehot.to(lm_dtype)




    matrix = torch.matmul(y_onehot, lm_matrix[:nvocab]) # YW
    # Normalize rows
    norm = torch.norm(matrix, dim = -1)
    norm_expanded = norm.unsqueeze(-1)
    matrix /= norm_expanded

    matrix = torch.matmul(matrix, matrix.T) # (YW)(YW)^T, diagonal is unit


    # Construct representation cosine similarity

    norm = torch.norm(representations, dim = -1)
    norm_expanded = norm.unsqueeze(-1)
    representations = representations / norm_expanded

    rep_prod = torch.matmul(representations, torch.transpose(representations, 1, 2)).to("cuda:0")

    # Compare
    difference = rep_prod - matrix
    norms = torch.norm(difference, p = "fro", dim = (1, 2))
    mean_norm = torch.mean(norms) / (2 * ntoken ** 2)

    return mean_norm
