import os
import warnings
import time
import random
import pickle
import argparse
import seaborn as sns
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.utils import k_hop_subgraph, contains_isolated_nodes, to_undirected, is_undirected, remove_isolated_nodes
from torch_geometric.utils.num_nodes import maybe_num_nodes

from utils import set_seed, random_splits
from dataset_loader import DataLoader
from eval import unsupervised_test_linear
warnings.filterwarnings("ignore")



def check_device(model, data, specified_device):
    # Check if the model is on the correct device
    for name, param in model.named_parameters():
        if param.device != specified_device:
            print(f"Parameter '{name}' is on {param.device}, but should be on {specified_device}")

    # Check if data is on the correct device
    for attribute in dir(data):
        if isinstance(getattr(data, attribute), torch.Tensor):
            tensor_device = getattr(data, attribute).device
            if tensor_device != specified_device:
                print(f"Data '{attribute}' is on {tensor_device}, but should be on {specified_device}")


def coalesce(edge_index, num_nodes=None, sort_by_row: bool = True):
    nnz = edge_index.size(1)
    num_nodes = maybe_num_nodes(edge_index, num_nodes)

    idx = edge_index.new_empty(nnz + 1)
    idx[0] = -1
    idx[1:] = edge_index[1 - int(sort_by_row)]
    idx[1:].mul_(num_nodes).add_(edge_index[int(sort_by_row)])
    mask = idx[1:] > idx[:-1]
    # Only perform expensive merging in case there exists duplicates:
    if mask.all():
        return edge_index
    return edge_index[:, mask]


def get_activation(name: str):
    activations = {
        'relu': torch.nn.ReLU(), # F.relu,
        'hardtanh': torch.nn.Hardtanh(), # F.hardtanh,
        'elu': torch.nn.ELU(), # F.elu,
        'leakyrelu': F.leaky_relu,
        'prelu': torch.nn.PReLU(),
        'rrelu': torch.nn.RReLU(), # F.rrelu
    }
    return activations[name]


class GNN_Encoder(nn.Module):
    def __init__(self, in_dim, args):
        super(GNN_Encoder, self).__init__()
        self.num_forward_layer = args.num_layers
        self.use_bn = not args.no_bn
        self.dropout = args.dropout

        self.forward_pass = [GCNConv(in_dim, args.hidden)]
        self.bns = [nn.BatchNorm1d(args.hidden)]
        for _ in range(1, args.num_layers):
            self.forward_pass.append(GCNConv(args.hidden, args.hidden))
            self.bns.append(nn.BatchNorm1d(args.hidden))
        self.forward_pass = nn.ModuleList(self.forward_pass)
        self.bns = nn.ModuleList(self.bns)
        self.activation = nn.PReLU()

    def forward(self, x, edge_index):
        for i in range(self.num_forward_layer):
            x = self.forward_pass[i](x, edge_index)
            if self.use_bn:
                x = self.bns[i](x)
            x = self.activation(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        return x


class SPGCL(nn.Module):
    def __init__(self, encoder, args):
        super(SPGCL, self).__init__()
        self.args = args
        self.encoder = encoder
        self.fc_pipe = torch.nn.Sequential(
            torch.nn.Linear(args.hidden, args.hidden),
            get_activation(args.proj_activation),
            torch.nn.Linear(args.hidden, args.hidden))

    def projection(self, z: torch.Tensor):
        return self.fc_pipe(z)

    def forward(self, x, edge_index):
        embed = self.encoder(x, edge_index)
        return embed

    def embed(self, x, edge_index):
        embed = self.encoder(x, edge_index)
        return embed.detach()

    def model_train(self, features, edge_index, optimizer):
        optimizer.zero_grad()
        self.MAX_SIZE = self.args.max_size
        orig_node_embed_full = self.encoder(features, edge_index)

        if self.args.seed_sampling == 'random':
            sample_idx = np.array(random.sample(list(range(features.shape[0])), min(self.MAX_SIZE, features.shape[0])))
        else: # tree
            sample_idx_ = torch.concat([data.subgraph_cache[i] for i in
                                        random.sample(list(range(features.shape[0])), self.args.seed_num)]).tolist()
            # prevent neighbor explosion
            sample_idx = np.array(random.sample(sample_idx_, min(self.MAX_SIZE, len(sample_idx_)) ))

        if self.args.square_subg:
            node_embed_sample = self.projection(F.relu(orig_node_embed_full[sample_idx]))
            norm_node_embed = F.normalize(node_embed_sample, p=2, dim=-1)
            sim_matrix_sample = torch.mm(norm_node_embed, norm_node_embed.t())
        else:
            node_embed_full = self.projection(F.relu(orig_node_embed_full))
            norm_node_embed_full = F.normalize(node_embed_full, p=2, dim=-1)
            norm_node_embed = norm_node_embed_full[sample_idx]
            sim_matrix_sample = torch.mm(norm_node_embed, norm_node_embed_full.t())

        topk_col_idx = torch.topk(sim_matrix_sample, self.args.topk, dim=1)[1].cpu().numpy()
        filter_index = []
        for i in range(topk_col_idx.shape[1]):
            a = np.array(list(range(sim_matrix_sample.shape[0]))).reshape(-1,1)
            b = topk_col_idx[:,i].reshape(-1,1)
            c = np.hstack([a,b])
            filter_index.append(c)
        filter_index = torch.tensor(np.concatenate(filter_index, axis=0)).to(features.device)

        if self.args.neg_selection == 'topk':
            topk_col_idx = torch.topk(-sim_matrix_sample, self.args.neg_topk, dim=1)[1].cpu().numpy()
        if self.args.neg_selection == 'random':
            topk_col_idx = torch.randint(0, len(sample_idx), (len(sample_idx), self.args.neg_topk)).numpy()
        filter_index_neg = []
        for i in range(topk_col_idx.shape[1]):
            a = np.array(list(range(sim_matrix_sample.shape[0]))).reshape(-1,1)
            b = topk_col_idx[:,i].reshape(-1,1)
            c = np.hstack([a,b])
            filter_index_neg.append(c)
        filter_index_neg = torch.tensor(np.concatenate(filter_index_neg, axis=0)).to(features.device)

        pos_score_per_node = torch.zeros(sample_idx.shape[0]).to(sim_matrix_sample.device)
        pos_score_per_node = pos_score_per_node.scatter_add_(0, filter_index[:, 0], sim_matrix_sample[filter_index[:, 0], filter_index[:, 1]])
        per_node_count = torch.zeros(sample_idx.shape[0]).float().to(sim_matrix_sample.device)
        per_node_count = per_node_count.scatter_add_(0, filter_index[:, 0], torch.ones_like(filter_index[:, 0]).float().to(sim_matrix_sample.device))
        pos_part = (-2 * pos_score_per_node/per_node_count).mean() #
        neg_score_per_node = torch.zeros(sample_idx.shape[0]).to(sim_matrix_sample.device)
        neg_score_per_node = neg_score_per_node.scatter_add_(0, filter_index_neg[:, 0], sim_matrix_sample[filter_index_neg[:, 0], filter_index_neg[:, 1]]**2)
        neg_part = (neg_score_per_node).mean()

        loss = pos_part + neg_part

        loss.backward()
        optimizer.step()
        return loss.item()


def unsupervised_learning(model, data, features, edge_index, args):
    assert not contains_isolated_nodes(edge_index)
    best = float("inf")
    cnt_wait = 0
    unsup_tag = str(int(time.time()))
    for epoch in range(args.unsup_epochs):
        model.train()
        loss = model.model_train(features, edge_index, optimizer)
        if loss < best:
            best = loss
            cnt_wait = 0
            torch.save(model.state_dict(), 'unsup_pkl/' + 'spgcl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl')
        else:
            cnt_wait += 1
        if cnt_wait == args.patience:
            break

    model.load_state_dict(torch.load('unsup_pkl/' + 'spgcl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl'))
    model.eval()
    model = model.to(device)
    embeds = model.embed(data.x, data.edge_index)
    os.remove('unsup_pkl/' + 'spgcl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl')
    return embeds


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='seed.')
    parser.add_argument('--dataset', type=str,default='Cora')
    parser.add_argument('--device', type=int, default=0, help='GPU device.')
    parser.add_argument('--runs', type=int, default=10, help='number of runs.')
    parser.add_argument('--net', type=str, default='GCN')

    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--hidden', type=int, default=64, help='hidden units.')
    parser.add_argument('--dropout', type=float, default=0.5, help='dropout for neural networks.')
    
    parser.add_argument('--train_rate', type=float, default=0.6, help='train set rate.')
    parser.add_argument('--val_rate', type=float, default=0.2, help='val set rate.')

    parser.add_argument("--patience", type=int, default=50, help="Patient epochs to wait before early stopping.")
    parser.add_argument("--unsup_epochs", type=int, default=1000, help="Unupservised training epochs.")
    parser.add_argument("--lr1", type=float, default=0.01, help="Learning rate of the unsupervised model.")
    parser.add_argument("--lr2", type=float, default=0.01, help="Learning rate of linear evaluator.")
    parser.add_argument("--wd1", type=float, default=0.0, help="Weight decay of the unsupervised model.")
    parser.add_argument("--wd2", type=float, default=0.0, help="Weight decay of linear evaluator.")

    parser.add_argument('--seed_num', type=int, default=128)
    parser.add_argument('--subg_num_hops', type=int, default=4)
    parser.add_argument('--sub_dataset', type=str, default=None)
    parser.add_argument('--no_bn', type=int, default=0, help='do not use batchnorm')
    parser.add_argument('--max_size', type=int, default=512)
    parser.add_argument('--proj_activation', type=str, default='prelu', help='prelu, relu, rrelu')
    parser.add_argument('--square_subg', type=int, default=1)
    parser.add_argument('--topk', type=int, default=5)
    parser.add_argument('--neg_topk', type=int, default=100)
    parser.add_argument('--neg_selection', type=str, default='random', help='topk, random')
    parser.add_argument('--seed_sampling', type=str, default='tree', help='random, tree')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    print(args)
    print("---------------------------------------------")
    
    set_seed(args.seed)
    #10 fixed seeds for random splits from BernNet
    SEEDS=[1941488137,4198936517,983997847,4023022221,4019585660,2108550661,1648766618,629014539,3212139042,2424918363]
    device = torch.device('cuda:'+str(args.device) if torch.cuda.is_available() else 'cpu')

    dataset = DataLoader(args.dataset)
    data = dataset[0]
    data = data.to(device)
    data.edge_index = to_undirected(data.edge_index)
    if contains_isolated_nodes(data.edge_index, num_nodes=data.num_nodes):
        edge_index, _, mask = remove_isolated_nodes(data.edge_index, num_nodes=data.num_nodes)
        features = data.x[mask]
    else:
        edge_index = data.edge_index
        features = data.x

    assert not contains_isolated_nodes(edge_index)
    t1 = time.time()
    subg_file_path = 'saved_models/{}_hop_{}_subg.pkl'.format(args.dataset, args.subg_num_hops)
    if os.path.exists(subg_file_path):
        # print('Loading saved subg...')
        with open(subg_file_path, 'rb') as file:
            subgraph_cache = pickle.load(file)
        mask_matrix = []
        subgraph_list_len = len(subgraph_cache)
        for node_idx in range(subgraph_list_len):
            result = subgraph_cache[node_idx]
            l = (torch.zeros_like(result) + node_idx).unsqueeze(0)
            r = result.unsqueeze(0)
            mask_matrix.append(torch.cat([l,r], dim=0))
        mask_matrix = torch.cat(mask_matrix, dim=1)
        if not is_undirected(mask_matrix):
            print(args.dataset + 'need check!!!!')
            mask_matrix = to_undirected(mask_matrix)
        mask_matrix = coalesce(mask_matrix)
    else:
        # print("construct k-hop-subg")
        subgraph_cache = []
        mask_matrix = []
        assert edge_index.max() + 1 == features.shape[0]
        for node_idx in set(edge_index[0].tolist()):
            result = k_hop_subgraph(node_idx=node_idx, num_hops=args.subg_num_hops, edge_index=edge_index)[0]
            if len(result) == 0: # if no neighbor, skip
                continue
            subgraph_cache.append(result)
            l = (torch.zeros_like(result) + node_idx).unsqueeze(0)
            r = result.unsqueeze(0)
            mask_matrix.append(torch.cat([l,r], dim=0))
        mask_matrix = torch.cat(mask_matrix, dim=1)
        with open(subg_file_path, 'wb') as file:
            pickle.dump(subgraph_cache, file)

    data.subgraph_cache = subgraph_cache
    data.mask_matrix = mask_matrix.to(device)
    t2 = time.time()
    # print("Pre-compute subgraph: {}s".format(t2-t1))

    percls_trn = int(round(args.train_rate * len(data.y) / dataset.num_classes))
    val_lb = int(round(args.val_rate * len(data.y)))
    
    encoder = GNN_Encoder(dataset.num_node_features, args).to(device)
    model = SPGCL(encoder=encoder, args=args).to(device)
    optimizer = torch.optim.AdamW(params=model.parameters(), lr=args.lr1, weight_decay=args.wd1)

    embeds = unsupervised_learning(model=model, data=data, features=features, edge_index=edge_index, args=args).to(device)

    unsup_results = []
    for RP in range(args.runs):
        args.seed = SEEDS[RP]
        tag = str(args.seed)
        data = random_splits(data, dataset.num_classes, percls_trn, val_lb, args.seed).to(device)
        eval_acc = unsupervised_test_linear(data=data, embeds=embeds, n_classes=dataset.num_classes, device=device, args=args)
        unsup_results.append(eval_acc)

    test_acc_mean = np.mean(unsup_results) * 100
    values = np.asarray(unsup_results, dtype=object)
    uncertainty = np.max(np.abs(sns.utils.ci(sns.algorithms.bootstrap(values, func=np.mean, n_boot=1000), 95) - values.mean()))
    print(f'test acc mean = {test_acc_mean:.4f} ± {uncertainty * 100:.4f}')