import os
import warnings
import time
import numpy as np
import argparse
import seaborn as sns
import networkx as nx
import scipy.stats as stats

import torch
from torch import nn
import torch.nn.functional as F
from torch_scatter import scatter
from torch_geometric.nn import GCNConv
from torch_geometric.utils import dropout_adj, degree, to_undirected, to_networkx

from utils import set_seed, random_splits
from dataset_loader import DataLoader
from eval import unsupervised_test_linear
warnings.filterwarnings("ignore")



class GCN(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers):
        super(GCN, self).__init__()
        assert num_layers >= 2
        self.num_layers = num_layers
 
        self.conv = [GCNConv(in_channels, out_channels)]
        for _ in range(num_layers - 1):
            self.conv.append(GCNConv(out_channels, out_channels))
        self.conv = nn.ModuleList(self.conv)
        self.activation = nn.PReLU()

    def forward(self, x, edge_index):
        for i in range(self.num_layers):
            x = self.activation(self.conv[i](x, edge_index))
        return x


class GRACE(torch.nn.Module):
    def __init__(self, encoder, num_hidden, num_proj_hidden, tau=0.5):
        super(GRACE, self).__init__()
        self.encoder = encoder
        self.tau = tau
        self.fc1 = torch.nn.Linear(num_hidden, num_proj_hidden)
        self.fc2 = torch.nn.Linear(num_proj_hidden, num_hidden)
        self.num_hidden = num_hidden

    def forward(self, x, edge_index):
        return self.encoder(x, edge_index)

    def projection(self, z):
        z = F.elu(self.fc1(z))
        return self.fc2(z)

    def sim(self, z1, z2):
        z1 = F.normalize(z1)
        z2 = F.normalize(z2)
        return torch.mm(z1, z2.t())

    def semi_loss(self, z1, z2, epoch):
        f = lambda x: torch.exp(x / self.tau)
        refl_sim = self.sim(z1, z1)
        between_sim = self.sim(z1, z2)
        refl_sim = f(refl_sim)
        between_sim = f(between_sim)
        return -torch.log(between_sim.diag() / (between_sim.sum(1) + refl_sim.sum(1) - refl_sim.diag()))

    def semi_loss_bmm(self, z1, z2, epoch, args, bmm_model, fit = False):
        f = lambda x: torch.exp(x / self.tau)
        refl_sim = self.sim(z1, z1)
        between_sim = self.sim(z1, z2)
        N = between_sim.size(0)
        mask = torch.ones((N,N),dtype=bool).to(z1.device)
        mask[np.eye(N,dtype=bool)] = False 
        if epoch == args.epoch_start and fit:
            global B
            N_sel = 100
            index_fit = np.random.randint(0, N, N_sel)          
            sim_fit = between_sim[:,index_fit]
            sim_fit = (sim_fit + 1) / 2   # Min-Max Normalization
            bmm_model.fit(sim_fit.flatten())
            between_sim_norm = between_sim.masked_select(mask).view(N, -1)
            between_sim_norm = (between_sim_norm + 1) / 2
            print('Computing positive probility,wait...')
            B = bmm_model.posterior(between_sim_norm,0) * between_sim_norm.detach() 
            print('Over!') 
        if args.mode == 'weight':
            refl_sim = f(refl_sim)
            between_sim = f(between_sim)
            ng_bet = (between_sim.masked_select(mask).view(N,-1) * B).sum(1) / B.mean(1)
            ng_refl = (refl_sim.masked_select(mask).view(N,-1) * B).sum(1) / B.mean(1)
            return -torch.log(between_sim.diag()/(between_sim.diag() + ng_bet + ng_refl)) 
        elif args.mode == 'mix':  
            eps = 1e-12
            sorted, indices = torch.sort(B, descending=True)
            N_sel = torch.gather(between_sim[mask].view(N,-1), -1, indices)[:,:args.sel_num]
            random_index = np.random.permutation(np.arange(args.sel_num))
            N_random = N_sel[:,random_index]
            M = sorted[:,:args.sel_num]
            M_random = M[:,random_index]
            M = (N_sel * M + N_random * M_random) / (M + M_random + eps)
            refl_sim = f(refl_sim)
            between_sim = f(between_sim)
            M = f(M)
            return -torch.log(between_sim.diag()/(M.sum(1) + between_sim.sum(1) + refl_sim.sum(1) - refl_sim.diag()))  
        else:
            print('Mode Error!')

    def batched_semi_loss(self, z1, z2, batch_size: int, epoch):
        # Space complexity: O(BN) (semi_loss: O(N^2))
        device = z1.device
        num_nodes = z1.size(0)
        num_batches = (num_nodes - 1) // batch_size + 1
        f = lambda x: torch.exp(x / self.tau)
        indices = np.arange(0, num_nodes)
        losses = []

        for i in range(num_batches):
            mask = indices[i * batch_size:(i + 1) * batch_size]
            refl_sim = self.sim(z1[mask], z1)  # [B, N]
            between_sim = self.sim(z1[mask], z2)  # [B, N]
            refl_sim = f(refl_sim)
            between_sim = f(refl_sim)
            losses.append(-torch.log(between_sim[:, i * batch_size:(i + 1) * batch_size].diag()
                                     / (refl_sim.sum(1) + between_sim.sum(1)
                                        - refl_sim[:, i * batch_size:(i + 1) * batch_size].diag())))

        return torch.cat(losses)

    def batched_semi_loss_bmm(self, z1, z2, batch_size, epoch, args, bmm_model, fit):
        device = z1.device
        num_nodes = z1.size(0)
        num_batches = (num_nodes - 1) // batch_size + 1
        f = lambda x: torch.exp(x / self.tau)
        indices = torch.arange(0, num_nodes).to(device)
        losses = []
        global B
        B = []
        for i in range(num_batches):
            index = indices[i * batch_size:(i + 1) * batch_size]
            neg_mask = torch.ones((batch_size, num_nodes),dtype=bool).to(device)
            pos_index = np.transpose(np.column_stack((np.arange(0,batch_size,1),np.arange(i*batch_size, (i + 1) * batch_size,1))))
            neg_mask[pos_index] = False
            refl_sim = self.sim(z1[index], z1)
            between_sim = self.sim(z1[index], z2)
            if epoch == args.epoch_start and fit:
                N_sel = 100
                index_fit = np.random.randint(0, num_nodes, N_sel)          
                sim_fit = between_sim[:,index_fit]
                sim_fit = (sim_fit - sim_fit.min()) / (sim_fit.max() - sim_fit.min())
                bmm_model.fit(sim_fit.flatten())
                between_sim_norm = between_sim.masked_select(neg_mask).view(batch_size,-1)
                between_sim_norm = (between_sim_norm - between_sim_norm.min()) / (between_sim_norm.max() - between_sim_norm.min())
                print('Computing positive probility,wait...')
                B.append(bmm_model.posterior(between_sim_norm,0) * between_sim_norm.detach())
                print('Over!')
            if args.mode == 'weight':
                refl_sim = f(refl_sim)
                between_sim = f(between_sim)
                ng_bet = (between_sim.masked_select(neg_mask).view(neg_mask.size(0),-1) * B[i]).sum(1) / B[i].mean(1)
                ng_refl = (refl_sim.masked_select(neg_mask).view(neg_mask.size(0),-1) * B[i]).sum(1) / B[i].mean(1)
                losses.append(-torch.log(between_sim.diag()/(between_sim.diag() + ng_bet + ng_refl)))
                return torch.cat(losses)
            elif args.mode == 'mix':
                eps = 1e-12
                B_sel, indices = torch.sort(B[i],descending=True)
                N_sel = torch.gather(between_sim, -1, indices)
                random_index = np.random.permutation(np.arange(N_sel.size(1)))
                N_sel_random = N_sel[:,random_index]
                B_sel_random = B_sel[:,random_index]
                M = (B_sel * N_sel + B_sel_random * N_sel_random) / (B_sel + B_sel_random + eps)
                refl_sim = f(refl_sim)
                between_sim = f(between_sim)
                M = f(M)
                losses.append(-torch.log(between_sim.diag()/(M.sum(1) + between_sim.sum(1) + refl_sim.sum(1) - refl_sim.diag())))  
                return torch.cat(losses)          
            else:
                print('Mode Error!')      
        return torch.cat(losses)

    def loss(self, z1, z2, epoch, args, bmm_model, mean=True, batch_size=None):
        h1 = self.projection(z1)
        h2 = self.projection(z2)
        if epoch < args.epoch_start:
            if batch_size is None:
                l1 = self.semi_loss(h1, h2, epoch)
                l2 = self.semi_loss(h2, h1, epoch)
            else:
                l1 = self.batched_semi_loss(h1, h2, batch_size, epoch)
                l2 = self.batched_semi_loss(h2, h1, batch_size, epoch)
            ret = (l1 + l2) * 0.5
            ret = ret.mean() if mean else ret.sum()
        else:
            if batch_size is None:
                l1 = self.semi_loss_bmm(h1, h2, epoch, args, bmm_model, fit = True)
                l2 = self.semi_loss_bmm(h2, h1, epoch, args, bmm_model)
            else:
                l1 = self.batched_semi_loss_bmm(h1, h2, batch_size, epoch, args, bmm_model, fit = True)
                l2 = self.batched_semi_loss_bmm(h2, h1, batch_size, epoch, args, bmm_model)
            ret = (l1 + l2) * 0.5
            ret = ret.mean() if mean else ret.sum()
        return ret




def weighted_mean(x, w):
    return torch.sum(w * x) / torch.sum(w)


def fit_beta_weighted(x, w):
    x_bar = weighted_mean(x, w)
    s2 = weighted_mean((x - x_bar)**2, w)
    alpha = x_bar * ((x_bar * (1 - x_bar)) / s2 - 1)
    beta = alpha * (1 - x_bar) /x_bar
    return alpha, beta


class BetaMixture1D(object):
    def __init__(self, max_iters, alphas_init, betas_init, weights_init):
        self.alphas = alphas_init
        self.betas = betas_init
        self.weight = weights_init
        self.max_iters = max_iters
        self.eps_nan = 1e-12

    def likelihood(self, x, y):
        x_cpu = x.cpu().detach().numpy()
        alpha_cpu = self.alphas.cpu().detach().numpy()
        beta_cpu = self.betas.cpu().detach().numpy()
        return torch.from_numpy(stats.beta.pdf(x_cpu, alpha_cpu[y], beta_cpu[y])).to(x.device)

    def weighted_likelihood(self, x, y):
        return self.weight[y] * self.likelihood(x, y)

    def probability(self, x):
        return self.weighted_likelihood(x, 0) + self.weighted_likelihood(x, 1)

    def posterior(self, x, y):
        return self.weighted_likelihood(x, y) / (self.probability(x) + self.eps_nan)

    def responsibilities(self, x):
        r = torch.cat((self.weighted_likelihood(x, 0).view(1,-1),self.weighted_likelihood(x, 1).view(1,-1)),0)
        r[r <= self.eps_nan] = self.eps_nan
        r /= r.sum(0)
        return r

    def fit(self, x):
        eps = 1e-12
        x[x >= 1 - eps] = 1 - eps
        x[x <= eps] = eps

        for i in range(self.max_iters):
            # E-step
            r = self.responsibilities(x)
            # M-step
            self.alphas[0], self.betas[0] = fit_beta_weighted(x, r[0])
            self.alphas[1], self.betas[1] = fit_beta_weighted(x, r[1])
            if self.betas[1] < 1:
                self.betas[1] = 1.01
            self.weight = r.sum(1)
            self.weight /= self.weight.sum()
        return self

    def predict(self, x):
        return self.posterior(x, 1) > 0.5

    def __str__(self):
        return 'BetaMixture1D(w={}, a={}, b={})'.format(self.weight, self.alphas, self.betas)


def drop_feature_weighted(x, w, p: float, threshold: float = 0.7):
    w = w / w.mean() * p
    w = w.where(w < threshold, torch.ones_like(w) * threshold)
    drop_prob = w.repeat(x.size(0)).view(x.size(0), -1)
    drop_mask = torch.bernoulli(drop_prob).to(torch.bool)
    x = x.clone()
    x[drop_mask] = 0.
    return x


def drop_feature_weighted_2(x, w, p: float, threshold: float = 0.7):
    w = w / w.mean() * p
    w = w.where(w < threshold, torch.ones_like(w) * threshold)
    drop_prob = w
    drop_mask = torch.bernoulli(drop_prob).to(torch.bool)
    x = x.clone()
    x[:, drop_mask] = 0.
    return x


def feature_drop_weights(x, node_c):
    x = x.to(torch.bool).to(torch.float32)
    w = x.t() @ node_c
    w = w.log()
    s = (w.max() - w) / (w.max() - w.mean())
    return s


def degree_drop_weights(edge_index):
    edge_index_ = to_undirected(edge_index)
    deg = degree(edge_index_[1])
    deg_col = deg[edge_index[1]].to(torch.float32)
    s_col = torch.log(deg_col)
    weights = (s_col.max() - s_col) / (s_col.max() - s_col.mean())
    return weights


def pr_drop_weights(edge_index, aggr: str = 'sink', k: int = 10):
    pv = compute_pr(edge_index, k=k)
    pv_row = pv[edge_index[0]].to(torch.float32)
    pv_col = pv[edge_index[1]].to(torch.float32)
    s_row = torch.log(pv_row)
    s_col = torch.log(pv_col)
    if aggr == 'sink':
        s = s_col
    elif aggr == 'source':
        s = s_row
    elif aggr == 'mean':
        s = (s_col + s_row) * 0.5
    else:
        s = s_col
    weights = (s.max() - s) / (s.max() - s.mean())
    return weights


def evc_drop_weights(data):
    evc = eigenvector_centrality(data)
    evc = evc.where(evc > 0, torch.zeros_like(evc))
    evc = evc + 1e-8
    s = evc.log()
    edge_index = data.edge_index
    s_row, s_col = s[edge_index[0]], s[edge_index[1]]
    s = s_col
    return (s.max() - s) / (s.max() - s.mean())


def compute_pr(edge_index, damp: float = 0.85, k: int = 10):
    num_nodes = edge_index.max().item() + 1
    deg_out = degree(edge_index[0])
    x = torch.ones((num_nodes, )).to(edge_index.device).to(torch.float32)
    for i in range(k):
        edge_msg = x[edge_index[0]] / deg_out[edge_index[0]]
        agg_msg = scatter(edge_msg, edge_index[1], reduce='sum')
        x = (1 - damp) * x + damp * agg_msg
    return x


def eigenvector_centrality(data):
    graph = to_networkx(data)
    x = nx.eigenvector_centrality_numpy(graph)
    x = [x[i] for i in range(data.num_nodes)]
    return torch.tensor(x, dtype=torch.float32).to(data.edge_index.device)


def drop_edge_weighted(edge_index, edge_weights, p: float, threshold: float = 1.):
    edge_weights = edge_weights / edge_weights.mean() * p
    edge_weights = edge_weights.where(edge_weights < threshold, torch.ones_like(edge_weights) * threshold)
    sel_mask = torch.bernoulli(1. - edge_weights).to(torch.bool)
    return edge_index[:, sel_mask]


def drop_edge(drop_rate):
    global drop_weights
    if args.drop_scheme == 'uniform':
        return dropout_adj(data.edge_index, p=drop_rate)[0]
    elif args.drop_scheme in ['degree', 'evc', 'pr']:
        return drop_edge_weighted(data.edge_index, drop_weights, p=drop_rate, threshold=0.7)
    else:
        raise Exception(f'undefined drop scheme: {args.drop_scheme}')


def drop_feature(x, drop_prob):
    drop_mask = torch.empty((x.size(1),), dtype=torch.float32, device=x.device).uniform_(0, 1) < drop_prob
    x = x.clone()
    x[:, drop_mask] = 0
    return x


def unsupervised_learning(data, args):
    best = float("inf")
    cnt_wait = 0
    unsup_tag = str(int(time.time()))
    for epoch in range(1, args.unsup_epochs + 1):
        model.train()
        optimizer.zero_grad()

        edge_index_1 = drop_edge(args.de1)
        edge_index_2 = drop_edge(args.de2)

        x_1 = drop_feature(data.x, args.df1)
        x_2 = drop_feature(data.x, args.df2)

        if args.drop_scheme in ['pr', 'degree', 'evc']:
            x_1 = drop_feature_weighted_2(data.x, feature_weights, args.df1)
            x_2 = drop_feature_weighted_2(data.x, feature_weights, args.df2)

        z1 = model(x_1, edge_index_1)
        z2 = model(x_2, edge_index_2)
        loss = model.loss(z1, z2, epoch, args, bmm_model, batch_size=None)
        loss.backward()
        optimizer.step()
        if loss < best:
            best = loss
            cnt_wait = 0
            torch.save(model.state_dict(), 'unsup_pkl/' + 'progcl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl')
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
            break

    model.load_state_dict(torch.load('unsup_pkl/' + 'progcl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl'))
    model.eval()
    embeds = model(data.x, data.edge_index).detach().to(device)
    os.remove('unsup_pkl/' + 'progcl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl')
    return embeds


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='seed.')
    parser.add_argument('--dataset', type=str,default='Cora')
    parser.add_argument('--device', type=int, default=0, help='GPU device.')
    parser.add_argument('--runs', type=int, default=10, help='number of runs.')
    parser.add_argument('--net', type=str, default='GCN')

    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--hidden', type=int, default=64, help='hidden units.')
    parser.add_argument('--dropout', type=float, default=0.5, help='dropout for neural networks.')
    
    parser.add_argument('--fix_split', action='store_true')
    parser.add_argument('--train_rate', type=float, default=0.6, help='train set rate.')
    parser.add_argument('--val_rate', type=float, default=0.2, help='val set rate.')

    # unsupervised learning
    parser.add_argument('--de1', default=0.2, type=float)
    parser.add_argument('--de2', default=0.2, type=float)
    parser.add_argument('--df1', default=0.2, type=float)
    parser.add_argument('--df2', default=0.2, type=float)
    parser.add_argument('--tau', default=0.5, type=float)
    parser.add_argument("--proj_hid_dim", type=int, default=128, help="Projection hidden layer dim.")

    parser.add_argument("--patience", type=int, default=50, help="Patient epochs to wait before early stopping.")
    parser.add_argument("--unsup_epochs", type=int, default=1000, help="Unupservised training epochs.")
    parser.add_argument("--lr1", type=float, default=0.01, help="Learning rate of the unsupervised model.")
    parser.add_argument("--lr2", type=float, default=0.01, help="Learning rate of linear evaluator.")
    parser.add_argument("--wd1", type=float, default=0.0, help="Weight decay of the unsupervised model.")
    parser.add_argument("--wd2", type=float, default=0.0, help="Weight decay of linear evaluator.")

    parser.add_argument('--drop_scheme', default='evc', type=str)
    parser.add_argument('--epoch_start', type=int, default=400)
    parser.add_argument('--mode', type=str, default='weight')
    parser.add_argument('--sel_num', type=int, default=1000)
    parser.add_argument('--weight_init', type=float, default=0.05)
    parser.add_argument('--iters', type=int, default=10)
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    print(args)
    print("---------------------------------------------")
    
    set_seed(args.seed)
    #10 fixed seeds for random splits from BernNet
    SEEDS=[1941488137,4198936517,983997847,4023022221,4019585660,2108550661,1648766618,629014539,3212139042,2424918363]
    device = torch.device('cuda:'+str(args.device) if torch.cuda.is_available() else 'cpu')

    dataset = DataLoader(args.dataset)
    data = dataset[0]
    data = data.to(device)

    percls_trn = int(round(args.train_rate * len(data.y) / dataset.num_classes))
    val_lb = int(round(args.val_rate * len(data.y)))

    encoder = GCN(dataset.num_features, args.hidden, num_layers=args.num_layers).to(device)
    model = GRACE(encoder, args.hidden, args.proj_hid_dim, args.tau).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr1, weight_decay=args.wd1)

    alphas_init = torch.tensor([1, 2], dtype=torch.float64, device=device)
    betas_init = torch.tensor([2, 1], dtype=torch.float64, device=device)
    weights_init = torch.tensor([1-args.weight_init, args.weight_init], dtype=torch.float64, device=device)
    bmm_model = BetaMixture1D(args.iters, alphas_init, betas_init, weights_init)

    if args.drop_scheme == 'degree':
        drop_weights = degree_drop_weights(data.edge_index).to(device)
        edge_index_ = to_undirected(data.edge_index)
        node_deg = degree(edge_index_[1])
        feature_weights = feature_drop_weights(data.x, node_c=node_deg).to(device)
    elif args.drop_scheme == 'pr':
        drop_weights = pr_drop_weights(data.edge_index, aggr='sink', k=200).to(device)
        node_pr = compute_pr(data.edge_index)
        feature_weights = feature_drop_weights(data.x, node_c=node_pr).to(device)
    elif args.drop_scheme == 'evc':
        drop_weights = evc_drop_weights(data).to(device)
        node_evc = eigenvector_centrality(data)
        feature_weights = feature_drop_weights(data.x, node_c=node_evc).to(device)
    else:
        drop_weights = None
        feature_weights = torch.ones((data.x.size(1),)).to(device)

    embeds = unsupervised_learning(data=data, args=args)
    
    if args.dataset not in ['Computers', 'Photo']:
        full_train_mask, full_val_mask, full_test_mask = data.train_mask, data.val_mask, data.test_mask
    
    unsup_results = []
    for RP in range(args.runs):
        args.seed = SEEDS[RP]
        if args.fix_split:
            if args.dataset in ['Computers', 'Photo']:  # no public splitting, train/val/test=1/1/8
                percls_trn = int(round(0.1 * len(data.y) / dataset.num_classes))
                val_lb = int(round(0.1 * len(data.y)))
                data = random_splits(data, dataset.num_classes, percls_trn, val_lb, args.seed).to(device)
            else:
                data.train_mask, data.val_mask, data.test_mask = full_train_mask[:, RP], full_val_mask[:, RP], full_test_mask[:, RP]
        else:       
            data = random_splits(data, dataset.num_classes, percls_trn, val_lb, args.seed).to(device)
        
        eval_acc = unsupervised_test_linear(data=data, embeds=embeds, n_classes=dataset.num_classes, device=device, args=args)
        unsup_results.append(eval_acc)

    test_acc_mean = np.mean(unsup_results) * 100
    values = np.asarray(unsup_results, dtype=object)
    uncertainty = np.max(np.abs(sns.utils.ci(sns.algorithms.bootstrap(values, func=np.mean, n_boot=1000), 95) - values.mean()))
    print(f'test acc mean = {test_acc_mean:.4f} ± {uncertainty * 100:.4f}')
