from typing import List
import numpy as np
import copy
import torch_geometric.transforms as T
from torch_geometric.utils import degree, to_dense_adj
import torch.nn.functional as F
import torch
from torch.autograd import Variable
from torch_geometric.utils import dense_to_sparse
from torch_geometric.data import Data
import random
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')



class NormalizedDegree(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, data):
        deg = degree(data.edge_index[0], dtype=torch.float)
        deg = (deg - self.mean) / self.std
        data.x = deg.view(-1, 1)
        # print( data.x.shape )
        return data

def stat_graph(graphs_list: List[Data]):
    num_total_nodes = []
    num_total_edges = []
    for graph in graphs_list:
        num_total_nodes.append(graph.num_nodes)
        num_total_edges.append(  graph.edge_index.shape[1] )
    avg_num_nodes = sum( num_total_nodes ) / len(graphs_list)
    avg_num_edges = sum( num_total_edges ) / len(graphs_list) / 2.0
    avg_density = avg_num_edges / (avg_num_nodes * avg_num_nodes)

    median_num_nodes = np.median( num_total_nodes ) 
    min_num_nodes = min( num_total_nodes )
    max_num_nodes = max( num_total_nodes )
    median_num_edges = np.median(num_total_edges)
    median_density = median_num_edges / (median_num_nodes * median_num_nodes)
    std_num_nodes = np.std(num_total_nodes)

    return avg_num_nodes, max_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density, min_num_nodes, std_num_nodes


def prepare_dataset_onehot_y(dataset):

    y_set = set()
    for data in dataset:
        y_set.add(int(data.y))
    num_classes = len(y_set)

    for data in dataset:
        data.y = F.one_hot(data.y, num_classes=num_classes).to(torch.float)[0]
    return dataset

def get_graphon(Res, model, coords = None):

    x_coord = (np.arange(Res) + 0.5) / Res if coords is None else coords
    y_coord = x_coord
    xx, yy = np.meshgrid(x_coord, y_coord)
    X = np.column_stack((xx.ravel(), yy.ravel()))
    X_torch = torch.tensor(X, dtype=torch.float).to(device)
    model = model.to(device)
    model.eval()
    with torch.no_grad():
        graphon_upper = model(X_torch)
    graphon = graphon_upper.cpu().numpy().reshape(Res, Res)
    graphon = (graphon + graphon.T) / 2
    np.fill_diagonal(graphon, 0)

    return graphon

def two_graphons_mixup(two_graphons, la=0.5, num_sample=20, ge='ISGL', resolution=None):
    
    
    label = la * two_graphons[0][0] + (1 - la) * two_graphons[1][0]
    # print(label) # the mixed label
    sample_graph_label = torch.from_numpy(label).type(torch.float32)
    
    # size of graphs with different number of nodes
    Res = np.linspace(resolution[0], resolution[1], num_sample, dtype=int) 
    # print(Res)

    sample_graphs = []
    for i in range(num_sample):
        Res_i = Res[i]
        if ge=='ISGL':
            coords = None
            inr_1 = two_graphons[0][2]
            inr_2 = two_graphons[1][2]

            graphon_1 = get_graphon(Res_i, inr_1, coords)
            graphon_2 = get_graphon(Res_i, inr_2, coords)
            new_graphon = la * graphon_1 + (1 - la) * graphon_2 # mix the graphons
            sample_graph = (np.random.rand(*new_graphon.shape) <= new_graphon).astype(np.int32) # generate sampled graph

        else:
            new_graphon = la * two_graphons[0][1] + (1 - la) * two_graphons[1][1]
            sample_graph = (np.random.rand(*new_graphon.shape) <= new_graphon).astype(np.int32)

        # ensure the graph is symmetric and undirected
        sample_graph = np.triu(sample_graph)
        sample_graph = sample_graph + sample_graph.T - np.diag(np.diag(sample_graph))
        sample_graph = sample_graph[sample_graph.sum(axis=1) != 0]
        sample_graph = sample_graph[:, sample_graph.sum(axis=0) != 0]

        A = torch.from_numpy(sample_graph)
        edge_index, _ = dense_to_sparse(A)
        num_nodes = sample_graph.shape[0]

        if num_nodes == 0:
            print('num_nodes is 0')
            continue

        pyg_graph = Data()
        pyg_graph.y = sample_graph_label
        pyg_graph.edge_index = edge_index
        pyg_graph.num_nodes = num_nodes
        sample_graphs.append(pyg_graph)
        
    return sample_graphs



def prepare_dataset_x(dataset):
    print("Prepare dataset x")
    if dataset[0].x is None:
        print("dataset[0].x is None")
        max_degree = 0
        degs = []
        for data in dataset:
            degs += [degree(data.edge_index[0], dtype=torch.long)]
            max_degree = max( max_degree, degs[-1].max().item() )
            data.num_nodes = int( torch.max(data.edge_index) ) + 1

        if max_degree < 2000:
            # dataset.transform = T.OneHotDegree(max_degree)

            for data in dataset:
                degs = degree(data.edge_index[0], dtype=torch.long)
                data.x = F.one_hot(degs, num_classes=max_degree+1).to(torch.float)
        else:
            deg = torch.cat(degs, dim=0).to(torch.float)
            mean, std = deg.mean().item(), deg.std().item()
            for data in dataset:
                degs = degree(data.edge_index[0], dtype=torch.long)
                data.x = ( (degs - mean) / std ).view( -1, 1 )
    return dataset




def mixup_cross_entropy_loss(input, target, size_average=True):
    """Origin: https://github.com/moskomule/mixup.pytorch
    in PyTorch's cross entropy, targets are expected to be labels
    so to predict probabilities this loss is needed
    suppose q is the target and p is the input
    loss(p, q) = -\sum_i q_i \log p_i
    """
    assert input.size() == target.size()
    assert isinstance(input, Variable) and isinstance(target, Variable)
    loss = - torch.sum(input * target)
    return loss / input.size()[0] if size_average else loss



def train(model, train_loader, optimizer, num_classes):
    model.train()
    loss_all = 0
    graph_all = 0
    for data in train_loader:
        # print( "data.y", data.y )
        data = data.to(device)
        optimizer.zero_grad()
        _, output = model(data.x, data.edge_index, data.batch)
        y = data.y.view(-1, num_classes)
        loss = mixup_cross_entropy_loss(output, y)
        loss.backward()
        loss_all += loss.item() * data.num_graphs
        graph_all += data.num_graphs
        optimizer.step()
    loss = loss_all / graph_all
    return model, loss


def test(model, loader, num_classes):
    model.eval()
    correct = 0
    total = 0
    loss = 0
    emb_all = []
    for data in loader:
        data = data.to(device)
        emb_out, output = model(data.x, data.edge_index, data.batch)
        emb_all.append(emb_out)
        pred = output.max(dim=1)[1]
        y = data.y.view(-1, num_classes)
        loss += mixup_cross_entropy_loss(output, y).item() * data.num_graphs
        y = y.max(dim=1)[1]
        correct += pred.eq(y).sum().item()
        total += data.num_graphs
    acc = correct / total
    loss = loss / total
    emb_all = torch.cat(emb_all, dim=0) if emb_all else None
    return acc, loss, emb_all