import os
import numpy as np
import random
import argparse
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch_geometric.loader import DataLoader

from gin import Encoder
from evaluate_embedding import evaluate_embedding
from view_generator import ViewGenerator, GIN_NodeWeightEncoder
from load_data import load_dataset
import warnings
warnings.filterwarnings('ignore')


def arg_parse():
    parser = argparse.ArgumentParser(description='GcnInformax Arguments.')
    parser.add_argument('--dataset', dest='dataset', help='Dataset')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--seed', type=int, default=0)

    parser.add_argument('--lr', type=float, default=0.001, help='Learning rate.')
    parser.add_argument('--batch_size', type=int, default = 128, help='')
    parser.add_argument('--epochs', type=int, default = 30, help='')

    parser.add_argument('--num_layers', type=int, default=5, help='Number of graph convolution layers before each pooling')
    parser.add_argument('--hidden_dim', type=int, default=128, help='')
    parser.add_argument('--aug', type=str, default='dnodes')
    parser.add_argument('--tau', type=float, default=0.5)
    return parser.parse_args()


def set_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def create_exp_dir(path, scripts_to_save=None):
    if not os.path.exists(path):
        os.makedirs(path)        
        os.mkdir(os.path.join(path, 'model'))

    if scripts_to_save is not None:
        os.mkdir(os.path.join(path, 'scripts'))
        for script in scripts_to_save:
            dst_file = os.path.join(path, 'scripts', os.path.basename(script))
            shutil.copyfile(script, dst_file)


class simclr(nn.Module):
    def __init__(self, dataset, hidden_dim, num_gc_layers, alpha=0.5, beta=1., gamma=.1, device=None):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.device= device
        self.embedding_dim = hidden_dim * num_gc_layers
        self.encoder = Encoder(dataset.num_features, hidden_dim, num_gc_layers)
        self.proj_head = nn.Sequential(nn.Linear(self.embedding_dim, self.embedding_dim), nn.ReLU(inplace=True), nn.Linear(self.embedding_dim, self.embedding_dim))
        self.init_emb()

    def init_emb(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.fill_(0.0)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        if x is None:
            x = torch.ones(batch.shape[0]).to(self.device)
        y, M = self.encoder(x, edge_index, batch)
        y = self.proj_head(y)
        return y

    def loss_cal(self, x, x_aug):
        T = 0.2
        batch_size, _ = x.size()
        x_abs = x.norm(dim=1)
        x_aug_abs = x_aug.norm(dim=1)
        sim_matrix = torch.einsum('ik,jk->ij', x, x_aug) / torch.einsum('i,j->ij', x_abs, x_aug_abs)
        sim_matrix = torch.exp(sim_matrix / T)
        pos_sim = sim_matrix[range(batch_size), range(batch_size)]
        loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim)
        loss = - torch.log(loss).mean()
        return loss


def loss_cl(x1, x2, tau=0.5):
    batch_size, _ = x1.size()
    x1_abs = x1.norm(dim=1)
    x2_abs = x2.norm(dim=1)

    sim_matrix_a = torch.einsum('ik,jk->ij', x1, x2) / torch.einsum('i,j->ij', x1_abs, x2_abs)
    sim_matrix_a = torch.exp(sim_matrix_a / tau)
    pos_sim_a = sim_matrix_a[range(batch_size), range(batch_size)]
    loss_a = pos_sim_a / (sim_matrix_a.sum(dim=1) - pos_sim_a)
    loss_a = - torch.log(loss_a).mean()

    sim_matrix_b = torch.einsum('ik,jk->ij', x2, x1) / torch.einsum('i,j->ij', x2_abs, x1_abs)
    sim_matrix_b = torch.exp(sim_matrix_b / tau)
    pos_sim_b = sim_matrix_b[range(batch_size), range(batch_size)]
    loss_b = pos_sim_b / (sim_matrix_b.sum(dim=1) - pos_sim_b)
    loss_b = - torch.log(loss_b).mean()
    loss = (loss_a + loss_b) / 2
    return loss


def train_cl_with_sim_loss(view_gen1, view_gen2, view_optimizer, model, optimizer, data_loader, device, tau):
    loss_all = 0
    model.train()
    total_graphs = 0
    for data in data_loader:
        optimizer.zero_grad()
        view_optimizer.zero_grad()

        data = data.to(device)

        sample1, view1 = view_gen1(data, True)
        sample2, view2 = view_gen2(data, True)

        sim_loss = F.mse_loss(sample1, sample2)
        sim_loss = (1 - sim_loss)

        input_list = [data, view1, view2]
        input1, input2 = random.choices(input_list, k=2)

        out1 = model(input1)
        out2 = model(input2)
        
        cl_loss = loss_cl(out1, out2, tau=tau)
        loss = sim_loss + cl_loss

        loss_all += loss.item() * data.num_graphs
        total_graphs += data.num_graphs
        loss.backward()        
        optimizer.step()
        view_optimizer.step()

    loss_all /= total_graphs
    return loss_all


def cl_exp(args):
    set_seed(args.seed)
    device_id = 'cuda:%d' % (args.gpu)
    device = torch.device(device_id if torch.cuda.is_available() else 'cpu')

    dataset = load_dataset(args.dataset, args)
    data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

    model = simclr(dataset, args.hidden_dim, args.num_layers, device=device).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    view_gen1 = ViewGenerator(dataset, args.hidden_dim, GIN_NodeWeightEncoder)
    view_gen2 = ViewGenerator(dataset, args.hidden_dim, GIN_NodeWeightEncoder)
    view_gen1 = view_gen1.to(device)
    view_gen2 = view_gen2.to(device)

    view_optimizer = optim.Adam([{'params': view_gen1.parameters()},{'params': view_gen2.parameters()} ], lr=args.lr, weight_decay=0)
    for epoch in range(1, args.epochs + 1):
        train_cl_with_sim_loss(view_gen1, view_gen2, view_optimizer, model, optimizer, data_loader, device, tau=args.tau)

    model.eval()
    emb, y = model.encoder.get_embeddings(data_loader)
    acc_mean, acc_std = evaluate_embedding(emb, y)
    return acc_mean, acc_std


if __name__ == '__main__':
    args = arg_parse()
    print(f'dataset={args.dataset}, hidden_dim={args.hidden_dim}, tau={args.tau}, lr={args.lr}, epochs={args.epochs}')
    test_acc, test_std = cl_exp(args)

