import torch
from torch.autograd.functional import jacobian
from src.utils import add_gradients, scale_gradients
import sys
import torch.nn as nn
from torch.autograd import grad
from torch.nn import functional as F


class NodeRepClass:
    def __init__(self, model, graph, k_hop_neighbors_list):
        self.model = model
        self.graph = graph.clone()
        self.k_hop_neighbors_list = k_hop_neighbors_list

    def get_node_rep(self, node_feature):
        self.graph.x[self.k_hop_nodes] = node_feature
        return self.model(self.graph)[self.node_idx]
    
    def set_node_idx(self, node_idx):
        self.node_idx = node_idx
        self.k_hop_nodes = self.k_hop_neighbors_list[node_idx]


def dirichlet_energy(model, graph, edge_index) -> torch.Tensor:
    """
    Compute the Dirichlet energy of a graph embedding.
    Returns:`
    --------
    energy : torch.Tensor
        The scalar Dirichlet energy of the graph.
    """
    x = model(graph)
    row, col = edge_index
    
    # Difference between neighboring node embeddings
    diff = x[row] - x[col]
    
    # Squared L2 norm of differences
    diff_squared = (diff ** 2).sum(dim=1)
    
    # Average over all nodes
    num_nodes = x.size(0)
    energy = diff_squared.sum() / num_nodes

    return energy


def pairwise_cosine_similarity(matrix):
    """
    Computes the pairwise cosine similarity between rows of an n x d matrix.

    Args:
        matrix (torch.Tensor): An n x d tensor.

    Returns:
        torch.Tensor: An n x n tensor containing pairwise cosine similarities.
    """
    # Normalize each row to unit length
    norms = torch.norm(matrix, dim=1, keepdim=True)  # (n x 1)
    normalized_matrix = matrix / norms  # (n x d)

    # Compute the pairwise cosine similarity (n x n)
    similarity_matrix = torch.mm(normalized_matrix, normalized_matrix.T)

    return similarity_matrix

def graph_smoothing_level(model, graph):
    node_rep = model(graph)
    num_nodes = node_rep.shape[0]
    cos_sim = pairwise_cosine_similarity(node_rep)
    gsl = (cos_sim.sum() - num_nodes)/(num_nodes*(num_nodes-1))

    return gsl

def mean_validation_loss(model, graph, val_idxs=None):
    criterion = nn.CrossEntropyLoss()

    model.eval()

    if val_idxs is None:
        val_logit = model(graph)[graph.val_mask]
        val_target = graph.y[graph.val_mask]
    else:
        val_logit = model(graph)[val_idxs]
        val_target = graph.y[val_idxs]

    val_loss = criterion(val_logit, val_target)

    return val_loss

def mvl_with_kl_reg(model, graph, reg=1e-3, val_idxs=None):
    criterion = nn.CrossEntropyLoss()

    model.eval()
    logit = model(graph)

    if val_idxs is None:
        val_logit = logit[graph.val_mask]
        val_target = graph.y[graph.val_mask]
    else:
        val_logit = model(graph)[val_idxs]
        val_target = graph.y[val_idxs]

    val_loss = criterion(val_logit, val_target)
    test_logit = logit[graph.test_mask]
    
    mean_val_softmax = F.softmax(val_logit, dim=1).mean(dim=0)
    mean_test_softmax = F.softmax(test_logit, dim=1).mean(dim=0)

    val_test_div = F.kl_div(mean_val_softmax.log(), mean_test_softmax, reduction="sum")

    loss = val_loss + reg * val_test_div
    return loss

def kl_reg(model, graph):
    model.eval()
    logit = model(graph)
    val_logit = logit[graph.val_mask]
    test_logit = logit[graph.test_mask]

    mean_val_softmax = F.softmax(val_logit, dim=1).mean(dim=0)
    mean_test_softmax = F.softmax(test_logit, dim=1).mean(dim=0)

    val_test_div = F.kl_div(mean_val_softmax.log(), mean_test_softmax, reduction="sum")

    return val_test_div

def mean_test_loss(model, graph):
    criterion = nn.CrossEntropyLoss()

    model.eval()
    test_logit = model(graph)[graph.test_mask]
    test_target = graph.y[graph.test_mask]
    test_loss = criterion(test_logit, test_target)

    return test_loss


def local_k_hop_grad(target_node, model, graph, k_hop_neighbors, num_classes, params, calculate_param_grad=True, calculate_weight_grad=True):
    if not graph.x.requires_grad:
        graph.x.requires_grad = True

    sum_grad = 0
    k_hop_cnt = 0
    res_param_grad = None
    mean_param_grad = None
    res_weight_grad = None
    mean_weight_grad = None

    if k_hop_neighbors[target_node].numel() == 0:
        print('Error: target node does not have k hop neighbors')
        sys.exit(0)
    
    abs_sum_grad = 0
    model_output = model(graph)[target_node]
    for c in range(num_classes):
        if not calculate_param_grad and not calculate_weight_grad:
            x_grad = grad(model_output[c], graph.x, retain_graph=True)[0][k_hop_neighbors[target_node]]
        else:
            x_grad = grad(model_output[c], graph.x, create_graph=True)[0][k_hop_neighbors[target_node]]
        abs_k_hop_grad = torch.abs(x_grad).sum()
        abs_sum_grad += abs_k_hop_grad
    
    if calculate_param_grad:
        if calculate_weight_grad:
            param_grad = grad(abs_sum_grad, params, allow_unused=True, retain_graph=True)
        else:
            param_grad = grad(abs_sum_grad, params, allow_unused=True)
        res_param_grad = add_gradients(res_param_grad, param_grad)
    if calculate_weight_grad:
        weight_grad = grad(abs_sum_grad, graph.edge_weight)
        res_weight_grad = add_gradients(res_weight_grad, weight_grad)
    
    sum_grad += abs_sum_grad.detach()
    k_hop_cnt += k_hop_neighbors[target_node].numel()

    mean_grad = sum_grad / k_hop_cnt
    if calculate_param_grad:
        mean_param_grad = scale_gradients(res_param_grad, 1/k_hop_cnt)
    if calculate_weight_grad:
        mean_weight_grad = scale_gradients(res_weight_grad, 1/k_hop_cnt)

    graph.x.grad = None

    return mean_grad, mean_param_grad, mean_weight_grad


def feature_eliminate(graph, eliminate_idxs):
    new_graph = graph.clone()
    new_graph.x[eliminate_idxs] = 0

    return new_graph

def feature_ablation(node_idxs, model, graph, k_hop_neighbors):
    # The contribution of node feature to node representation can be obtained as:
    # GNN(G) - GNN(G'), where G' is the copy of G in which target node feature is eleminated.
    origin_node_rep = model(graph)
    result_list = []

    for v in node_idxs:
        if k_hop_neighbors[v].numel() == 0:
            continue
        node_rep = origin_node_rep[v]
        feature_eliminated_graph = feature_eliminate(graph, k_hop_neighbors[v])
        feature_eliminated_rep = model(feature_eliminated_graph)[v]
        feature_contribution = torch.norm(node_rep - feature_eliminated_rep, p='fro')
        result_list.append(feature_contribution)
    result_tensor = torch.stack(result_list)

    return result_tensor.mean()


def k_hop_grad(node_idxs, model, graph, k_hop_neighbors, num_classes, params, calculate_param_grad=True, calculate_weight_grad=True, square=False):
    # f(graph, \theta) = 1/(\sum_{v \in \mathcal{V}} N_k(v)) \sum_{v \in \mathcal{V}} \sum_{s \in \mathcal{N}_k(v)} |\partial f(graph, \theta)_v/\partial x_s|
    #start_time = time.time()

    if not graph.x.requires_grad:
        graph.x.requires_grad = True

    sum_grad = 0
    k_hop_cnt = 0
    res_param_grad = None
    mean_param_grad = None
    res_weight_grad = None
    mean_weight_grad = None
    f_norm_results = []
    
    node_rep_extracter = NodeRepClass(model, graph, k_hop_neighbors)

    #print('Calculating mean k-hop grad...')
    for i, v in enumerate(node_idxs):
        if k_hop_neighbors[v].numel() == 0:
            continue
        temp_f_norm = 0
        node_rep_extracter.set_node_idx(v)
        
        if not calculate_param_grad and not calculate_weight_grad:
            jacob = jacobian(node_rep_extracter.get_node_rep, graph.x[k_hop_neighbors[v]])
        else:
            jacob = jacobian(node_rep_extracter.get_node_rep, graph.x[k_hop_neighbors[v]], create_graph=True)
        
        for k_hop_idx in range(k_hop_neighbors[v].numel()):
            f_norm = torch.norm(jacob[:,k_hop_idx], p='fro')
            f_norm_results.append(f_norm.item())
            temp_f_norm += f_norm

        #if i % 10 == 9:
        if calculate_param_grad:
            param_grad = grad(temp_f_norm, params, retain_graph=True)
            res_param_grad = add_gradients(res_param_grad, param_grad)
        if calculate_weight_grad:
            weight_grad = grad(temp_f_norm, graph.edge_weight)
            res_weight_grad = add_gradients(res_weight_grad, weight_grad)
        temp_f_norm.detach()
        temp_f_norm = 0

        k_hop_cnt += k_hop_neighbors[v].numel()

    if k_hop_cnt == 0:
        print('Error: k hop neighborhood does not exist.')
        sys.exit(0)

    mean_f_norm = torch.tensor(sum(f_norm_results)/len(f_norm_results))
    if calculate_param_grad:
        mean_param_grad = scale_gradients(res_param_grad, 1/k_hop_cnt)
    if calculate_weight_grad:
        mean_weight_grad = scale_gradients(res_weight_grad, 1/k_hop_cnt)
    #print(f'Consumed time for calculating over-squashing measure: {time.time()-start_time:.2f}s.')
    #print(f'Mean grad: {mean_grad:.4f}')

    graph.x.grad = None

    return mean_f_norm, mean_param_grad, mean_weight_grad