import warnings
import time
import os.path as osp
import numpy as np
import argparse
import seaborn as sns
import torch
import torch.nn as nn
from torch_geometric.utils import to_dense_adj
from scipy.linalg import fractional_matrix_power, inv

from models import *
from utils import set_seed, random_splits
from unsup_model import MVGRL
from dataset_loader import DataLoader
from eval import unsupervised_test_linear
warnings.filterwarnings("ignore")


class GCN(nn.Module):
    def __init__(self, in_ft, out_ft, dropout):
        super(GCN, self).__init__()
        self.fc1 = nn.Linear(in_ft, out_ft, bias=False)
        self.fc2 = nn.Linear(out_ft, out_ft, bias=False)
        self.act = nn.ReLU()
        self.dropout = dropout

        self.bias1 = nn.Parameter(torch.FloatTensor(out_ft))
        self.bias1.data.fill_(0.0)
        self.bias2 = nn.Parameter(torch.FloatTensor(out_ft))
        self.bias2.data.fill_(0.0)

        for m in self.modules():
            self.weights_init(m)

    def weights_init(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    def forward(self, seq, adj):  # Shape of seq: (batch, nodes, features)
        x = self.fc1(seq)
        x = torch.bmm(adj, x)
        if self.bias1 is not None:
            x += self.bias1
        x = self.act(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
    
        x = self.fc2(x)
        x = torch.bmm(adj, x)
        if self.bias2 is not None:
            x += self.bias2
        return x
    

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='seed.')
    parser.add_argument('--dataset', type=str,default='Cora')
    parser.add_argument('--device', type=int, default=0, help='GPU device.')
    parser.add_argument('--runs', type=int, default=10, help='number of runs.')
    parser.add_argument('--net', type=str, default='GCN')

    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--hidden', type=int, default=64, help='hidden units.')
    parser.add_argument('--dropout', type=float, default=0.5, help='dropout for neural networks.')
    
    parser.add_argument('--fix_split', action='store_true')
    parser.add_argument('--train_rate', type=float, default=0.6, help='train set rate.')
    parser.add_argument('--val_rate', type=float, default=0.2, help='val set rate.')

    # unsupervised learning
    parser.add_argument("--patience", type=int, default=20, help="Patient epochs to wait before early stopping.")
    parser.add_argument("--unsup_epochs", type=int, default=500, help="Unupservised training epochs.")
    parser.add_argument("--lr1", type=float, default=0.001, help="Learning rate of the unsupervised model.")
    parser.add_argument("--lr2", type=float, default=0.01, help="Learning rate of linear evaluator.")
    parser.add_argument("--wd1", type=float, default=0.0, help="Weight decay of the unsupervised model.")
    parser.add_argument("--wd2", type=float, default=0.0, help="Weight decay of linear evaluator.")
    args = parser.parse_args()
    return args


def compute_ppr(name, adj, alpha):
    path = osp.join('./data', name.lower(), 'raw', f'diff.npy')
    if osp.exists(path):
        ppr = np.load(path)
    else:
        a = adj.detach().cpu().numpy()
        a = a + np.eye(a.shape[0])                                    # A^ = A + I_n
        d = np.diag(np.sum(a, 1))                                     # D^ = Sigma A^_ii
        dinv = fractional_matrix_power(d, -0.5)                       # D^(-1/2)
        at = np.matmul(np.matmul(dinv, a), dinv)                      # A~ = D^(-1/2) x A^ x D^(-1/2)
        ppr = alpha * inv((np.eye(a.shape[0]) - (1 - alpha) * at))    # a(I_n-(1-a)A~)^-1
        np.save(path, ppr)
    return ppr


def unsupervised_learning(dataset, data, args, device, verbose=False):
    encoder1 = GCN(in_ft=dataset.num_node_features, out_ft=args.hidden, dropout=args.dropout)
    encoder2 = GCN(in_ft=dataset.num_node_features, out_ft=args.hidden, dropout=args.dropout)
    
    sample_size = 2000
    batch_size = 4

    lbl_1 = torch.ones(batch_size, sample_size * 2)
    lbl_2 = torch.zeros(batch_size, sample_size * 2)
    lbl = torch.cat((lbl_1, lbl_2), 1).to(device)

    feat = data.x
    feat_size = data.x.size(1)
    adj = to_dense_adj(data.edge_index).squeeze(0)
    diff = torch.tensor(compute_ppr(args.dataset, adj, alpha=0.2), device=device)

    model = MVGRL(encoder1=encoder1, encoder2=encoder2, n_h=args.hidden).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr1, weight_decay=args.wd2)

    best = float("inf")
    cnt_wait = 0
    best_t = 0
    unsup_tag = str(int(time.time()))
    for epoch in range(args.unsup_epochs): 
        idx = torch.randint(0, adj.size(-1) - sample_size + 1, (batch_size,))
        ba, bd, bf = [], [], []
        for i in idx:
            i = int(i)
            ba.append(adj[i: i + sample_size, i: i + sample_size])
            bd.append(diff[i: i + sample_size, i: i + sample_size])
            bf.append(feat[i: i + sample_size])

        ba = torch.stack(ba).float().reshape(batch_size, sample_size, sample_size)
        bd = torch.stack(bd).float().reshape(batch_size, sample_size, sample_size)
        bf = torch.stack(bf).float().reshape(batch_size, sample_size, feat_size)
        idx = torch.randperm(sample_size)

        model.train()
        optimizer.zero_grad()

        logits = model(seq1=bf, seq2=bf[:, idx, :], adj=ba, diff=bd)
        loss = model.loss_fn(logits, lbl)

        loss.backward()
        optimizer.step()

        if verbose:
            print('Epoch: {0}, Loss: {1:0.4f}'.format(epoch, loss.item()))

        if loss < best:
            best = loss
            best_t = epoch
            cnt_wait = 0
            torch.save(model.state_dict(), 'unsup_pkl/' + 'mvgrl_cl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl')
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
            if verbose:
                print('Early stopping!')
            break

    if verbose:
        print('Loading {}th epoch'.format(best_t))
    model.load_state_dict(torch.load('unsup_pkl/' + 'mvgrl_cl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl'))
    model.eval()
    embeds = model.get_embedding(seq=feat.float().unsqueeze(0), adj=adj.float().unsqueeze(0), diff=diff.float().unsqueeze(0))
    return embeds


if __name__ == '__main__':
    args = parse_args()
    print(args)
    print("---------------------------------------------")
    
    set_seed(args.seed)
    #10 fixed seeds for random splits from BernNet
    SEEDS=[1941488137,4198936517,983997847,4023022221,4019585660,2108550661,1648766618,629014539,3212139042,2424918363]
    device = torch.device('cuda:'+str(args.device) if torch.cuda.is_available() else 'cpu')

    dataset = DataLoader(args.dataset)
    data = dataset[0]
    data = data.to(device)

    percls_trn = int(round(args.train_rate * len(data.y) / dataset.num_classes))
    val_lb = int(round(args.val_rate * len(data.y)))
    
    embeds = unsupervised_learning(dataset=dataset, data=data, args=args, device=device)
    
    if args.dataset not in ['Computers', 'Photo']:
        full_train_mask, full_val_mask, full_test_mask = data.train_mask, data.val_mask, data.test_mask

    unsup_results = []
    for RP in range(args.runs):
        args.seed = SEEDS[RP]
        
        if args.fix_split:
            if args.dataset in ['Computers', 'Photo']:  # no public splitting, train/val/test=1/1/8
                percls_trn = int(round(0.1 * len(data.y) / dataset.num_classes))
                val_lb = int(round(0.1 * len(data.y)))
                data = random_splits(data, dataset.num_classes, percls_trn, val_lb, args.seed).to(device)
            else:
                data.train_mask, data.val_mask, data.test_mask = full_train_mask[:, RP], full_val_mask[:, RP], full_test_mask[:, RP]
        else:
            data = random_splits(data, dataset.num_classes, percls_trn, val_lb, args.seed).to(device)
        
        eval_acc = unsupervised_test_linear(data=data, embeds=embeds, n_classes=dataset.num_classes, device=device, args=args)
        unsup_results.append(eval_acc)

    test_acc_mean = np.mean(unsup_results) * 100
    values = np.asarray(unsup_results, dtype=object)
    uncertainty = np.max(np.abs(sns.utils.ci(sns.algorithms.bootstrap(values, func=np.mean, n_boot=1000), 95) - values.mean()))
    print(f'test acc mean = {test_acc_mean:.4f} ± {uncertainty * 100:.4f}')
