import warnings
import time
import os
import numpy as np
import argparse
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import dropout_adj, add_remaining_self_loops
from torch_scatter import scatter

from utils import set_seed, random_splits, drop_features
from eval import unsupervised_test_linear
from dataset_loader import DataLoader
warnings.filterwarnings("ignore")


def print_params(model, string):
    print(f'----------- {string} ----------')
    params = model.encoder.named_parameters()
    for name, param in params:
        print(name)
        print(param)
    print('-----------------------------------')


class Decoupled_GCN(nn.Module):
    def __init__(self, dataset, args):
        super(Decoupled_GCN, self).__init__()
        self.K = args.K
        self.dropout = args.dropout 
        self.lin1 = nn.Linear(dataset.num_node_features, args.hidden)
        self.lin2 = nn.Linear(args.hidden, args.hidden)
        self.reset_parameters()

    def reset_parameters(self): 
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()
    
    def get_embeddings(self, x, edge_index, random1=False, random2=False):
        if random1:
            print(self.lin1.weight.data.mean().item(), self.lin1.weight.data.std().item())
            self.lin1.weight = nn.Parameter(torch.normal(mean=self.lin1.weight.data.mean().item(), std=self.lin1.weight.data.std().item(), size=self.lin1.weight.size()).to(self.lin1.weight.device), requires_grad=False)
            # self.lin1.weight = nn.Parameter(torch.normal(mean=0, std=0.1, size=self.lin1.weight.size()).to(self.lin1.weight.device), requires_grad=False)
        if random2:
            self.lin2.weight = nn.Parameter(torch.normal(mean=0, std=0.01, size=self.lin2.weight.size()).to(self.lin2.weight.device), requires_grad=False)
        return self(x, edge_index)

    def forward(self, x, edge_index):
        for _ in range(self.K):
            row, col = edge_index
            x = scatter(x[col], row, dim=0, reduce='add', dim_size=x.size(0))
            
        x = self.lin1(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin2(x)
        x = F.relu(x)
        return x


class GRACE(torch.nn.Module):
    def __init__(self, encoder, input_dim, num_hidden, num_proj_hidden, tau, drop_rate):
        super(GRACE, self).__init__()
        self.encoder = encoder
        
        self.tau = tau
        self.drop_rate = drop_rate
        self.num_proj_hidden = num_proj_hidden
        self.fc1 = torch.nn.Linear(input_dim, num_proj_hidden)
        self.fc2 = torch.nn.Linear(num_proj_hidden, num_hidden)

    def augmentation(self, x, edge_index):
        edge_index_1 = dropout_adj(edge_index, p=self.drop_rate[0])[0]
        edge_index_2 = dropout_adj(edge_index, p=self.drop_rate[1])[0]
        x1 = drop_features(x, self.drop_rate[2])
        x2 = drop_features(x, self.drop_rate[3])
        return x1, edge_index_1, x2, edge_index_2

    def forward(self, x, edge_index):
        x1, edge_index_1, x2, edge_index_2 = self.augmentation(x, edge_index)
        z1 = self.encoder(x1, edge_index_1)
        z2 = self.encoder(x2, edge_index_2)
        return z1, z2

    def get_embedding(self, x, edge_index, random1=False, random2=False):
        z = self.encoder.get_embeddings(x, edge_index, random1, random2)
        return z.detach()

    def projection_mlp(self, z: torch.Tensor) -> torch.Tensor:
        z = F.elu(self.fc1(z))
        return self.fc2(z)

    def sim(self, z1: torch.Tensor, z2: torch.Tensor):
        z1 = F.normalize(z1)
        z2 = F.normalize(z2)
        return torch.mm(z1, z2.t())
        
    def infonce(self, z1, z2):
        f = lambda x: torch.exp(x / self.tau)
        between_sim = f(self.sim(z1, z2))
        alignment_loss = -torch.log(between_sim.diag())
        refl_sim = f(self.sim(z1, z1))
        uniformity_loss = torch.log(refl_sim.sum(1) + between_sim.sum(1) - refl_sim.diag())
        loss = alignment_loss + uniformity_loss
        return loss
    
    def infonce_loss(self, z1, z2):
        z1 = self.projection_mlp(z1)
        z2 = self.projection_mlp(z2)
        l1 = self.infonce(z1, z2)
        l2 = self.infonce(z2, z1)
        ret = (l1 + l2) * 0.5
        ret = ret.mean()
        return ret


def unsupervised_learning(dataset, data, args, device):
    encoder = Decoupled_GCN(dataset=dataset, args=args)
    model = GRACE(encoder=encoder, input_dim=args.hidden, num_hidden=args.hidden, num_proj_hidden=args.proj_hid_dim, 
                  tau=args.tau, drop_rate=(args.de1, args.de2, args.df1, args.df2)).to(device)
    optimizer = torch.optim.Adam([{'params': model.parameters(), 'weight_decay': args.wd1, 'lr': args.lr1}])

    best = float("inf")
    cnt_wait = 0
    unsup_tag = str(int(time.time()))
    edge_index = add_remaining_self_loops(data.edge_index)[0]
    for epoch in range(args.unsup_epochs):
        model.train()
        optimizer.zero_grad()

        z1, z2 = model(data.x, edge_index)
        loss = model.infonce_loss(z1, z2)

        loss.backward()
        optimizer.step()

        if loss < best:
            best = loss
            cnt_wait = 0
            torch.save(model.state_dict(), 'unsup_pkl/' + 'grace_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl')
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
            break

    model.load_state_dict(torch.load('unsup_pkl/' + 'grace_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl'))
    model.eval()
    embeds = model.get_embedding(data.x, edge_index, random1=args.random1, random2=args.random2)
    os.remove('unsup_pkl/' + 'grace_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl')
    return embeds


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='seed.')
    parser.add_argument('--dataset', type=str,default='Cora')
    parser.add_argument('--device', type=int, default=0, help='GPU device.')
    parser.add_argument('--runs', type=int, default=10, help='number of runs.')
    parser.add_argument('--net', type=str, default='GCN')

    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--hidden', type=int, default=64, help='hidden units.')
    parser.add_argument('--dropout', type=float, default=0.5, help='dropout for neural networks.')
    parser.add_argument('--K', type=int, default=10, help='propagation steps.')
    
    parser.add_argument('--train_rate', type=float, default=0.6, help='train set rate.')
    parser.add_argument('--val_rate', type=float, default=0.2, help='val set rate.')

    # unsupervised learning
    parser.add_argument('--de1', default=0.2, type=float)
    parser.add_argument('--de2', default=0.2, type=float)
    parser.add_argument('--df1', default=0.2, type=float)
    parser.add_argument('--df2', default=0.2, type=float)
    parser.add_argument('--tau', default=0.5, type=float)
    parser.add_argument("--proj_hid_dim", type=int, default=128, help="Projection hidden layer dim.")

    parser.add_argument("--patience", type=int, default=20, help="Patient epochs to wait before early stopping.")
    parser.add_argument("--unsup_epochs", type=int, default=500, help="Unupservised training epochs.")
    parser.add_argument("--lr1", type=float, default=0.001, help="Learning rate of the unsupervised model.")
    parser.add_argument("--lr2", type=float, default=0.01, help="Learning rate of linear evaluator.")
    parser.add_argument("--wd1", type=float, default=0.0, help="Weight decay of the unsupervised model.")
    parser.add_argument("--wd2", type=float, default=0.0, help="Weight decay of linear evaluator.")

    parser.add_argument('--random1', action='store_true')
    parser.add_argument('--random2', action='store_true')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    print(args)
    print("---------------------------------------------")
    
    set_seed(args.seed)
    #10 fixed seeds for random splits from BernNet
    SEEDS=[1941488137,4198936517,983997847,4023022221,4019585660,2108550661,1648766618,629014539,3212139042,2424918363]
    device = torch.device('cuda:'+str(args.device) if torch.cuda.is_available() else 'cpu')

    dataset = DataLoader(args.dataset)
    data = dataset[0]
    data = data.to(device)

    percls_trn = int(round(args.train_rate * len(data.y) / dataset.num_classes))
    val_lb = int(round(args.val_rate * len(data.y)))
    
    embeds = unsupervised_learning(dataset=dataset, data=data, args=args, device=device)
    
    unsup_results = []
    for RP in range(args.runs):
        args.seed = SEEDS[RP]
        data = random_splits(data, dataset.num_classes, percls_trn, val_lb, args.seed).to(device)
        eval_acc = unsupervised_test_linear(data=data, embeds=embeds, n_classes=dataset.num_classes, device=device, args=args)
        unsup_results.append(eval_acc)

    test_acc_mean = np.mean(unsup_results) * 100
    values = np.asarray(unsup_results, dtype=object)
    uncertainty = np.max(np.abs(sns.utils.ci(sns.algorithms.bootstrap(values, func=np.mean, n_boot=1000), 95) - values.mean()))
    print(f'test acc mean = {test_acc_mean:.4f} ± {uncertainty * 100:.4f}')
