import os
import copy
import warnings
import time
import numpy as np
import argparse
import seaborn as sns

import torch
from torch.nn.functional import cosine_similarity
from torch.optim import AdamW
from torch_geometric.utils.dropout import dropout_adj
from torch_geometric.transforms import Compose

from models import *
from unsup_model import BGRL, MLP_Predictor
from utils import set_seed, random_splits
from dataset_loader import DataLoader
from eval import unsupervised_test_linear
warnings.filterwarnings("ignore")


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='seed.')
    parser.add_argument('--dataset', type=str,default='Cora')
    parser.add_argument('--device', type=int, default=0, help='GPU device.')
    parser.add_argument('--runs', type=int, default=10, help='number of runs.')
    parser.add_argument('--net', type=str, default='GCN')

    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--hidden', type=int, default=512, help='hidden units.')
    parser.add_argument('--dropout', type=float, default=0.5, help='dropout for neural networks.')
    
    parser.add_argument('--fix_split', action='store_true')
    parser.add_argument('--train_rate', type=float, default=0.6, help='train set rate.')
    parser.add_argument('--val_rate', type=float, default=0.2, help='val set rate.')

    # unsupervised learning
    parser.add_argument('--de1', default=0.2, type=float)
    parser.add_argument('--de2', default=0.2, type=float)
    parser.add_argument('--df1', default=0.2, type=float)
    parser.add_argument('--df2', default=0.2, type=float)
    parser.add_argument("--pred_hid_dim", type=int, default=128, help="Projection hidden layer dim.")

    parser.add_argument("--unsup_epochs", type=int, default=10000, help="Unupservised training epochs.")
    parser.add_argument("--lr_warmup_epochs", type=int, default=1000, help='Warmup period for learning rate.')
    parser.add_argument("--mm", type=float, default=0.99, help='The momentum for moving average.')
    parser.add_argument("--lr1", type=float, default=0.001, help="Learning rate of the unsupervised model.")
    parser.add_argument("--lr2", type=float, default=0.01, help="Learning rate of linear evaluator.")
    parser.add_argument("--wd1", type=float, default=0.0, help="Weight decay of the unsupervised model.")
    parser.add_argument("--wd2", type=float, default=0.0, help="Weight decay of linear evaluator.")
    args = parser.parse_args()
    return args


class CosineDecayScheduler:
    def __init__(self, max_val, warmup_steps, total_steps):
        self.max_val = max_val
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps

    def get(self, step):
        if step < self.warmup_steps:
            return self.max_val * step / self.warmup_steps
        elif self.warmup_steps <= step <= self.total_steps:
            return self.max_val * (1 + np.cos((step - self.warmup_steps) * np.pi /
                                              (self.total_steps - self.warmup_steps))) / 2
        else:
            raise ValueError('Step ({}) > total number of steps ({}).'.format(step, self.total_steps))
        

class DropFeatures:
    def __init__(self, p=None):
        assert 0. < p < 1., 'Dropout probability has to be between 0 and 1, but got %.2f' % p
        self.p = p

    def __call__(self, data):
        drop_mask = torch.empty((data.x.size(1),), dtype=torch.float32, device=data.x.device).uniform_(0, 1) < self.p
        data.x[:, drop_mask] = 0
        return data

    def __repr__(self):
        return '{}(p={})'.format(self.__class__.__name__, self.p)


class DropEdges:
    def __init__(self, p, force_undirected=False):
        assert 0. < p < 1., 'Dropout probability has to be between 0 and 1, but got %.2f' % p
        self.p = p
        self.force_undirected = force_undirected

    def __call__(self, data):
        edge_index = data.edge_index
        edge_attr = data.edge_attr if 'edge_attr' in data else None
        edge_index, edge_attr = dropout_adj(edge_index, edge_attr, p=self.p, force_undirected=self.force_undirected)

        data.edge_index = edge_index
        if edge_attr is not None:
            data.edge_attr = edge_attr
        return data

    def __repr__(self):
        return '{}(p={}, force_undirected={})'.format(self.__class__.__name__, self.p, self.force_undirected)
    

def get_graph_drop_transform(drop_edge_p, drop_feat_p):
    transforms = list()
    transforms.append(copy.deepcopy)
    if drop_edge_p > 0.:
        transforms.append(DropEdges(drop_edge_p))
    if drop_feat_p > 0.:
        transforms.append(DropFeatures(drop_feat_p))
    return Compose(transforms)


def unsupervised_learning(data, args):
    best = float("inf")
    unsup_tag = str(int(time.time()))
    for step in range(1, args.unsup_epochs + 1):
        model.train()

        lr = lr_scheduler.get(step - 1)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        mm = 1 - mm_scheduler.get(step)
        optimizer.zero_grad()

        data1, data2 = transform_1(data), transform_2(data)
        q1, y2 = model(data1, data2)
        q2, y1 = model(data2, data1)

        loss = 2 - cosine_similarity(q1, y2.detach(), dim=-1).mean() - cosine_similarity(q2, y1.detach(), dim=-1).mean()
        loss.backward()
        optimizer.step()
        model.update_target_network(mm)

        if loss < best:
            best = loss
            torch.save(model.state_dict(), 'unsup_pkl/' + 'bgrl_cl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl')

    model.load_state_dict(torch.load('unsup_pkl/' + 'bgrl_cl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl'))
    model.eval()
    model.online_encoder.eval()
    embeds = model.get_embeddings(data)
    os.remove('unsup_pkl/' + 'bgrl_cl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl')
    return embeds



if __name__ == '__main__':
    args = parse_args()
    print(args)
    print("---------------------------------------------")
    
    set_seed(args.seed)
    #10 fixed seeds for random splits from BernNet
    SEEDS=[1941488137,4198936517,983997847,4023022221,4019585660,2108550661,1648766618,629014539,3212139042,2424918363]
    device = torch.device('cuda:'+str(args.device) if torch.cuda.is_available() else 'cpu')

    dataset = DataLoader(args.dataset)
    data = dataset[0]
    data = data.to(device)

    percls_trn = int(round(args.train_rate * len(data.y) / dataset.num_classes))
    val_lb = int(round(args.val_rate * len(data.y)))
    
    transform_1 = get_graph_drop_transform(drop_edge_p=args.de1, drop_feat_p=args.df1)
    transform_2 = get_graph_drop_transform(drop_edge_p=args.de2, drop_feat_p=args.df2)

    encoder = UnsupGCN_Net(dataset=dataset, args=args)
    predictor = MLP_Predictor(args.hidden, args.hidden, hidden_size=args.pred_hid_dim)
    model = BGRL(encoder, predictor).to(device)

    optimizer = AdamW(model.trainable_parameters(), lr=args.lr1, weight_decay=args.wd1)
    lr_scheduler = CosineDecayScheduler(args.lr1, args.lr_warmup_epochs, args.unsup_epochs)
    mm_scheduler = CosineDecayScheduler(1 - args.mm, 0, args.unsup_epochs)

    embeds = unsupervised_learning(data=data, args=args)
        
    if args.dataset not in ['Computers', 'Photo']:
        full_train_mask, full_val_mask, full_test_mask = data.train_mask, data.val_mask, data.test_mask

    unsup_results = []
    for RP in range(args.runs):
        args.seed = SEEDS[RP]
        
        if args.fix_split:
            if args.dataset in ['Computers', 'Photo']:  # no public splitting, train/val/test=1/1/8
                percls_trn = int(round(0.1 * len(data.y) / dataset.num_classes))
                val_lb = int(round(0.1 * len(data.y)))
                data = random_splits(data, dataset.num_classes, percls_trn, val_lb, args.seed).to(device)
            else:
                data.train_mask, data.val_mask, data.test_mask = full_train_mask[:, RP], full_val_mask[:, RP], full_test_mask[:, RP]
        else:
            data = random_splits(data, dataset.num_classes, percls_trn, val_lb, args.seed).to(device)

        eval_acc = unsupervised_test_linear(data=data, embeds=embeds, n_classes=dataset.num_classes, device=device, args=args)
        unsup_results.append(eval_acc)

    test_acc_mean = np.mean(unsup_results) * 100
    values = np.asarray(unsup_results, dtype=object)
    uncertainty = np.max(np.abs(sns.utils.ci(sns.algorithms.bootstrap(values, func=np.mean, n_boot=1000), 95) - values.mean()))
    print(f'test acc mean = {test_acc_mean:.4f} ± {uncertainty * 100:.4f}')