import numpy as np
from tqdm import tqdm
from ogb.linkproppred import Evaluator

from torch_geometric.datasets import PPI
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.nn import GCNConv
from torch_geometric.nn import SAGEConv
from torch_geometric.nn import GATConv
from torch_geometric.nn import GCN2Conv
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_scatter.scatter import scatter_add
from torch_geometric.data.data import Data

from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score


import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch.utils.data import DataLoader

import networkx as nx

@torch.no_grad()
def preprocess(d:Data):
    G = nx.Graph()
    N = d['x'].shape[0]
    G.add_nodes_from(range(N))
    G.add_edges_from(d['edge_index'].T.tolist())
    l = list(nx.find_cliques(G))

    # edge_index = d['edge_index'].T.tolist()
    phantom_edge_index = list()

    for c in l:
        for u in c:
            phantom_edge_index.append([u, N])
            phantom_edge_index.append([N, u])
        N += 1
    phantom_edge_index = torch.tensor(phantom_edge_index).T
    
    x = torch.zeros((N, d['x'].shape[1]), dtype=d['x'].dtype)
    # train_mask = torch.zeros((N), dtype=d['train_mask'].dtype)
    # val_mask = torch.zeros((N), dtype=d['val_mask'].dtype)
    # test_mask = torch.zeros((N), dtype=d['test_mask'].dtype)
    N = d['x'].shape[0]
    x[:N] = d['x']
    
    for c in l:
        idx = torch.tensor(c)
        x[N] = x[idx].mean(dim=0)
        N += 1

    ret = Data(x=x, edge_index=d['edge_index'], phantom_edge_index=phantom_edge_index, y=d['y'])
    return ret



class GCN_PN_Layer(torch.nn.Module):
    def __init__(self, dim, layer):
        super().__init__()
        self.conv = GCN2Conv(dim, alpha=0.5, theta=1, layer=layer)
        self.convp = GCN2Conv(dim, alpha=0.5, theta=1, layer=layer)
        self.out_dim = dim
    
    def forward(self, x, x_0, edge_index, phantom_edge_index, N):
        # return self.conv(x, x_0, edge_index)
        Xp = self.convp(x, x_0, phantom_edge_index)
        
        phantom_edge_weight = torch.ones((phantom_edge_index.size(1), ), dtype=x.dtype,
                                     device=edge_index.device)
        row, col = phantom_edge_index[0], phantom_edge_index[1]
        deg = scatter_add(phantom_edge_weight, row, dim=0, dim_size=x.shape[0])
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
        phantom_edge_weight = deg_inv_sqrt[row] * phantom_edge_weight * deg_inv_sqrt[col]
        
        X = self.conv(x, x_0, edge_index)
        Xprop = torch.zeros((x.shape[0], self.out_dim), device=edge_index.device)
        Xprop[:N] = self.convp.propagate(phantom_edge_index, x=Xp, edge_weight=phantom_edge_weight)[:N]
        X = X + Xprop
        ret = torch.zeros((x.shape[0], self.out_dim), device=edge_index.device)
        ret[:N] = X[:N]
        ret[N:] = Xp[N:]

        return ret

class GCN_PN(torch.nn.Module):
    def __init__(self, in_dim, hid_dim, out_dim):
        super(GCN_PN, self).__init__()
        self.lin_in = torch.nn.Linear(in_dim, hid_dim)
        self.convs = torch.nn.ModuleList()
        for i in range(5):
            self.convs.append(GCN_PN_Layer(hid_dim, i+1))
        self.lin_out = torch.nn.Linear(hid_dim, out_dim)

        # self.convh = GCNConv(hid_dim, hid_dim)
        self.norm = torch.nn.BatchNorm1d(hid_dim)

    def forward(self, x, edge_index, phantom_edge_index, N):
        
        x = x_0 = self.lin_in(x).relu()
        
        for conv in self.convs:
            x = conv(x, x_0, edge_index, phantom_edge_index, N)
            # x = self.norm(x)
            x = x.relu()
            x = F.dropout(x, training=self.training)

        x = self.lin_out(x)

        return x

def train(model, optimizer, data):
    # train 1 epoch on one graph
    model.train()
    optimizer.zero_grad()
    N = data['y'].shape[0]
    logits = model(data['x'].to(device), data['edge_index'].to(device), data['phantom_edge_index'].to(device), N)[:N]
    loss = F.binary_cross_entropy_with_logits(logits, data['y'].to(device).to(torch.float))
    loss.backward()
    optimizer.step()
    return loss.item()

def train_m(model, optimizer, datas):
    # train 1 epoch on graphs
    model.train()
    optimizer.zero_grad()
    loss = 0
    for data in datas:
        N = data['y'].shape[0]
        logits = model(data['x'].to(device), data['edge_index'].to(device), data['phantom_edge_index'].to(device), N)[:N]
        loss = loss + F.binary_cross_entropy_with_logits(logits, data['y'].to(device).to(torch.float))
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def evaluate(model, data):
    model.eval()
    accs = list()
    preds = list()
    ys = list()
    for d in data:
        N = d['y'].shape[0]
        logits = model(d['x'].to(device), d['edge_index'].to(device), d['phantom_edge_index'].to(device), N)[:N]
        preds.append((logits>0).cpu())
        ys.append(d['y'])
        # acc = accuracy_score(d['y'].view(-1), (logits>0).view(-1).cpu())
        # accs.append((acc, logits.shape[0]))

    y_true = torch.cat([y for y in ys], dim=0)
    y_true_np = y_true.detach().cpu().numpy()
    y_pred = torch.cat([y for y in preds], dim=0)
    y_pred_np = y_pred.detach().cpu().numpy()
    
    acc = accuracy_score(y_true_np.flatten(), y_pred_np.flatten())
    f1 = f1_score(y_true_np, y_pred_np, average='micro')
    return acc, f1
    # return acc / len(accs)



train_data = PPI('./data/PPI', split='train')
test_data = PPI('./data/PPI', split='test')

aug_train = list()
aug_test = list()
for i in range(len(train_data)):
    aug_train.append(preprocess(train_data[i]))
for i in range(len(test_data)):
    aug_test.append(preprocess(test_data[i]))

in_dim = 50
hid_dim = 512
out_dim = 121
lr = 5e-3
wt_decay = 5e-6
device = torch.device('cuda')
split = 1

model = GCN_PN(in_dim, hid_dim, out_dim).to(device)
opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wt_decay)



with tqdm(range(3000)) as _tqdm:
    for i in _tqdm:
        # loss = train_m(model, opt, aug_train[:2])
        loss = 0
        for j in range(split):
            loss += train(model, opt, aug_train[j])
        if i % 20 == 0:
            train_acc, _ = evaluate(model, aug_train)
            test_acc, test_f1 = evaluate(model, aug_test)
            _tqdm.set_postfix_str(f'loss: {loss:.4f}, valid acc: {train_acc:.4f}, test_acc: {test_acc:.4f}, test_f1: {test_f1:.4f}')