import argparse
import warnings
import time
import numpy as np
import argparse
import seaborn as sns
import networkx as nx

import torch
from torch_scatter import scatter
from torch_geometric.utils import dropout_adj, degree, to_undirected, to_networkx

from models import *
from utils import set_seed, random_splits
from dataset_loader import DataLoader
from eval import unsupervised_test_linear
from unsup_model import GCA
warnings.filterwarnings("ignore")


def drop_edge_weighted(edge_index, edge_weights, p: float, threshold: float = 1.):
    edge_weights = edge_weights / edge_weights.mean() * p
    edge_weights = edge_weights.where(edge_weights < threshold, torch.ones_like(edge_weights) * threshold)
    sel_mask = torch.bernoulli(1. - edge_weights).to(torch.bool)
    return edge_index[:, sel_mask]


def drop_edge(p):
    global drop_weights
    if args.drop_scheme == 'uniform':
        return dropout_adj(data.edge_index, p=p)[0]
    elif args.drop_scheme in ['degree', 'evc', 'pr']:
        return drop_edge_weighted(data.edge_index, drop_weights, p=p, threshold=0.7)
    else:
        raise Exception(f'undefined drop scheme: {args.drop_scheme}')
    

def drop_feature(x, drop_prob):
    drop_mask = torch.empty((x.size(1),), dtype=torch.float32, device=x.device).uniform_(0, 1) < drop_prob
    x = x.clone()
    x[:, drop_mask] = 0
    return x


def degree_drop_weights(edge_index):
    edge_index_ = to_undirected(edge_index)
    deg = degree(edge_index_[1])
    deg_col = deg[edge_index[1]].to(torch.float32)
    s_col = torch.log(deg_col)
    weights = (s_col.max() - s_col) / (s_col.max() - s_col.mean())
    return weights


def compute_pr(edge_index, damp: float = 0.85, k: int = 10):
    num_nodes = edge_index.max().item() + 1
    deg_out = degree(edge_index[0])
    x = torch.ones((num_nodes, )).to(edge_index.device).to(torch.float32)
    for i in range(k):
        edge_msg = x[edge_index[0]] / deg_out[edge_index[0]]
        agg_msg = scatter(edge_msg, edge_index[1], reduce='sum')
        x = (1 - damp) * x + damp * agg_msg
    return x


def pr_drop_weights(edge_index, aggr: str = 'sink', k: int = 10):
    pv = compute_pr(edge_index, k=k)
    pv_row = pv[edge_index[0]].to(torch.float32)
    pv_col = pv[edge_index[1]].to(torch.float32)
    s_row = torch.log(pv_row)
    s_col = torch.log(pv_col)
    if aggr == 'sink':
        s = s_col
    elif aggr == 'source':
        s = s_row
    elif aggr == 'mean':
        s = (s_col + s_row) * 0.5
    else:
        s = s_col
    weights = (s.max() - s) / (s.max() - s.mean())
    return weights


def eigenvector_centrality(data):
    graph = to_networkx(data)
    x = nx.eigenvector_centrality_numpy(graph)
    x = [x[i] for i in range(data.num_nodes)]
    return torch.tensor(x, dtype=torch.float32).to(data.edge_index.device)


def evc_drop_weights(data):
    evc = eigenvector_centrality(data)
    evc = evc.where(evc > 0, torch.zeros_like(evc))
    evc = evc + 1e-8
    s = evc.log()
    edge_index = data.edge_index
    s_row, s_col = s[edge_index[0]], s[edge_index[1]]
    s = s_col
    return (s.max() - s) / (s.max() - s.mean())


def feature_drop_weights(x, node_c, dense=False):
    x = x.abs() if dense else x.to(torch.bool).to(torch.float32)
    w = x.t() @ node_c
    w = w.log()
    s = (w.max() - w) / (w.max() - w.mean())
    return s


def drop_feature_weighted(x, w, p: float, threshold: float = 0.7):
    w = w / w.mean() * p
    w = w.where(w < threshold, torch.ones_like(w) * threshold)
    drop_prob = w
    drop_mask = torch.bernoulli(drop_prob).to(torch.bool)
    x = x.clone()
    x[:, drop_mask] = 0.
    return x


def unsupervised_learning(data, args):
    best = float("inf")
    cnt_wait = 0
    unsup_tag = str(int(time.time()))
    for epoch in range(1, args.unsup_epochs + 1):
        model.train()
        optimizer.zero_grad()

        edge_index_1 = drop_edge(p=args.de1)
        edge_index_2 = drop_edge(p=args.de2)
        x_1 = drop_feature(data.x, args.df1)
        x_2 = drop_feature(data.x, args.df2)

        if args.drop_scheme in ['pr', 'degree', 'evc']:
            x_1 = drop_feature_weighted(data.x, feature_weights, args.df1)
            x_2 = drop_feature_weighted(data.x, feature_weights, args.df2)

        z1 = model(x_1, edge_index_1)
        z2 = model(x_2, edge_index_2)

        loss = model.loss(z1, z2, batch_size=None)
        loss.backward()
        optimizer.step()
        if loss < best:
            best = loss
            cnt_wait = 0
            torch.save(model.state_dict(), 'unsup_pkl/' + 'gca_cl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl')
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
            break

    model.load_state_dict(torch.load('unsup_pkl/' + 'gca_cl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl'))
    model.eval()
    embeds = model(data.x, data.edge_index).detach()
    return embeds


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='seed.')
    parser.add_argument('--dataset', type=str,default='Cora')
    parser.add_argument('--device', type=int, default=0, help='GPU device.')
    parser.add_argument('--runs', type=int, default=10, help='number of runs.')
    parser.add_argument('--net', type=str, default='GCN')

    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--hidden', type=int, default=64, help='hidden units.')
    parser.add_argument('--dropout', type=float, default=0.5, help='dropout for neural networks.')
    
    parser.add_argument('--fix_split', action='store_true')
    parser.add_argument('--train_rate', type=float, default=0.6, help='train set rate.')
    parser.add_argument('--val_rate', type=float, default=0.2, help='val set rate.')

    # unsupervised learning
    parser.add_argument('--drop_scheme', default='degree', type=str)
    parser.add_argument('--de1', default=0.2, type=float)
    parser.add_argument('--de2', default=0.2, type=float)
    parser.add_argument('--df1', default=0.2, type=float)
    parser.add_argument('--df2', default=0.2, type=float)
    parser.add_argument('--tau', default=0.5, type=float)
    parser.add_argument("--proj_hid_dim", type=int, default=128, help="Projection hidden layer dim.")

    parser.add_argument("--patience", type=int, default=20, help="Patient epochs to wait before early stopping.")
    parser.add_argument("--unsup_epochs", type=int, default=500, help="Unupservised training epochs.")
    parser.add_argument("--lr1", type=float, default=0.001, help="Learning rate of the unsupervised model.")
    parser.add_argument("--lr2", type=float, default=0.01, help="Learning rate of linear evaluator.")
    parser.add_argument("--wd1", type=float, default=0.0, help="Weight decay of the unsupervised model.")
    parser.add_argument("--wd2", type=float, default=0.0, help="Weight decay of linear evaluator.")
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    print(args)
    print("---------------------------------------------")
    
    set_seed(args.seed)
    #10 fixed seeds for random splits from BernNet
    SEEDS=[1941488137,4198936517,983997847,4023022221,4019585660,2108550661,1648766618,629014539,3212139042,2424918363]
    device = torch.device('cuda:'+str(args.device) if torch.cuda.is_available() else 'cpu')

    dataset = DataLoader(args.dataset)
    data = dataset[0]
    data = data.to(device)

    percls_trn = int(round(args.train_rate * len(data.y) / dataset.num_classes))
    val_lb = int(round(args.val_rate * len(data.y)))
    
    encoder = UnsupGCN_Net(dataset=dataset, args=args)
    model = GCA(encoder=encoder, num_hidden=args.hidden, num_proj_hidden=args.proj_hid_dim, tau=args.tau).to(device)
    optimizer = torch.optim.Adam([{'params': model.parameters(), 'weight_decay': args.wd1, 'lr': args.lr1}])

    if args.drop_scheme == 'degree':
        drop_weights = degree_drop_weights(data.edge_index).to(device)
        edge_index_ = to_undirected(data.edge_index)
        node_deg = degree(edge_index_[1])
        feature_weights = feature_drop_weights(data.x, node_c=node_deg, dense=True if args.dataset == 'CS' else False).to(device)

    elif args.drop_scheme == 'pr':
        drop_weights = pr_drop_weights(data.edge_index, aggr='sink', k=200).to(device)
        node_pr = compute_pr(data.edge_index)
        feature_weights = feature_drop_weights(data.x, node_c=node_pr, dense=True if args.dataset == 'CS' else False).to(device)

    elif args.drop_scheme == 'evc':
        drop_weights = evc_drop_weights(data).to(device)
        node_evc = eigenvector_centrality(data)
        feature_weights = feature_drop_weights(data.x, node_c=node_evc, dense=True if args.dataset == 'CS' else False).to(device)

    else:
        drop_weights = None
        feature_weights = torch.ones((data.x.size(1),)).to(device)

    embeds = unsupervised_learning(data=data, args=args)
    
    if args.dataset not in ['Computers', 'Photo']:
        full_train_mask, full_val_mask, full_test_mask = data.train_mask, data.val_mask, data.test_mask

    unsup_results = []
    for RP in range(args.runs):
        args.seed = SEEDS[RP]

        if args.fix_split:
            if args.dataset in ['Computers', 'Photo']:  # no public splitting, train/val/test=1/1/8
                percls_trn = int(round(0.1 * len(data.y) / dataset.num_classes))
                val_lb = int(round(0.1 * len(data.y)))
                data = random_splits(data, dataset.num_classes, percls_trn, val_lb, args.seed).to(device)
            else:
                data.train_mask, data.val_mask, data.test_mask = full_train_mask[:, RP], full_val_mask[:, RP], full_test_mask[:, RP]
        else:
            data = random_splits(data, dataset.num_classes, percls_trn, val_lb, args.seed).to(device)

        eval_acc = unsupervised_test_linear(data=data, embeds=embeds, n_classes=dataset.num_classes, device=device, args=args)
        unsup_results.append(eval_acc)

    test_acc_mean = np.mean(unsup_results) * 100
    values = np.asarray(unsup_results, dtype=object)
    uncertainty = np.max(np.abs(sns.utils.ci(sns.algorithms.bootstrap(values, func=np.mean, n_boot=1000), 95) - values.mean()))
    print(f'test acc mean = {test_acc_mean:.4f} ± {uncertainty * 100:.4f}')