import argparse
import random
import numpy as np
import torch
import torch.nn as nn
from torch_geometric.data import DataLoader

from losses import *
from gin import GIN, ChebNetII_V2, GPRGNN_V2, BernNet_V2
from evaluate_embedding import evaluate_embedding
from model import *
from aug import TUDataset_aug as TUDataset


class simclr(nn.Module):
  def __init__(self, encoder, embedding_dim, tau, alpha=0.5, beta=1., gamma=.1):
    super(simclr, self).__init__()
    self.tau = tau
    self.alpha = alpha
    self.beta = beta
    self.gamma = gamma
    self.prior = args.prior
    self.embedding_dim = embedding_dim

    self.encoder = encoder
    self.proj_head = nn.Sequential(nn.Linear(self.embedding_dim, self.embedding_dim), nn.ReLU(inplace=True), nn.Linear(self.embedding_dim, self.embedding_dim))
    self.init_emb()

  def init_emb(self):
    for m in self.modules():
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

  def forward(self, x, edge_index, batch):
    if x is None:
        x = torch.ones(batch.shape[0]).to(device)
    y, M = self.encoder(x, edge_index, batch)
    y = self.proj_head(y)
    return y

  def loss_cal(self, x, x_aug):
    batch_size, _ = x.size()
    x_abs = x.norm(dim=1)
    x_aug_abs = x_aug.norm(dim=1)

    sim_matrix = torch.einsum('ik,jk->ij', x, x_aug) / torch.einsum('i,j->ij', x_abs, x_aug_abs)
    sim_matrix = torch.exp(sim_matrix / self.tau)
    pos_sim = sim_matrix[range(batch_size), range(batch_size)]
    loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim)
    loss = - torch.log(loss).mean()
    return loss


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    np.random.seed(seed)
    random.seed(seed)


def arg_parse():
    parser = argparse.ArgumentParser(description='GcnInformax Arguments.')
    parser.add_argument('--dataset', help='Dataset')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--prior', dest='prior', action='store_const',  const=True, default=False)
    parser.add_argument('--gpu', default=0, type=int)

    parser.add_argument('--tau', type=float, default=0.5)
    parser.add_argument('--encoder', type=str, default='GIN')
    parser.add_argument('--aug', type=str, default='dnodes')
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--lr', default=0.01, dest='lr', type=float, help='Learning rate.')
    parser.add_argument('--weight_decay', default=0, type=float)
    parser.add_argument('--num_layers', type=int, default=5, help='Number of graph convolution layers before each pooling')
    parser.add_argument('--hidden_dim', type=int, default=32, help='')

    parser.add_argument('--dropout', type=float, default=0.5, help='dropout for neural networks.')
    parser.add_argument('--use_bn', action='store_true')
    parser.add_argument('--K', type=int, default=10, help='propagation steps.')
    parser.add_argument('--alpha', type=float, default=0.1, help='alpha for APPN/GPRGNN.')
    parser.add_argument('--dprate', type=float, default=0.0, help='dropout for propagation layer.')
    parser.add_argument('--q', type=int, default=0, help='The constant for ChebBase.')
    parser.add_argument('--Init', type=str,choices=['SGC', 'PPR', 'NPPR', 'Random', 'WS', 'Null'], default='PPR', help='initialization for GPRGNN.')
    return parser.parse_args()


if __name__ == '__main__':
    args = arg_parse()
    setup_seed(args.seed)
    print(args)
    print('---------------------')

    accuracies = {'val':[], 'test':[]}
    log_interval = 10
    batch_size = 128
    path = 'data'

    dataset = TUDataset(path, name=args.dataset, aug=args.aug).shuffle()
    dataset_eval = TUDataset(path, name=args.dataset, aug='none').shuffle()
    print(len(dataset))

    try:
       dataset_num_features = dataset[0].x.shape[1]
    except:
       dataset_num_features = 1
    print(dataset_num_features)
    # try:
    #     dataset_num_features = dataset.get_num_feature()
    # except:
    #     dataset_num_features = 1
    dataloader = DataLoader(dataset, batch_size=batch_size)
    dataloader_eval = DataLoader(dataset_eval, batch_size=batch_size)

    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
    if args.encoder == 'GIN':
       encoder = GIN(dataset_num_features, args.hidden_dim, args.num_layers, device=device)
       embedding_dim = args.hidden_dim * args.num_layers
    elif args.encoder == 'ChebNetII_V2':
       encoder = ChebNetII_V2(dataset_num_features, args, device=device)
       embedding_dim = dataset_num_features * (args.K + 1)
    elif args.encoder == 'GPRGNN_V2':
       encoder = GPRGNN_V2(dataset_num_features, args, device=device)
       embedding_dim = dataset_num_features * (args.K + 1)
    elif args.encoder == 'BernNet_V2':
       encoder = BernNet_V2(dataset_num_features, args, device=device)
       embedding_dim = dataset_num_features * (args.K + 1)

    
    model = simclr(encoder=encoder, embedding_dim=embedding_dim, tau=args.tau).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    for epoch in range(1, args.epochs + 1):
        loss_all = 0
        model.train()
        for data in dataloader:
            data, data_aug = data
            optimizer.zero_grad()
            
            node_num, _ = data.x.size()
            data = data.to(device)
            x = model(data.x, data.edge_index, data.batch)

            if args.aug == 'dnodes' or args.aug == 'subgraph' or args.aug == 'random2' or args.aug == 'random3' or args.aug == 'random4':
                edge_idx = data_aug.edge_index.numpy()
                _, edge_num = edge_idx.shape
                idx_not_missing = [n for n in range(node_num) if (n in edge_idx[0] or n in edge_idx[1])]

                node_num_aug = len(idx_not_missing)
                data_aug.x = data_aug.x[idx_not_missing]

                data_aug.batch = data.batch[idx_not_missing]
                idx_dict = {idx_not_missing[n]:n for n in range(node_num_aug)}
                edge_idx = [[idx_dict[edge_idx[0, n]], idx_dict[edge_idx[1, n]]] for n in range(edge_num) if not edge_idx[0, n] == edge_idx[1, n]]
                data_aug.edge_index = torch.tensor(edge_idx).transpose_(0, 1)

            data_aug = data_aug.to(device)
            x_aug = model(data_aug.x, data_aug.edge_index, data_aug.batch)

            loss = model.loss_cal(x, x_aug)
            loss_all += loss.item() * data.num_graphs
            loss.backward()
            optimizer.step()
        print('Epoch {}, Loss {}'.format(epoch, loss_all / len(dataloader)))


    model.eval()
    emb, y = model.encoder.get_embeddings(dataloader_eval)
    acc_mean, acc_std = evaluate_embedding(emb, y)
    print(acc_mean)
    print(acc_std)