import os
import math
import warnings
import time
import copy
import numpy as np
import argparse
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.utils import to_undirected, remove_isolated_nodes, contains_isolated_nodes

from utils import set_seed
from dataset_loader import HeterophilousGraphDataset
from eval import unsupervised_eval_linear
warnings.filterwarnings("ignore")


def sample_neighborhood(edge_index, num_nodes, neighbor_max, device):
    adjacent_list = {}
    for i in range(0, num_nodes):
        adjacent_list[i] = edge_index[1, edge_index[0] == i]

    for i in range(0, num_nodes):
        num_neighbor = adjacent_list[i].shape[0]
        if num_neighbor < neighbor_max:
            more = math.ceil(neighbor_max / num_neighbor)
            adjacent_list[i] = adjacent_list[i].repeat(1, more)[0, :neighbor_max]
        if num_neighbor > neighbor_max:
            perm = torch.randperm(num_neighbor, device=device)
            idx = perm[:neighbor_max]
            adjacent_list[i] = adjacent_list[i][idx]
    adjacent_list = list(adjacent_list.values())
    adjacent_list = torch.vstack(adjacent_list)
    return adjacent_list.to(device)


def sample_neg_neighborhood(edge_index, num_nodes, neighbor_max, device):
    adjacent_list = {}
    for i in range(0, num_nodes):
        adjacent_list[i] = edge_index[1, edge_index == i]

    for i in range(0, num_nodes):
        perm = torch.randperm(num_nodes, device=device)
        adjacent_list[i] = perm[:neighbor_max]
    adjacent_list = list(adjacent_list.values())
    adjacent_list = torch.vstack(adjacent_list)
    return adjacent_list.to(device)

    
class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, num_layers=2, dropout=0.5, use_bn=True):
        super(GCN, self).__init__()
        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels))
        self.bns = nn.ModuleList()
        self.bns.append(nn.BatchNorm1d(hidden_channels))
        for _ in range(num_layers - 1):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
            self.bns.append(nn.BatchNorm1d(hidden_channels))

        self.dropout = dropout
        self.activation = F.relu
        self.use_bn = use_bn


    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            if self.use_bn:
                x = self.bns[i](x)
            x = self.activation(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        return x


class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new


def update_moving_average(target_ema_updater, ma_model, current_model):
    for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
        old_weight, up_weight = ma_params.data, current_params.data
        ma_params.data = target_ema_updater.update_average(old_weight, up_weight)

def set_requires_grad(model, val):
    for p in model.parameters():
        p.requires_grad = val

def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

def loss_fn(x, y):
    x = F.normalize(x, dim=-1, p=2)
    y = F.normalize(y, dim=-1, p=2)
    return 2 - 2 * (x * y).sum(dim=-1)


class MLP(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout=.0):
        super(MLP, self).__init__()
        self.linear1 = nn.Linear(in_channels, hidden_channels)
        self.linear2 = nn.Linear(hidden_channels, out_channels)
        self.dropout = dropout

    def reset_parameters(self):
        self.linear1.reset_parameters()
        self.linear2.reset_parameters()

    def forward(self, data):
        x = data
        x = F.relu(self.linear1(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.linear2(x)
        return x
    

class DSSL(torch.nn.Module):
    def __init__(self, encoder, hidden_channels, dataset, device, cluster_num, alpha, gamma, tao, beta, moving_average_decay=0.0):
        super(DSSL, self).__init__()
        self.dataset = dataset
        self.device = device
        self.cluster_num = cluster_num
        self.alpha = alpha
        self.gamma = gamma
        self.tao = tao
        self.beta = beta
        self.inner_dropout = True
        self.inner_activation = True
        self.online_encoder = encoder
        self.target_encoder = copy.deepcopy(self.online_encoder)
        set_requires_grad(self.target_encoder, False)
        self.target_ema_updater = EMA(moving_average_decay)

        self.mlp_inference = MLP(hidden_channels, hidden_channels, cluster_num, 2)
        self.mlp_predictor = MLP(hidden_channels, hidden_channels, hidden_channels, 2)
        self.clusters = nn.Parameter(torch.nn.init.normal_(torch.Tensor(hidden_channels, cluster_num)))
        self.mlp_predictor2 = MLP(cluster_num, hidden_channels, hidden_channels, 1)

        self.Embedding_mlp = True
        self.inference_mlp = True

    def reset_parameters(self):
        self.mlp_inference.reset_parameters()
        self.mlp_predictor.reset_parameters()
        self.mlp_predictor2.reset_parameters()
        self.online_encoder.reset_parameters()
        self.target_encoder = copy.deepcopy(self.online_encoder)

    def update_moving_average(self):
        #assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'
        assert self.target_encoder is not None, 'target encoder has not been created yet'
        update_moving_average(self.target_ema_updater,self.target_encoder, self.online_encoder)

    def forward(self, embedding, neighbor_embedding):
        # expand node embedding
        embedding_more = embedding[:, None, :]
        embedding_expand = embedding_more.expand(-1, neighbor_embedding.shape[1], -1)
        # loss
        if self.inference_mlp == True:
            k, k_node, entropy_loss = self.inference_network(embedding_expand, neighbor_embedding)
        else:
            k, k_node, entropy_loss = self.inference_network2(embedding_expand, neighbor_embedding)

        if self.Embedding_mlp == True:
            main_loss = self.generative_network(embedding_expand, k, neighbor_embedding)
        else:
            main_loss = self.generative_network2(embedding_expand, k, neighbor_embedding)
        context_loss = self.context_network(embedding,k_node)
        return main_loss, context_loss, entropy_loss,k_node

    def inference_network(self, embedding_expand, neighbor_embedding): # N × K, N × 10 * K, k is N*32
        cat_embedding = embedding_expand * neighbor_embedding # get k
        k = F.softmax(self.mlp_inference(cat_embedding), dim=2)
        k_node = k.mean(dim=1) #to get P(k|x)
        negative_entropy = k_node * torch.log(k_node + 1e-10)
        entropy_loss = negative_entropy.sum(-1).mean() # minimize negative entropy
        return k, k_node, entropy_loss

    def inference_network2(self, embedding_expand, neighbor_embedding):
        cat_embedding = embedding_expand * neighbor_embedding
        k  = torch.matmul(cat_embedding, self.clusters) # get k
        k_node = k.mean(dim=1) # to get P(k|x) 
        negative_entropy = k_node * torch.log(F.softmax(k_node, dim=1)+1e-10)
        entropy_loss = negative_entropy.sum(-1).mean() # minimize negative entropy
        return k, k_node, entropy_loss

    def generative_network(self, embedding_expand, k, neighbor_embedding):
        # re-parameterization trick
        gumbel_k = F.gumbel_softmax(k, hard=False)
        central=self.mlp_predictor(embedding_expand)+ self.beta*self.mlp_predictor2(gumbel_k)
        neighbor=neighbor_embedding
        loss= loss_fn(central, neighbor.detach()).mean()
        return loss

    def generative_network2(self, embedding_expand, k, neighbor_embedding):
        # em algorithm (to do)
        # re-parameterization trick
        gumbel_k = F.gumbel_softmax(k, hard=True)
        central=(embedding_expand)+ self.beta*self.mlp_predictor2(gumbel_k)
        neighbor=neighbor_embedding
        loss= loss_fn(central, neighbor).mean()
        return loss

    def context_network(self, embedding, k_node):
        kprior = torch.matmul(embedding, self.clusters)
        kprior = F.softmax(kprior/self.tao, dim=1)
        context_loss = k_node *torch.log(kprior+1e-10)
        context_loss = - 1.0 * context_loss.sum(-1).mean()
        return context_loss

    def update_cluster(self, new_center,batch_sum):
        with torch.no_grad():
            out_ids = torch.arange(self.cluster_num).to(self.device)
            out_ids = out_ids.long()
            self.clusters.index_copy_(1, out_ids, new_center)
            self.clusters.data=torch.mul(self.clusters.data.T ,1/ (batch_sum+1)).T



def unsupervised_learning(data, features, edge_index, args):
    model.train()
    best = float("inf")
    cnt_wait = 0
    unsup_tag = str(int(time.time()))
    num_nodes = features.size(0)
    for epoch in range(args.unsup_epochs):
        model.train()
        perm = torch.randperm(num_nodes)
        epoch_loss = 0
        for batch in range(0, num_nodes, args.batch_size):
            optimizer.zero_grad()
            online_embedding = model.online_encoder(features, edge_index)
            target_embedding = model.target_encoder(features, edge_index)

            batch_idx = perm[batch:batch + args.batch_size].to(device)
            batch_neighbor_index = sampled_neighborhoods[batch_idx]
            batch_embedding = online_embedding[batch_idx].to(device)
            batch_embedding = F.normalize(batch_embedding, dim=-1, p=2)
            batch_neighbor_embedding = [target_embedding[i, :].unsqueeze(0) for i in batch_neighbor_index]
            batch_neighbor_embedding = torch.cat(batch_neighbor_embedding, dim=0).to(device)
            batch_neighbor_embedding = F.normalize(batch_neighbor_embedding, dim=-1, p=2)

            main_loss, context_loss, entropy_loss, k_node = model(batch_embedding, batch_neighbor_embedding)
            tmp = F.one_hot(torch.argmax(k_node, dim=1), num_classes=args.cluster_num).type(torch.FloatTensor).to(device)
            batch_sum = (torch.reshape(torch.sum(tmp, 0), (-1, 1)))
            if args.neg_alpha:
                batch_neg_neighbor_index = sampled_neg_neighborhoods[batch_idx]
                batch_neighbor_embedding = [target_embedding[i, :].unsqueeze(0) for i in batch_neg_neighbor_index]
                batch_neighbor_embedding = torch.cat(batch_neighbor_embedding, dim=0).to(device)
                batch_neighbor_embedding = F.normalize(batch_neighbor_embedding, dim=-1, p=2)
                main_neg_loss, tmp, tmp, tmp = model(batch_embedding, batch_neighbor_embedding)
                loss = main_loss + args.gamma * (context_loss + entropy_loss) + main_neg_loss
            else:
                loss = main_loss+ args.gamma*(context_loss+entropy_loss)
            # print("batch : {}, main_loss: {}, context_loss: {}, entropy_loss: {}".format(batch, main_loss, context_loss, entropy_loss))
            loss.backward()
            optimizer.step()
            model.update_moving_average()
            epoch_loss = epoch_loss + loss

        if epoch % 1 == 0:
            model.eval()
            for batch in range(0, num_nodes, args.batch_size):
                online_embedding = model.online_encoder(features, edge_index).detach().cpu()
                target_embedding = model.target_encoder(features, edge_index).detach().cpu()

                batch_idx = perm[batch:batch + args.batch_size].to(device)
                batch_neighbor_index = sampled_neighborhoods[batch_idx]
                batch_embedding = online_embedding[batch_idx].to(device)
                batch_neighbor_embedding = [target_embedding[i, :].unsqueeze(0) for i in batch_neighbor_index]
                batch_neighbor_embedding = torch.cat(batch_neighbor_embedding, dim=0).to(device)
                main_loss, context_loss, entropy_loss, k_node= model(batch_embedding, batch_neighbor_embedding)
                tmp = F.one_hot(torch.argmax(k_node, dim=1), num_classes=args.cluster_num).type(torch.FloatTensor).to(device)
                if batch == 0:
                    cluster = torch.matmul(batch_embedding.t(), tmp)
                    batch_sum=(torch.reshape(torch.sum(tmp, 0), (-1, 1)))
                else:
                    cluster += torch.matmul(batch_embedding.t(), tmp)
                    batch_sum += (torch.reshape(torch.sum(tmp, 0), (-1, 1)))
            cluster = F.normalize(cluster, dim=-1, p=2)
            model.update_cluster(cluster,batch_sum)
        
        # print("epoch: {}, loss: {}".format(epoch, epoch_loss))
        if epoch_loss < best:
            best = epoch_loss
            cnt_wait = 0
            torch.save(model.state_dict(), 'unsup_pkl/' + 'dssl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl')
        else:
            cnt_wait += 1
        if cnt_wait == args.patience:
            break

    model.load_state_dict(torch.load('unsup_pkl/' + 'dssl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl'))
    model.eval()
    embeds = model.online_encoder(data.x, data.edge_index).detach()
    os.remove('unsup_pkl/' + 'dssl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl')
    return embeds
        

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='seed.')
    parser.add_argument('--dataset', type=str,default='Cora')
    parser.add_argument('--device', type=int, default=0, help='GPU device.')
    parser.add_argument('--runs', type=int, default=10, help='number of runs.')
    parser.add_argument('--net', type=str, default='GCN')

    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--hidden', type=int, default=64, help='hidden units.')
    parser.add_argument('--dropout', type=float, default=0.5, help='dropout for neural networks.')
    parser.add_argument('--no_bn', action='store_true', help='do not use batchnorm')
    
    parser.add_argument('--train_rate', type=float, default=0.6, help='train set rate.')
    parser.add_argument('--val_rate', type=float, default=0.2, help='val set rate.')

    # unsupervised learning
    parser.add_argument("--patience", type=int, default=50, help="Patient epochs to wait before early stopping.")
    parser.add_argument("--unsup_epochs", type=int, default=1000, help="Unupservised training epochs.")
    parser.add_argument("--lr1", type=float, default=0.01, help="Learning rate of the unsupervised model.")
    parser.add_argument("--lr2", type=float, default=0.01, help="Learning rate of linear evaluator.")
    parser.add_argument("--wd1", type=float, default=0.0, help="Weight decay of the unsupervised model.")
    parser.add_argument("--wd2", type=float, default=0.0, help="Weight decay of linear evaluator.")

    parser.add_argument('--batch_size', type=int, default=1024, help="batch size")
    parser.add_argument('--neighbor_max', type=int, default=5, help="neighbor num max")
    parser.add_argument('--cluster_num', type=int, default=6, help="cluster num")
    parser.add_argument('--alpha', type=float, default=1)
    parser.add_argument('--gamma', type=float, default=0.1)
    parser.add_argument('--tau', type=float, default=0.99)
    parser.add_argument('--mlp_bool', type=int, default=1, help="embedding with mlp predictor")
    parser.add_argument('--tao', type=float, default=1)
    parser.add_argument('--beta', type=float, default=1)
    parser.add_argument('--mlp_inference_bool', type=int, default=1, help="embedding with mlp predictor")
    parser.add_argument('--neg_alpha', type=int, default=0, help="negative alpha ")
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    print(args)
    print("---------------------------------------------")
    
    set_seed(args.seed)
    #10 fixed seeds for random splits from BernNet
    SEEDS=[1941488137,4198936517,983997847,4023022221,4019585660,2108550661,1648766618,629014539,3212139042,2424918363]
    device = torch.device('cuda:'+str(args.device) if torch.cuda.is_available() else 'cpu')

    # Step 1: Load data =================================================================== #
    root = './data/'
    dataset = HeterophilousGraphDataset(root=root, name=args.dataset)
    data = dataset[0]
    data = data.to(device)

    data.edge_index = to_undirected(data.edge_index)
    if contains_isolated_nodes(data.edge_index, num_nodes=data.num_nodes):
        edge_index, _, mask = remove_isolated_nodes(data.edge_index, num_nodes=data.num_nodes)
        features = data.x[mask]
    else:
        edge_index = data.edge_index
        features = data.x
    
    sampled_neighborhoods = sample_neighborhood(edge_index, features.size(0), args.neighbor_max, device)
    if args.neg_alpha:
        sampled_neg_neighborhoods = sample_neg_neighborhood(dataset, device, args)
        # print('sample_neg_neighborhoods')

    percls_trn = int(round(args.train_rate * len(data.y) / dataset.num_classes))
    val_lb = int(round(args.val_rate * len(data.y)))
    
    encoder = GCN(in_channels=dataset.num_node_features,
                    hidden_channels=args.hidden,
                    num_layers=args.num_layers,
                    dropout=args.dropout,
                    use_bn=not args.no_bn).to(device)

    model = DSSL(encoder=encoder,
                hidden_channels=args.hidden,
                dataset=dataset,
                device=device,
                cluster_num=args.cluster_num,
                alpha=args.alpha,
                gamma=args.gamma,
                tao=args.tao,
                beta=args.beta,
                moving_average_decay=args.tau).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr1, weight_decay=args.wd1)

    if not args.mlp_bool: # 0 embedding without mlp predictor
        model.Embedding_mlp = False
    if not args.mlp_inference_bool: # 0 embedding without mlp predictor
        model.inference_mlp = False

    embeds = unsupervised_learning(data=data, features=features, edge_index=edge_index, args=args)
    
    results = unsupervised_eval_linear(data=data, embeds=embeds, args=args, device=device)
    results = [v.item() for v in results]
    test_acc_mean = np.mean(results, axis=0) * 100
    values = np.asarray(results, dtype=object)
    uncertainty = np.max(
        np.abs(sns.utils.ci(sns.algorithms.bootstrap(values, func=np.mean, n_boot=1000), 95) - values.mean()))
    print(f'test acc mean = {test_acc_mean:.4f} ± {uncertainty * 100:.4f}')