import argparse
import os
import time
import math
import warnings
import seaborn as sns

import torch
import numpy as np
import torch as th
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import add_self_loops, get_laplacian
warnings.filterwarnings("ignore")

from dataset_loader import DataLoader
from eval import unsupervised_test_linear
from utils import random_splits, set_seed, cheby


def parse_args():
    parser = argparse.ArgumentParser(description="PolyGCL")
    parser.add_argument('--runs', type=int, default=10)
    parser.add_argument('--seed', type=int, default=42, help='Random seed.')  # Default seed same as GCNII

    parser.add_argument("--dataset", type=str, default="cora", help="Name of dataset.")
    parser.add_argument("--device", type=int, default=0, help="GPU index. Default: -1, using cpu.")
    parser.add_argument("--epochs", type=int, default=500, help="Training epochs.")
    parser.add_argument("--patience", type=int, default=20, help="Patient epochs to wait before early stopping.")

    parser.add_argument("--lr_prop", type=float, default=0.010, help="Learning rate of prop.")
    parser.add_argument("--lr1", type=float, default=0.001, help="Learning rate of PolyGCL.")
    parser.add_argument("--lr2", type=float, default=0.01, help="Learning rate of linear evaluator.")
    parser.add_argument("--wd_prop", type=float, default=0.0, help="Weight decay of PolyGCL prop.")
    parser.add_argument("--wd1", type=float, default=0.0, help="Weight decay of PolyGCL.")
    parser.add_argument("--wd2", type=float, default=0.0, help="Weight decay of linear evaluator.")

    parser.add_argument("--hid_dim", type=int, default=512, help="Hidden layer dim.")
    parser.add_argument("--K", type=int, default=10, help="Layer of encoder.")
    parser.add_argument('--dropout', type=float, default=0.5, help='dropout for neural networks.')
    parser.add_argument('--dprate', type=float, default=0.5, help='dropout for propagation layer.')
    parser.add_argument('--is_bns', type=bool, default=False)
    parser.add_argument('--act_fn', default='relu', help='activation function')

    parser.add_argument('--train_rate', default=0.6, type=float)
    parser.add_argument('--val_rate', default=0.2, type=float)
    args = parser.parse_args()
    return args


class Discriminator(nn.Module):
    def __init__(self, dim):
        super(Discriminator, self).__init__()
        self.fn = nn.Bilinear(dim, dim, 1)

    def forward(self, h1, h2, h3, h4, c):
        c_x = c.expand_as(h1).contiguous()
        # positive
        sc_1 = self.fn(h2, c_x).squeeze(1)
        sc_2 = self.fn(h1, c_x).squeeze(1)
        # negative
        sc_3 = self.fn(h4, c_x).squeeze(1)
        sc_4 = self.fn(h3, c_x).squeeze(1)
        logits = th.cat((sc_1, sc_2, sc_3, sc_4))
        return logits


class PolyGCL(nn.Module):
    def __init__(self, in_dim, out_dim, K, dprate, dropout, is_bns, act_fn):
        super(PolyGCL, self).__init__()

        self.encoder = ChebNetII(num_features=in_dim, hidden=out_dim, K=K, dprate=dprate, dropout=dropout, is_bns=is_bns, act_fn=act_fn)

        self.disc = Discriminator(out_dim)
        self.act_fn = nn.ReLU()
        self.alpha = nn.Parameter(torch.tensor(0.5), requires_grad=True)
        self.beta = nn.Parameter(torch.tensor(0.5), requires_grad=True)

    def get_embedding(self, edge_index, feat):
        h1 = self.encoder(x=feat, edge_index=edge_index, highpass=True)
        h2 = self.encoder(x=feat, edge_index=edge_index, highpass=False)
        h = torch.mul(self.alpha, h1) + torch.mul(self.beta, h2)
        return h.detach()

    def forward(self, edge_index, feat, shuf_feat):
        # positive
        h1 = self.encoder(x=feat, edge_index=edge_index, highpass=True)
        h2 = self.encoder(x=feat, edge_index=edge_index, highpass=False)

        # negative
        h3 = self.encoder(x=shuf_feat, edge_index=edge_index, highpass=True)
        h4 = self.encoder(x=shuf_feat, edge_index=edge_index, highpass=False)

        h = torch.mul(self.alpha, h1) + torch.mul(self.beta, h2)
        c = self.act_fn(torch.mean(h, dim=0))
        out = self.disc(h1, h2, h3, h4, c)
        return out


def presum_tensor(h, initial_val):
    length = len(h) + 1
    temp = torch.zeros(length)
    temp[0] = initial_val
    for idx in range(1, length):
        temp[idx] = temp[idx-1] + h[idx-1]
    return temp

def preminus_tensor(h, initial_val):
    length = len(h) + 1
    temp = torch.zeros(length)
    temp[0] = initial_val
    for idx in range(1, length):
        temp[idx] = temp[idx-1] - h[idx-1]
    return temp

def reverse_tensor(h):
    temp = torch.zeros_like(h)
    length = len(temp)
    for idx in range(0, length):
        temp[idx] = h[length-1-idx]
    return temp


class ChebnetII_prop(MessagePassing):
    def __init__(self, K, **kwargs):
        super(ChebnetII_prop, self).__init__(aggr='add', **kwargs)

        self.K = K
        self.initial_val_low = Parameter(torch.tensor(2.0), requires_grad=False)
        self.temp_low = Parameter(torch.Tensor(self.K), requires_grad=True)
        self.temp_high = Parameter(torch.Tensor(self.K), requires_grad=True)
        self.initial_val_high = Parameter(torch.tensor(0.0), requires_grad=False)
        self.reset_parameters()

    def reset_parameters(self):
        self.temp_low.data.fill_(2.0/self.K)
        self.temp_high.data.fill_(2.0/self.K)

    def forward(self, x, edge_index, edge_weight=None, highpass=True):
        if highpass:
            TEMP = F.relu(self.temp_high)
            coe_tmp = presum_tensor(TEMP, self.initial_val_high)
        else:
            TEMP = F.relu(self.temp_low)
            coe_tmp = preminus_tensor(TEMP, self.initial_val_low)

        coe = coe_tmp.clone()
        for i in range(self.K + 1):
            coe[i] = coe_tmp[0] * cheby(i, math.cos((self.K + 0.5) * math.pi / (self.K + 1)))
            for j in range(1, self.K + 1):
                x_j = math.cos((self.K - j + 0.5) * math.pi / (self.K + 1))
                coe[i] = coe[i] + coe_tmp[j] * cheby(i, x_j)
            coe[i] = 2 * coe[i] / (self.K + 1)

        edge_index1, norm1 = get_laplacian(edge_index, edge_weight, normalization='sym', dtype=x.dtype, num_nodes=x.size(self.node_dim))
        edge_index_tilde, norm_tilde = add_self_loops(edge_index1, norm1, fill_value=-1.0, num_nodes=x.size(self.node_dim))

        Tx_0 = x
        Tx_1 = self.propagate(edge_index_tilde, x=x, norm=norm_tilde, size=None)
        out = coe[0] / 2 * Tx_0 + coe[1] * Tx_1
        for i in range(2, self.K + 1):
            Tx_2 = self.propagate(edge_index_tilde, x=Tx_1, norm=norm_tilde, size=None)
            Tx_2 = 2 * Tx_2 - Tx_0
            out = out + coe[i] * Tx_2
            Tx_0, Tx_1 = Tx_1, Tx_2
        return out

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

    def __repr__(self):
        return '{}(K={}, temp={})'.format(self.__class__.__name__, self.K, self.temp)
    

class ChebNetII(torch.nn.Module):
    def __init__(self, num_features, hidden=512, K=10, dprate=0.50, dropout=0.50, is_bns=False, act_fn='relu'):
        super(ChebNetII, self).__init__()
        self.lin1 = nn.Linear(num_features, hidden)

        self.prop1 = ChebnetII_prop(K=K)
        assert act_fn in ['relu', 'prelu']
        self.act_fn = nn.PReLU() if act_fn == 'prelu' else nn.ReLU()
        self.bn = torch.nn.BatchNorm1d(num_features, momentum=0.01)
        self.is_bns = is_bns
        self.dprate = dprate
        self.dropout = dropout
        self.reset_parameters()

    def reset_parameters(self):
        self.prop1.reset_parameters()
        self.lin1.reset_parameters()
    
    def get_emebedding(self, x, edge_index, highpass):
        return self(x, edge_index, highpass)

    def forward(self, x, edge_index, highpass=True):
        if self.dprate == 0.0:
            x = self.prop1(x, edge_index, highpass=highpass)
        else:
            x = F.dropout(x, p=self.dprate, training=self.training)
            x = self.prop1(x, edge_index, highpass=highpass)

        x = F.dropout(x, p=self.dropout, training=self.training)
        if self.is_bns:
            x = self.bn(x)
        x = self.lin1(x)
        x = self.act_fn(x)
        return x
 

def unsupervised_learning(data, args):
    feat = data.x
    edge_index = data.edge_index

    n_node = feat.shape[0]
    lbl1 = th.ones(n_node * 2, device=feat.device)
    lbl2 = th.zeros(n_node * 2, device=feat.device)
    lbl = th.cat((lbl1, lbl2))

    best = float("inf")
    cnt_wait = 0
    tag = str(int(time.time()))
    for epoch in range(args.epochs):
        model.train()
        optimizer.zero_grad()

        shuf_idx = np.random.permutation(n_node)
        shuf_feat = feat[shuf_idx, :]

        out = model(edge_index, feat, shuf_feat)
        loss = loss_fn(out, lbl)

        loss.backward()
        optimizer.step()

        if loss < best:
            best = loss
            cnt_wait = 0
            th.save(model.state_dict(), 'unsup_pkl/polygcl_best_model_'+ args.dataset + tag + '.pkl')
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
            # print("Early stopping")
            break

    model.load_state_dict(th.load('unsup_pkl/polygcl_best_model_'+ args.dataset + tag + '.pkl'))
    model.eval()
    embeds = model.get_embedding(edge_index, feat)
    os.remove('unsup_pkl/polygcl_best_model_'+ args.dataset + tag + '.pkl')
    return embeds


if __name__ == "__main__":
    args = parse_args()
    print(args)
    print('---------------')

    set_seed(args.seed)
    #10 fixed seeds for random splits from BernNet
    SEEDS=[1941488137,4198936517,983997847,4023022221,4019585660,2108550661,1648766618,629014539,3212139042,2424918363]
    device = torch.device('cuda:'+str(args.device) if torch.cuda.is_available() else 'cpu')

    # Step 1: Load data =================================================================== #
    dataset = DataLoader(name=args.dataset)
    data = dataset[0]
    n_feat = data.x.shape[1]
    n_classes = np.unique(data.y).shape[0]
    data = data.to(device)

    # Step 2: Create model =================================================================== #
    model = PolyGCL(in_dim=n_feat, out_dim=args.hid_dim, K=args.K, dprate=args.dprate, dropout=args.dropout, is_bns=args.is_bns, act_fn=args.act_fn).to(args.device)

    # Step 3: Create training components ===================================================== #
    optimizer = torch.optim.Adam([{'params': model.encoder.lin1.parameters(), 'weight_decay': args.wd1, 'lr': args.lr1},
                                  {'params': model.disc.parameters(), 'weight_decay': args.wd1, 'lr': args.lr1},
                                  {'params': model.encoder.prop1.parameters(), 'weight_decay': args.wd_prop, 'lr': args.lr_prop},
                                  {'params': model.alpha, 'weight_decay': args.wd_prop, 'lr': args.lr_prop},
                                  {'params': model.beta, 'weight_decay': args.wd_prop, 'lr': args.lr_prop}
                                  ])

    loss_fn = nn.BCEWithLogitsLoss()

    embeds = unsupervised_learning(data=data, args=args)

    percls_trn = int(round(args.train_rate * len(data.y) / dataset.num_classes))
    val_lb = int(round(args.val_rate * len(data.y)))

    unsup_results = []
    for RP in range(args.runs):
        args.seed = SEEDS[RP]
        data = random_splits(data, dataset.num_classes, percls_trn, val_lb, args.seed).to(device)
        eval_acc = unsupervised_test_linear(data=data, embeds=embeds, n_classes=dataset.num_classes, device=device, args=args)
        unsup_results.append(eval_acc)

    test_acc_mean = np.mean(unsup_results) * 100
    values = np.asarray(unsup_results, dtype=object)
    uncertainty = np.max(np.abs(sns.utils.ci(sns.algorithms.bootstrap(values, func=np.mean, n_boot=1000), 95) - values.mean()))
    print(f'test acc mean = {test_acc_mean:.4f} ± {uncertainty * 100:.4f}')