import numpy as np
from tqdm import tqdm
from ogb.linkproppred import Evaluator

from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.nn import GCN2Conv
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_scatter.scatter import scatter_add


import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch.utils.data import DataLoader

def train(model, optimizer, data, with_phantom):
    # train 1 epoch
    model.train()
    target_edges = data['edge_label_index']
    target_labels = data['edge_label']
    loader = DataLoader(range(target_edges.shape[-1]), arg_batch_size, shuffle=True)
    total_loss = 0

    for batch_count, indices in enumerate(tqdm(loader)):
        cur_target_edges = target_edges[:, indices]
        cur_target_labels = target_labels[indices]
        optimizer.zero_grad()
        logits = model(data['x'].to(arg_device), data['edge_index'].to(arg_device), cur_target_edges.to(arg_device), with_phantom)
        loss = F.binary_cross_entropy_with_logits(logits.view(-1), cur_target_labels.to(torch.float).to(arg_device))
        loss.backward()
        optimizer.step()
        total_loss += loss.cpu().item() * arg_batch_size
    return total_loss / len(loader.dataset)

def get_pos_neg_pred(model, data, target_data, with_phantom):
    model.eval()
    target_edges = target_data['edge_label_index']
    target_labels = target_data['edge_label']
    loader = DataLoader(range(target_edges.shape[-1]), arg_batch_size, shuffle=True)
    all_pred = list()
    all_y = list()
    for batch_count, indices in enumerate(loader):
            cur_target_edges = target_edges[:, indices]
            cur_target_labels = target_labels[indices]
            logits = model(data['x'].to(arg_device), data['edge_index'].to(arg_device), cur_target_edges.to(arg_device), with_phantom).view(-1)
            all_pred += logits.cpu().tolist()
            all_y += cur_target_labels.tolist()
    all_pred = np.array(all_pred)
    all_y = np.array(all_y)
    pos_pred = all_pred[all_y == 1]
    neg_pred = all_pred[all_y == 0]

    return pos_pred, neg_pred

@torch.no_grad()
def evaluate(model, train_data, val_data, test_data, with_phantom):
    pos_train, neg_train = get_pos_neg_pred(model, train_data, train_data, with_phantom)
    pos_val, neg_val = get_pos_neg_pred(model, train_data, val_data, with_phantom)
    pos_test, neg_test = get_pos_neg_pred(model, train_data, test_data, with_phantom)
    
    return evaluate_hits(evaluator, pos_train, neg_train, pos_val, neg_val, pos_test, neg_test)

def evaluate_hits(evaluator, pos_train_pred, neg_train_pred, pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred,
                  Ks=[10, 20, 100], use_val_negs_for_train=True):
    """
    Evaluate the hit rate at K
    :param evaluator: an ogb Evaluator object
    :param pos_val_pred: Tensor[val edges]
    :param neg_val_pred: Tensor[neg val edges]
    :param pos_test_pred: Tensor[test edges]
    :param neg_test_pred: Tensor[neg test edges]
    :param Ks: top ks to evaluatate for
    :return: dic[ks]
    """
    results = {}
    # As the training performance is used to assess overfitting it can help to use the same set of negs for
    # train and val comparisons.
    if use_val_negs_for_train:
        neg_train = neg_val_pred
    else:
        neg_train = neg_train_pred
    for K in Ks:
        evaluator.K = K
        train_hits = evaluator.eval({
            'y_pred_pos': pos_train_pred,
            'y_pred_neg': neg_train,
        })[f'hits@{K}']
        valid_hits = evaluator.eval({
            'y_pred_pos': pos_val_pred,
            'y_pred_neg': neg_val_pred,
        })[f'hits@{K}']
        test_hits = evaluator.eval({
            'y_pred_pos': pos_test_pred,
            'y_pred_neg': neg_test_pred,
        })[f'hits@{K}']

        results[f'Hits@{K}'] = (train_hits, valid_hits, test_hits)

    return results

class PhantomEdge(torch.nn.Module):
    def __init__(self, GNN) -> None:
        super().__init__()
        self.GNN = GNN

        self.lin_feat = Linear(arg_d_feature,
                                   arg_hidden_channels)
        self.lin_out = Linear(arg_hidden_channels, arg_hidden_channels)
        self.bn_feats = torch.nn.BatchNorm1d(arg_hidden_channels)
        self.lin = Linear(arg_hidden_channels, 1)

    def forward(self, x, edge_index, phantom_edge_index, with_phantom):
        # return Y=[phantom_edge_index[1]]
        x = self.GNN(x, edge_index, phantom_edge_index, with_phantom)
        x = self.lin_feat(x)
        x1 = x[phantom_edge_index[0]]
        x2 = x[phantom_edge_index[1]]
        x = x1 * x2
        x = self.lin_out(x)
        x = self.bn_feats(x)
        x = F.relu(x)
        x = F.dropout(x, p=arg_feature_dropout, training=self.training)
        x = self.lin(x)
        
        return x


class GCN(torch.nn.Module):
    def __init__(self, dim):
        super(GCN, self).__init__()
        self.conv1 = GCN2Conv(dim, dim)
        self.conv2 = GCN2Conv(dim, dim)
        self.norm = torch.nn.BatchNorm1d(dim)

    def forward(self, x, edge_index, phantom_edge_index):
        x = self.conv1(x, edge_index)
        x = self.norm(x)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return x

class GCN_PE(torch.nn.Module):
    def __init__(self, dim):
        super(GCN_PE, self).__init__()
        self.conv1 = GCN2Conv(dim, normalize=True, layer=1, alpha=0.5, theta=1)
        self.conv2 = GCN2Conv(dim, normalize=True, layer=2, alpha=0.5, theta=1)

        self.conv3 = GCN2Conv(dim, normalize=False, add_self_loops=False, bias=False, layer=1, alpha=0.5, theta=1)
        self.conv4 = GCN2Conv(dim, normalize=False, add_self_loops=False, bias=False, layer=2, alpha=0.5, theta=1)
        self.norm = torch.nn.BatchNorm1d(dim)

        self.factor = torch.nn.Parameter(torch.ones(()))


    def forward(self, x, edge_index, phantom_edge_index, with_phantom):
        x_0 = x.clone().detach()
        row, col = edge_index[0], edge_index[1]
        num_nodes = x.shape[0]
        edge_weight = torch.ones((edge_index.size(1), ), dtype=x.dtype,
                                     device=edge_index.device)
        phantom_edge_weight = torch.ones((phantom_edge_index.size(1), ), dtype=x.dtype,
                                     device=edge_index.device)
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)

        row, col = phantom_edge_index[0], phantom_edge_index[1]
        deg = scatter_add(phantom_edge_weight, row, dim=0, dim_size=num_nodes)

        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
        row, col = phantom_edge_index[0], phantom_edge_index[1]
        phantom_edge_weight = deg_inv_sqrt[row] * phantom_edge_weight * deg_inv_sqrt[col]
        # all_edge_index = edge_index



        x = self.conv1(x, x_0, edge_index)
        if with_phantom:
            out = self.conv3.propagate(phantom_edge_index, x=x, edge_weight=phantom_edge_weight,
                             size=None)
            x = x + 1e-4*out
        x = self.norm(x)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, x_0, edge_index)
        if with_phantom:
            out = self.conv4.propagate(phantom_edge_index, x=x, edge_weight=phantom_edge_weight,
                             size=None)
            x = x + 1e-4*out

        return x


arg_hidden_channels = 1024
arg_feature_dropout = 0.5
arg_batch_size = 2048
arg_lr = 1e-4
arg_epoch = 100
arg_device = 'cuda'
arg_wt = 5e-4
with_phantom = True

D = Planetoid('./dataset/Citeseer', 'Citeseer').data
arg_d_feature = D.num_node_features
# N = 2708, E = 10556
# D['x'] = torch.concat([D['x'], D['y'].view(-1, 1)], dim=1)
transform = RandomLinkSplit(is_undirected=True, num_val=0.1, num_test=0.2,
                                    add_negative_train_samples=True)
train_data, val_data, test_data = transform(D)
evaluator = Evaluator(name='ogbl-ppa')

GNN = GCN_PE(arg_d_feature)
model = PhantomEdge(GNN).to(arg_device)
opt = torch.optim.Adam(model.parameters(), lr=arg_lr, weight_decay=arg_wt)

for i in range(arg_epoch):
    loss = train(model, opt, train_data, with_phantom)
    print(f'Epoch {i}, loss:{loss:.2f}')
    
    results = evaluate(model, train_data, val_data, test_data, with_phantom)

    print(results['Hits@100'])