import argparse

from loader import BioDataset_aug
from torch_geometric.data import DataLoader
from torch_geometric.nn.inits import uniform
from torch_geometric.nn import global_mean_pool

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_scatter import scatter

from tqdm import tqdm
import numpy as np

from model import GNN
from sklearn.metrics import roc_auc_score

import pandas as pd

from copy import deepcopy


def nt_xent_loss_with_mask(anchor: torch.FloatTensor, samples: torch.FloatTensor, pos_mask: torch.FloatTensor, temperature: float):
    def _similarity(z1: torch.Tensor, z2: torch.Tensor):
        z1 = F.normalize(z1)
        z2 = F.normalize(z2)
        return z1 @ z2.t()
    f = lambda x: torch.exp(x / temperature)
    sim = f(_similarity(anchor, samples))  # anchor x sample
    assert sim.size() == pos_mask.size()  # sanity check

    pos = sim * pos_mask
    pos = pos.sum(dim=1)
    neg = sim.sum(dim=1) - pos

    loss = pos / (pos + neg)
    loss = -torch.log(loss)
    # loss = loss / pos_mask.sum(dim=1)

    return loss.mean()


def cycle_index(num, shift):
    arr = torch.arange(num) + shift
    arr[-shift:] = torch.arange(shift)
    return arr

class Discriminator(nn.Module):
    def __init__(self, hidden_dim):
        super(Discriminator, self).__init__()
        self.weight = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.reset_parameters()

    def reset_parameters(self):
        size = self.weight.size(0)
        uniform(size, self.weight)

    def forward(self, x, summary):
        h = torch.matmul(summary, self.weight)
        return torch.sum(x*h, dim = 1)

class graphcl(nn.Module):

    def __init__(self, gnn):
        super(graphcl, self).__init__()
        self.gnn = gnn
        self.pool = global_mean_pool
        self.projection_head = nn.Sequential(nn.Linear(300, 300), nn.ReLU(inplace=True), nn.Linear(300, 300))

    def forward_cl(self, x, edge_index, edge_attr, batch):
        x = self.gnn(x, edge_index, edge_attr)
        z = self.pool(x, batch)
        z = self.projection_head(z)
        return x, z

    def nt_xent_loss(self,
                     z1: torch.FloatTensor, g1: torch.FloatTensor,
                     z2: torch.FloatTensor, g2: torch.FloatTensor,
                     batch: torch.LongTensor, temperature: float,
                     fast_mode: bool = False
                     ):
        num_graphs = batch.max().item() + 1  # N := num_graphs
        num_nodes = z1.size()[0]  # M := num_nodes
        device = z1.device

        if fast_mode:
            values = torch.eye(num_nodes, dtype=torch.float32, device=device)  # [M, M]
            pos_mask = scatter(values, batch, dim=0, reduce='sum')  # [M, N]
        else:
            pos_mask = []
            for i in range(num_graphs):
                mask = batch == i
                pos_mask.append(mask.to(torch.long))
            pos_mask = torch.stack(pos_mask, dim=0).to(torch.float32)

        l1 = nt_xent_loss_with_mask(g1, z2, pos_mask=pos_mask, temperature=temperature)
        l2 = nt_xent_loss_with_mask(g2, z1, pos_mask=pos_mask, temperature=temperature)

        return l1 + l2

    def loss_cl(self, x1, x2):
        T = 0.1
        batch_size, _ = x1.size()
        x1_abs = x1.norm(dim=1)
        x2_abs = x2.norm(dim=1)

        sim_matrix = torch.einsum('ik,jk->ij', x1, x2) / torch.einsum('i,j->ij', x1_abs, x2_abs)
        sim_matrix = torch.exp(sim_matrix / T)
        pos_sim = sim_matrix[range(batch_size), range(batch_size)]
        loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim)
        loss = - torch.log(loss).mean()
        return loss


def train(args, model, device, dataset, optimizer):

    dataset.aug = "none"
    dataset1 = dataset.shuffle()
    dataset2 = deepcopy(dataset1)
    dataset1.aug, dataset1.aug_ratio = args.aug1, args.aug_ratio1
    dataset2.aug, dataset2.aug_ratio = args.aug2, args.aug_ratio2

    loader1 = DataLoader(dataset1, batch_size=args.batch_size, num_workers = args.num_workers, shuffle=False)
    loader2 = DataLoader(dataset2, batch_size=args.batch_size, num_workers = args.num_workers, shuffle=False)

    model.train()

    train_acc_accum = 0
    train_loss_accum = 0

    for step, batch in enumerate(tqdm(zip(loader1, loader2), desc="Iteration")):
        batch1, batch2 = batch
        batch1 = batch1.to(device)
        batch2 = batch2.to(device)

        optimizer.zero_grad()

        x1, g1 = model.forward_cl(batch1.x, batch1.edge_index, batch1.edge_attr, batch1.batch)
        x2, g2 = model.forward_cl(batch2.x, batch2.edge_index, batch2.edge_attr, batch2.batch)
        loss = model.nt_xent_loss(x1, g1, x2, g2, batch1.batch, temperature = 0.1)

        loss.backward()
        optimizer.step()

        train_loss_accum += float(loss.detach().cpu().item())
        # acc = (torch.sum(positive_score > 0) + torch.sum(negative_score < 0)).to(torch.float32)/float(2*len(positive_score))
        acc = torch.tensor(0)
        train_acc_accum += float(acc.detach().cpu().item())

    return train_acc_accum/(step+1), train_loss_accum/(step+1)


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device', type=int, default=7,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size', type=int, default=256,
                        help='input batch size for training (default: 256)')
    parser.add_argument('--epochs', type=int, default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr', type=float, default=0.001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--decay', type=float, default=0,
                        help='weight decay (default: 0)')
    parser.add_argument('--num_layer', type=int, default=5,
                        help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim', type=int, default=300,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio', type=float, default=0,
                        help='dropout ratio (default: 0)')
    parser.add_argument('--JK', type=str, default="last",
                        help='how the node features across layers are combined. last, sum, max or concat')
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument('--model_file', type = str, default = '', help='filename to output the pre-trained model')
    parser.add_argument('--seed', type=int, default=0, help = "Seed for splitting dataset.")
    parser.add_argument('--num_workers', type=int, default = 4, help='number of workers for dataset loading')
    parser.add_argument('--aug1', type=str, default = 'none') 
    parser.add_argument('--aug_ratio1', type=float, default = 0.2)
    parser.add_argument('--aug2', type=str, default = 'none') 
    parser.add_argument('--aug_ratio2', type=float, default = 0.2)
    args = parser.parse_args()


    torch.manual_seed(0)
    np.random.seed(0)
    device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(0)

    #set up dataset
    root_unsupervised = 'dataset/unsupervised'
    dataset = BioDataset_aug(root_unsupervised, data_type='unsupervised')

    print(dataset)

    #set up model
    gnn = GNN(args.num_layer, args.emb_dim, JK = args.JK, drop_ratio = args.dropout_ratio, gnn_type = args.gnn_type)

    model = graphcl(gnn)
    
    model.to(device)

    #set up optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay)
    print(optimizer)


    for epoch in range(1, args.epochs):
        print("====epoch " + str(epoch))
    
        train_acc, train_loss = train(args, model, device, dataset, optimizer)

        print(train_acc)
        print(train_loss)


        if epoch % 20 == 0:
            torch.save(model.gnn.state_dict(), "./models_lg/lg_" + str(epoch) + ".pth")

if __name__ == "__main__":
    main()
