import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from typing import List
import torch.nn as nn
import math
from einops import rearrange
import random
import ot
import matplotlib as mpl
import matplotlib.pyplot as plt
import scipy.sparse as sp
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.manifold import TSNE
import os
import copy
from typing import Tuple, List

device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')


'''Classes'''
# sin activation
class Sine(nn.Module):
    def __init__(self, w0 = 1.):
        super().__init__()
        self.w0 = w0
    def forward(self, x):
        return torch.sin(self.w0 * x)

# Siren layer:
class Siren(nn.Module):
    def __init__(self, dim_in, dim_out, w0 = 1., c = 6., is_first = False, use_bias = True, activation = 'sine'):
        super().__init__()
        self.dim_in = dim_in
        self.is_first = is_first
        self.act = activation

        weight = torch.zeros(dim_out, dim_in)
        bias = torch.zeros(dim_out) if use_bias else None
        self.init_(weight, bias, c = c, w0 = w0)

        self.weight = nn.Parameter(weight)
        self.bias = nn.Parameter(bias) if use_bias else None

       
        if activation=='sine':
            self.activation=Sine(w0)
        elif activation=='relu':
            self.activation=nn.ReLU(inplace=False)
        elif activation=='id':
            self.activation=nn.Identity()
        elif activation=='sigmoid':
            self.activation=nn.Sigmoid()
        else:
            raise ValueError('No mlp activation specified')


    def init_(self, weight, bias, c, w0):
        dim = self.dim_in
        act = self.act

        if act =='relu':
            w_std = math.sqrt(1/dim)
        else:
            w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
        
        weight.uniform_(-w_std, w_std)
        if exists(bias):
            bias.uniform_(-w_std, w_std)

    def forward(self, x):
        out =  F.linear(x, self.weight, self.bias)
        out = self.activation(out)
        return out

# Siren network:    
class SirenNet(nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out, num_layers, w0 = 1., w0_initial = 30., use_bias = True, activation = 'sine', final_activation = 'sigmoid'):
        super().__init__()

        self.dim_hidden = dim_hidden
        self.num_layers = len(dim_hidden)
        self.layers = nn.ModuleList([])
        for ind in range(self.num_layers):
            is_first = ind == 0
            layer_w0 = w0_initial if is_first else w0
            layer_dim_in = dim_in if is_first else dim_hidden[ind-1]
            self.layers.append(Siren(
                dim_in = layer_dim_in,
                dim_out = dim_hidden[ind],
                w0 = layer_w0,
                use_bias = use_bias,
                is_first = is_first,
                activation = activation
            ))

        final_activation = 'id' if not exists(final_activation) else final_activation
        self.last_layer = Siren(dim_in = dim_hidden[num_layers-1], dim_out = dim_out, w0 = w0, use_bias = use_bias, activation = final_activation)

    def forward(self, x, mods = None):
        mods = cast_tuple(mods, self.num_layers)

        for layer, mod in zip(self.layers, mods):
            x = layer(x)

            if exists(mod):
                x = x*rearrange(mod, 'd -> () d')   
                # x = x * mod

        return self.last_layer(x)

# GNN network using GCNConv:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_dimensions):
        super(GCN, self).__init__()        
        self.layers = torch.nn.ModuleList()        
        self.layers.append(GCNConv(in_channels, hidden_channels[0]))        
        for i in range(1, len(hidden_channels)):
            self.layers.append(GCNConv(hidden_channels[i-1], hidden_channels[i]))  
        
        # Final fully connected layer
        self.fc1 = torch.nn.Linear(hidden_channels[-1], num_dimensions)        
    def forward(self, x, edge_index):
        for conv in self.layers:
            x = conv(x, edge_index)
            x = F.relu(x)
        x_pre_fc = x
        x_post_fc = self.fc1(x_pre_fc)
        x_post_fc = torch.sigmoid(x_post_fc)
        x_post_fc = (x_post_fc - x_post_fc.min())/(x_post_fc.max() - x_post_fc.min())        

        return x_pre_fc, x_post_fc
    
# ISGL network without pooling:
class ISGL(nn.Module):
    def __init__(self, model1, model2):
        super(ISGL, self).__init__()
        self.model1 = model1
        self.model2 = model2

    def forward(self, g):
        output_gnn_pre, output_gnn_post = self.model1(g.x, g.edge_index)
        n = output_gnn_post.size(0)
        z = output_gnn_post.squeeze(-1)
        
        x_coords = z.repeat_interleave(n)
        y_coords = z.repeat(n)
        edge_coords = torch.stack([x_coords, y_coords], dim=1)
        edge_coords = edge_coords.to(device)

        output_inr = self.model2(edge_coords)
        return output_inr, output_gnn_post

# GNN network:
class GNN_class(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GNN_class, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)  # Log softmax for classification
    
# MLP network:
class MLP(nn.Module):
    def __init__(self, layer_sizes):
        super(MLP, self).__init__()
        layers = []
        for i in range(len(layer_sizes) - 2):  # Hidden layers
            layers.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1]))
            layers.append(nn.Softmax())  # or another activation function
        
        layers.append(nn.Linear(layer_sizes[-2], layer_sizes[-1]))  # Final layer
        
        # Apply Softmax only if it's a classification problem
        self.is_classification = True  # Change this depending on your problem
        if self.is_classification:
            layers.append(nn.Softmax(dim=1))
        
        self.mlp = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.mlp(x)
    
'''Functions'''
def nx2torch2(adjacency_matrix, features=None):
    if sp.issparse(adjacency_matrix):
        adjacency_matrix = adjacency_matrix.toarray()  # Convert sparse matrix to dense    
    adjacency_matrix = torch.tensor(adjacency_matrix, dtype=torch.float)
    edge_index = torch.nonzero(adjacency_matrix, as_tuple=False).t().contiguous()
    edge_value = adjacency_matrix.reshape(-1)
    if features is None:
        x = torch.randn(adjacency_matrix.size(0), 1)
    else:
        x = torch.tensor(features, dtype=torch.float)
    # x = torch.tensor([adjacency_matrix[i,:].sum() for i in range(adjacency_matrix.size(0))], dtype=torch.float).view(-1, 1)
    data = Data(x=x, edge_index=edge_index, y=edge_value).to(device)
    return data


def graph2XY(graphs, num_nodes, model_ISGL, latent_val=None, hsize=4, sortDeg=0):
    X_all = []
    y_all = []
    w_all = []

    # Loop through each graph
    for graph_idx in range(len(graphs)):
        # Sorting:
        graph_test = graphs[graph_idx]
        if sp.issparse(graph_test):
            graph_test = graph_test.toarray()  # Convert sparse matrix to dense
        weight = graph_test.shape[0] / num_nodes
        
        # get perm based on degree
        if sortDeg:
            perm = np.argsort([np.sum(graph_test[:,j]) for j in range(graph_test.shape[0])])
        else:
            data_i = nx2torch2(graph_test)
            _, output_gnn_post = model_ISGL.model1(data_i.x, data_i.edge_index)
            perm_i = torch.argsort(output_gnn_post.squeeze(-1))
            perm = perm_i.cpu().numpy()  # Convert torch permutation to numpy array
        
        graph_test_sort = graph_test[perm, :][:, perm]
        # Average pooling to have the histogram
        N = graph_test_sort.shape[0]
        if hsize == 1:
            h = 1
        if hsize == 2:
            h = 2
        if hsize == 3:
            h = 3
        if hsize == 4:
            h = int(np.log(N))
        if hsize == 5:
            h = int(np.sqrt(N))

        h = max(1, h)
        
        
        A = torch.tensor(graph_test_sort, dtype=torch.float, device=device)
        A = A.view(1, 1, N, N)
        H = F.avg_pool2d(A, kernel_size=(h, h))
        H = H.cpu().numpy().squeeze()  # Convert pooled result back to numpy
        k = H.shape[0]

        # Get the coordinates:
        x_coord = (np.arange(k) + 0.5) / k
        y_coord = (np.arange(k) + 0.5) / k
        xx, yy = np.meshgrid(x_coord, y_coord)
        upper_triangle_mask = np.triu(np.ones((k, k), dtype=bool))
        X_upper_triangle = np.column_stack((xx[upper_triangle_mask], yy[upper_triangle_mask]))
        y_upper_triangle = H[upper_triangle_mask].reshape(-1, 1)

        # X = np.column_stack((xx.ravel(), yy.ravel()))
        # y = H.ravel().reshape(-1, 1)
        X = X_upper_triangle
        y = y_upper_triangle
        w = np.tile(weight, (X.shape[0], 1))
        if latent_val is not None:
            z_this = latent_val[graph_idx]
            zz = np.tile(z_this, (X.shape[0], 1))
            X = np.column_stack((X, zz))

        X_all.append(X)
        y_all.append(y)
        w_all.append(w)

    # Convert lists to tensors
    X_all = torch.tensor(np.concatenate(X_all, axis=0), dtype=torch.float, device=device)
    y_all = torch.tensor(np.concatenate(y_all, axis=0), dtype=torch.float, device=device)
    w_all = torch.tensor(np.concatenate(w_all, axis=0), dtype=torch.float, device=device)

    return X_all, y_all, w_all, perm

def exists(val):
    return val is not None

def cast_tuple(val, repeat = 1):
    return val if isinstance(val, tuple) else ((val,) * repeat)

def gw_distance(graphon: np.ndarray, estimation: np.ndarray) -> float:
    p = np.ones((graphon.shape[0],)) / graphon.shape[0]
    q = np.ones((estimation.shape[0],)) / estimation.shape[0]
    loss_fun = 'square_loss'
    dw2 = ot.gromov.gromov_wasserstein2(graphon, estimation, p, q, loss_fun, log=False, armijo=False)
    return np.sqrt(dw2)

def coords_prediction(inr_dim_hidden, gnn_dim_hidden, n_epochs, epoch_show, w0, graphs, lr):
    model_inr = SirenNet(dim_in = 2, # input [x,y] coordinate
                dim_hidden = inr_dim_hidden,
                dim_out = 1, # output graphon (edge) probability 
                num_layers = len(inr_dim_hidden), # f_theta number of layers
                final_activation = 'sigmoid',
                w0_initial = w0).to(device)
    
    model_gnn = GCN(1, gnn_dim_hidden, 1).to(device) 
    model_ISGL = ISGL(model_gnn, model_inr).to(device)
    graph_data = [nx2torch2(graph_i) for graph_i in graphs]
    optimizer = torch.optim.Adam(model_ISGL.parameters(), lr=lr)
    criterion = torch.nn.MSELoss()
    model_ISGL.train()
    for epoch in range(1,n_epochs+1):
        random.shuffle(graph_data)
        loss_e = 0
        for data_i in graph_data:  # Iterate in graphs over the training dataset.
            # data_i = nx2torch(graph_i)
            edges_pred, _ = model_ISGL(data_i)  # Perform a single forward pass.
            loss = criterion(edges_pred.squeeze(-1), data_i.y)  # Compute the loss based on smoothing values
            if torch.isnan(loss):
                break
            loss.backward()  # Derive gradients.
            optimizer.step()  # Update parameters based on gradients.
            optimizer.zero_grad()  # Clear gradients.
            loss_e = loss_e + loss.item()
        if epoch==1 or epoch % epoch_show == 0:
            print("epoch: ", epoch, " loss: ", loss_e)


    return model_ISGL, loss_e

def train_graphon(inr_dim_hidden, w0, X_all, y_all, w_all, n_epochs, epoch_show, lr, batch_size, wLoss=0, isparametric=0):
    inr_model = SirenNet(dim_in = 2 + isparametric , # input [x,y] coordinate
                dim_hidden = inr_dim_hidden,
                dim_out = 1, # output graphon (edge) probability 
                num_layers = len(inr_dim_hidden), # f_theta number of layers
                final_activation = 'sigmoid',
                w0_initial = w0).to(device)
    train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(X_all, y_all, w_all), batch_size=batch_size, shuffle=True)
    optimizer = torch.optim.Adam(inr_model.parameters(), lr=lr)
    criterion = torch.nn.MSELoss(reduction='none')
    for epoch in range(1, n_epochs+1):
        loss_e = 0
        for loader_i, (X, y, w) in enumerate(train_loader):
            X, y, w = X.to(device), y.to(device), w.to(device)
            optimizer.zero_grad()
            y_pred = inr_model(X)
            loss = criterion(y_pred, y)
            weighted_loss = torch.mean(loss * w) if wLoss else torch.mean(loss)
            weighted_loss.backward()
            optimizer.step()
            loss_e = loss_e + weighted_loss.item()
        if epoch==1 or epoch % epoch_show == 0:
            print('Epoch: {}, Loss: {:.5f}'.format(epoch, loss_e))

    return inr_model

def get_graphon(Res, model, gpu=False):
    k = Res
    x_coord = (np.arange(k) + 0.5) / k
    y_coord = (np.arange(k) + 0.5) / k
    xx, yy = np.meshgrid(x_coord, y_coord)
    upper_triangle_mask = np.triu(np.ones((k, k), dtype=bool))
    X_upper_triangle = np.column_stack((xx[upper_triangle_mask], yy[upper_triangle_mask]))
    if gpu:
        X_torch = torch.tensor(X_upper_triangle, dtype=torch.float).to(device)
    else:
        X_torch = torch.tensor(X_upper_triangle, dtype=torch.float).cpu()
    model.eval()
    with torch.no_grad():
        graphon_upper = model(X_torch)
    graphon = np.zeros((k, k))
    graphon[np.triu_indices(k, 0)] = graphon_upper.cpu().numpy().reshape(-1)
    graphon = graphon + graphon.T
    np.fill_diagonal(graphon, 0)
    return graphon




def align_graphs(graphs: List[np.ndarray],
                 padding: bool = False, N: int = None) -> Tuple[List[np.ndarray], List[np.ndarray], int, int]:
    """
    Align multiple graphs by sorting their nodes by descending node degrees

    :param graphs: a list of binary adjacency matrices
    :param padding: whether padding graphs to the same size or not
    :return:
        aligned_graphs: a list of aligned adjacency matrices
        normalized_node_degrees: a list of sorted normalized node degrees (as node distributions)
    """
    num_nodes = [graphs[i].shape[0] for i in range(len(graphs))]
    max_num = max(num_nodes)
    min_num = min(num_nodes)

    aligned_graphs = []
    normalized_node_degrees = []
    for i in range(len(graphs)):
        num_i = graphs[i].shape[0]

        node_degree = 0.5 * np.sum(graphs[i], axis=0) + 0.5 * np.sum(graphs[i], axis=1)
        node_degree /= np.sum(node_degree)
        idx = np.argsort(node_degree)  # ascending
        idx = idx[::-1]  # descending

        sorted_node_degree = node_degree[idx]
        sorted_node_degree = sorted_node_degree.reshape(-1, 1)

        sorted_graph = copy.deepcopy(graphs[i])
        sorted_graph = sorted_graph[idx, :]
        sorted_graph = sorted_graph[:, idx]

        max_num = max(max_num, N)

        if padding:
            # normalized_node_degree = np.ones((max_num, 1)) / max_num
            normalized_node_degree = np.zeros((max_num, 1))
            normalized_node_degree[:num_i, :] = sorted_node_degree

            aligned_graph = np.zeros((max_num, max_num))
            aligned_graph[:num_i, :num_i] = sorted_graph

            normalized_node_degrees.append(normalized_node_degree)
            aligned_graphs.append(aligned_graph)
        else:
            normalized_node_degrees.append(sorted_node_degree)
            aligned_graphs.append(sorted_graph)

        if N:
            aligned_graphs = [aligned_graph[:N, :N] for aligned_graph in aligned_graphs]
            normalized_node_degrees = normalized_node_degrees[:N]

    return aligned_graphs, normalized_node_degrees, max_num, min_num



def graph_numpy2tensor(graphs: List[np.ndarray]) -> torch.Tensor:
    """
    Convert a list of np arrays to a pytorch tensor
    :param graphs: [K (N, N) adjacency matrices]
    :return:
        graph_tensor: [K, N, N] tensor
    """
    graph_tensor = np.array(graphs)
    return torch.from_numpy(graph_tensor).float()

def universal_svd(aligned_graphs: List[np.ndarray], threshold: float = 2.02) -> np.ndarray:
    """
    Estimate a graphon by universal singular value thresholding.

    Reference:
    Chatterjee, Sourav.
    "Matrix estimation by universal singular value thresholding."
    The Annals of Statistics 43.1 (2015): 177-214.

    :param aligned_graphs: a list of (N, N) adjacency matrices
    :param threshold: the threshold for singular values
    :return: graphon: the estimated (r, r) graphon model
    """
    aligned_graphs = graph_numpy2tensor(aligned_graphs)
    num_graphs = aligned_graphs.size(0)

    if num_graphs > 1:
        sum_graph = torch.mean(aligned_graphs, dim=0)
    else:
        sum_graph = aligned_graphs[0, :, :]  # (N, N)

    num_nodes = sum_graph.size(0)

    u, s, v = torch.svd(sum_graph)
    singular_threshold = threshold * (num_nodes ** 0.5)
    binary_s = torch.lt(s, singular_threshold)
    s[binary_s] = 0
    graphon = u @ torch.diag(s) @ torch.t(v)
    graphon[graphon > 1] = 1
    graphon[graphon < 0] = 0
    graphon = graphon.numpy()
    return graphon