from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import numpy as np

import torch
import torch.nn.functional as F
import numpy as np
from scipy.linalg import expm
import numpy as np


def classify_sentence_pair(tokenizer, model, sentence1, sentence2, device, max_length=None):
    # Prepare the inputs for the model
    if max_length:
        inputs = tokenizer(sentence1, sentence2, return_tensors='pt', truncation=True, padding=True, max_length=max_length).to(device)
    else:
        inputs = tokenizer(sentence1, sentence2, return_tensors='pt', truncation=True, padding=True).to(device)
    
    # Get the model outputs
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Get the predicted label
    logits = outputs.logits
    predicted_class_id = torch.argmax(logits, dim=1).item()
    logits = torch.squeeze(logits)
    #predicted_label = labels[predicted_class_id]

    # hidden_states = outputs.hidden_states
    # last_token_representation = hidden_states[-1][:, 0, :][0, :]

    return predicted_class_id, logits, None #last_token_representation


# def are_equivalent(tokenizer, model, labels, sentence1, sentence2, device):

#     prediction1, logits1, rep1 = classify_sentence_pair(tokenizer, model, labels, sentence1, sentence2, device)
#     prediction2, logits2, rep2 = classify_sentence_pair(tokenizer, model, labels, sentence2, sentence1, device)

#     #print(prediction1, prediction2)

#     contradiction = "CONTRADICTION"
#     entailment = "ENTAILMENT"
#     if prediction1 != contradiction and prediction2 != contradiction and (prediction1 == entailment or prediction2 == entailment):
#         eq = 1
#     else:
#         eq = 0
    
#     return eq, (logits1, logits2), (rep1, rep2)


def get_labels_logits_and_reps(tokenizer, model, labels, strings_list, device, max_length=None):

    # This function is modified from the following GitHub repository:
    # URL: https://github.com/jlko/semantic_uncertainty/tree/master

    n_samples = len(strings_list)
    label_matrix = np.zeros((n_samples, n_samples)) - 1
    logit_tensor = None
    rep_tensor = None
    # Keep track of current id.
    for i, string1 in enumerate(strings_list):
        for j in range(0, len(strings_list)):
            label, logit, rep = classify_sentence_pair(tokenizer, model, string1, strings_list[j], device, max_length=max_length)
            label_matrix[i, j] = label
            
            if logit_tensor is None:
                dim = logit.shape[0]
                # print(f"logit shape: {logit.shape}")
                logit_tensor = np.zeros((n_samples, n_samples, dim))
            logit_tensor[i, j, :] = logit.detach().cpu().numpy()

            # if rep_tensor is None:
            #     dim = rep.shape[0]
            #     #print(f"dim: {dim}")
            #     rep_tensor = np.zeros((n_samples, n_samples, dim))
            # rep_tensor[i, j, :] = rep.detach().cpu().numpy()

    return label_matrix, logit_tensor, rep_tensor


def get_labels_logits_and_reps_cross(tokenizer, model, labels, strings_list_A, strings_list_B, device, max_length=None):

    # This function is modified from the following GitHub repository:
    # URL: https://github.com/jlko/semantic_uncertainty/tree/master

    assert len(strings_list_A) == len(strings_list_B)

    n_samples = len(strings_list_A)
    label_matrix = np.zeros((n_samples, n_samples)) - 1
    logit_tensor = None
    rep_tensor = None
    for i, string1 in enumerate(strings_list_A):
        for j in range(0, len(strings_list_B)):
            label, logit, rep = classify_sentence_pair(tokenizer, model, string1, strings_list_B[j], device, max_length=max_length)
            label_matrix[i, j] = label
            
            if logit_tensor is None:
                dim = logit.shape[0]
                # print(f"logit shape: {logit.shape}")
                logit_tensor = np.zeros((n_samples, n_samples, dim))
            logit_tensor[i, j, :] = logit.detach().cpu().numpy()

            # if rep_tensor is None:
            #     dim = rep.shape[0]
            #     #print(f"dim: {dim}")
            #     rep_tensor = np.zeros((n_samples, n_samples, dim))
            # rep_tensor[i, j, :] = rep.detach().cpu().numpy()
    
    return label_matrix, logit_tensor, rep_tensor


def equivalance_rule(prediction1, prediction2):
    contradiction = "CONTRADICTION"
    entailment = "ENTAILMENT"
    if prediction1 != contradiction and prediction2 != contradiction and (prediction1 == entailment or prediction2 == entailment):
        eq = 1
    else:
        eq = 0
    return eq


def get_equivalence_matrix_from_labels(label_matrix, labels):

    n = label_matrix.shape[0]
    equivalence_matrix = np.zeros((n, n)) - 1
    for i in range(n):
        for j in range(n):
            prediction1 = labels[label_matrix[i, j]]
            prediction2 = labels[label_matrix[j, i]]
            equivalence_matrix[i, j] = equivalance_rule(prediction1, prediction2)
        
    return equivalence_matrix


def softmax(x):
    # Subtract the max to prevent overflow issues with large values
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()


def get_similarity_matrix_from_logits(logit_tensor):

    n = logit_tensor.shape[0]
    similarity_matrix = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            logits1 = logit_tensor[i, j]
            logits2 = logit_tensor[j, i]
            prob1 = softmax(logits1)
            prob2 = softmax(logits2)
            N1, N2 = prob1[1], prob2[1]
            E1, E2 = prob1[2], prob2[2]
            similarity_matrix[i, j] = 0.5 * (E1 + E2)
    return similarity_matrix


def get_similarity_matrix_from_logits_cross(cross_entailment, target_name, verifier_name):

    logit_tensorA = cross_entailment[target_name]
    logit_tensorB = np.transpose(cross_entailment[verifier_name], (1,0,2))

    n = logit_tensorA.shape[0]
    similarity_matrix = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            logits1 = logit_tensorA[i, j]
            logits2 = logit_tensorB[i, j]
            prob1 = softmax(logits1)
            prob2 = softmax(logits2)
            N1, N2 = prob1[1], prob2[1]
            E1, E2 = prob1[2], prob2[2]
            similarity_matrix[i, j] = 0.5 * (E1 + E2)
    return similarity_matrix


def run_greedy_on_equivalence(equivalence_matrix, score_matrix, subset_indices=None):

    if subset_indices is None:
        subset_indices = list(range(equivalence_matrix.shape[0]))

    n_strings = len(subset_indices)

    # Initialise all ids with -1.
    semantic_set_ids = [-1] * n_strings
    scores = {}
    # Keep track of current id.
    next_id = 0
    for i in range(n_strings):
        # Check if string1 already has an id assigned.
        if semantic_set_ids[i] == -1:
            # If string1 has not been assigned an id, assign it next_id.
            semantic_set_ids[i] = next_id
            scores[next_id] = []
            for j in range(i+1, n_strings):
                # Search through all remaining strings. If they are equivalent to string1, assign them the same id.
                eq = equivalence_matrix[subset_indices[i], subset_indices[j]]
                s = score_matrix[subset_indices[i], subset_indices[j]]

                if eq == 1:
                    semantic_set_ids[j] = next_id
                    scores[next_id].append(s)

            next_id += 1
    
    return semantic_set_ids, scores


def semantic_entropy(semantic_ids):

    # This function is modified from the following GitHub repository:
    # URL: https://github.com/jlko/semantic_uncertainty/tree/master

    """Estimate semantic uncertainty from how often different clusters get assigned.

    We estimate the categorical distribution over cluster assignments from the
    semantic ids. The uncertainty is then given by the entropy of that
    distribution. This estimate does not use token likelihoods, it relies soley
    on the cluster assignments. If probability mass is spread of between many
    clusters, entropy is larger. If probability mass is concentrated on a few
    clusters, entropy is small.

    Input:
        semantic_ids: List of semantic ids, e.g. [0, 1, 2, 1].
    Output:
        cluster_entropy: Entropy, e.g. (-p log p).sum() for p = [1/4, 2/4, 1/4].
    """

    n_generations = len(semantic_ids)
    counts = np.bincount(semantic_ids)
    counts = counts[counts != 0]
    probabilities = counts/n_generations
    assert np.isclose(probabilities.sum(), 1)
    entropy = - (probabilities * np.log(probabilities)).sum()
    return entropy


def mean_pairwise_distance(matrix_sim):

    sim = matrix_sim.flatten()
    return 1 - np.mean(sim)
                

def mean_std(matrix_sim, lmbda):

    sim = matrix_sim.flatten()
    return - (np.mean(sim ) -  lmbda*np.std(sim))


def heat_kernel(W, t):

    D = np.diag(W.sum(axis=1))
    L = D-W
    return expm(-t*L)


def unit_trace_normalization(K):
    D = np.diag(K)

    # Create a matrix that contains sqrt(K(i, i) * K(j, j)) for each i, j
    D_sqrt_inv = np.outer(1/np.sqrt(D), 1/np.sqrt(D))

    K_unit_trace = K * D_sqrt_inv / K.shape[1]
    return K_unit_trace


def KLE(matrix_sim, t=0.3):

    sim = matrix_sim

    K = heat_kernel(sim, t=t)
    K = unit_trace_normalization(K)

    # Eigenvalue decomposition to find eigenvalues
    eigenvalues = np.linalg.eigvalsh(K)
    # Calculate von Neumann entropy
    #print(f"Trace: {np.sum(eigenvalues)}")
    von_neumann_entropy = -np.sum(eigenvalues * np.log(eigenvalues + 1e-10))  # 1e-10 to avoid log(0)

    return von_neumann_entropy


def EigV(matrix_sim):

    sim = matrix_sim

    D = np.diag(sim.sum(axis=1))
    L = D-sim

    eigenvalues = np.linalg.eigvalsh(L)
    return np.sum(np.maximum(0, 1-eigenvalues))


def Ecc(matrix_sim, cutoff):

    sim = matrix_sim

    D = np.diag(sim.sum(axis=1))
    L = D-sim

    eigenvalues, eigenvectors = np.linalg.eig(L)
    below_cutoff_indices = np.where(eigenvalues < cutoff)[0]
    eigenvectors_below_cutoff = eigenvectors[:, below_cutoff_indices].real
    V = eigenvectors_below_cutoff.T

    mean_vector = np.mean(V, axis=0)
    var = np.sum((V - mean_vector) ** 2 / V.shape[0])
    
    return var



def quantum_entropy(matrix_sim):

    K = unit_trace_normalization(matrix_sim)

    S = np.linalg.svd(K, full_matrices=False, compute_uv=False)
    eps=1e-10
    return - np.sum(S*np.log(S + eps))

