import warnings
import time
import numpy as np
import argparse
import seaborn as sns
from torch_geometric.utils import dropout_adj, add_self_loops
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.utils import to_dense_adj

from dataset_loader import HeterophilousGraphDataset
from ChebnetII_pro import ChebnetII_prop
from utils import set_seed
from eval import unsupervised_eval_linear
from pe import random_walk_pe, laplacian_eigenvector_pe
warnings.filterwarnings("ignore")


class ChebNetII_CORE(torch.nn.Module):
    def __init__(self, in_dim, args):
        super(ChebNetII_CORE, self).__init__()
        self.dprate = args.dprate
        self.dropout = args.dropout
        self.residual = args.residual
        self.K = args.K
        self.is_bns = args.is_bns

        self.act_fn = nn.PReLU() if args.act_fn == 'prelu' else nn.ReLU()
        self.bn = torch.nn.BatchNorm1d(in_dim, momentum=0.01)
        self.prop1 = ChebnetII_prop(args.K)

    def reset_parameters(self): 
        self.prop1.reset_parameters()
    
    def get_embeddings(self, x, edge_index):
        return self(x=x, edge_index=edge_index)
    
    def forward(self, x, edge_index):
        if self.dprate != 0.0:
            x = F.dropout(x, p=self.dprate, training=self.training)
        prop_x = self.prop1(x, edge_index)
        x = torch.concat((x, prop_x), dim=1) if self.residual else prop_x
        x = F.dropout(x, p=self.dropout, training=self.training)
        if self.is_bns:
            x = self.bn(x)
        x = self.act_fn(x)
        return x



class CCA_SSG(nn.Module):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder
        self.EOS = 1e-5

    def forward(self, feat1, edge_index1, feat2, edge_index2):
        h1 = self.encoder(feat1, edge_index1)
        h2 = self.encoder(feat2, edge_index2)

        z1 = (h1 - h1.mean(0)) / (h1.std(0) + self.EOS)
        z2 = (h2 - h2.mean(0)) / (h2.std(0) + self.EOS)
        return z1, z2
    
    def get_embeddings(self, feat, graph):
        out = self.encoder(feat, graph)
        return out.detach()
    

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='seed.')
    parser.add_argument('--dataset', type=str,default='minesweeper')
    parser.add_argument('--device', type=int, default=0, help='GPU device.')
    parser.add_argument('--net', type=str, default='ChebNetII_V2')

    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--K', default=5, type=int)
    parser.add_argument('--hidden', type=int, default=64, help='hidden units.')
    parser.add_argument('--dropout', type=float, default=0.5, help='dropout for neural networks.')
    parser.add_argument('--dprate', type=float, default=0.5, help='dropout for propagation layer.')
    parser.add_argument('--is_bns', type=bool, default=False)
    parser.add_argument('--act_fn', default='relu', help='activation function')
    parser.add_argument('--residual', action='store_true')

    parser.add_argument('--walk_length', default=10, type=int)
    parser.add_argument('--lap_k', default=5, type=int)

    # unsupervised learning
    parser.add_argument('--lambd', type=float, default=1e-3, help='trade-off ratio.')
    parser.add_argument('--de1', default=0.2, type=float)
    parser.add_argument('--de2', default=0.2, type=float)
    parser.add_argument('--df1', default=0.2, type=float)
    parser.add_argument('--df2', default=0.2, type=float)

    parser.add_argument("--patience", type=int, default=20, help="Patient epochs to wait before early stopping.")
    parser.add_argument("--unsup_epochs", type=int, default=500, help="Unupservised training epochs.")
    parser.add_argument("--lr1", type=float, default=0.001, help="Learning rate of the unsupervised model.")
    parser.add_argument("--lr2", type=float, default=0.01, help="Learning rate of linear evaluator.")
    parser.add_argument("--wd1", type=float, default=0.0, help="Weight decay of the unsupervised model.")
    parser.add_argument("--wd2", type=float, default=0.0, help="Weight decay of linear evaluator.")
    args = parser.parse_args()
    return args


def drop_features(x, drop_prob):
    drop_mask = torch.empty((x.size(1), ), dtype=torch.float32, device=x.device).uniform_(0, 1) < drop_prob
    x = x.clone()
    x[:, drop_mask] = 0
    return x


def random_aug(edge_index, x, df, de):
    x = drop_features(x, df)
    edge_index = dropout_adj(edge_index, p=de)[0]
    return edge_index, x


def unsupervised_learning(data, args, device):
    best = float("inf")
    unsup_tag = str(int(time.time()))

    for epoch in range(args.unsup_epochs):
        model.train()
        optimizer.zero_grad()

        edge_index1, feat1 = random_aug(data.edge_index, data.x, args.df1, args.de1)
        edge_index2, feat2 = random_aug(data.edge_index, data.x, args.df2, args.de2)

        edge_index1 = add_self_loops(edge_index1)[0]
        edge_index2 = add_self_loops(edge_index2)[0]

        z1, z2 = model(feat1, edge_index1, feat2, edge_index2)

        c = torch.mm(z1.T, z2) / data.num_nodes
        c1 = torch.mm(z1.T, z1) / data.num_nodes
        c2 = torch.mm(z2.T, z2) / data.num_nodes

        loss_inv = -torch.diagonal(c).sum()
        iden = torch.tensor(np.eye(c.size(0))).to(device)
        loss_dec1 = (iden - c1).pow(2).sum()
        loss_dec2 = (iden - c2).pow(2).sum()

        loss = loss_inv + args.lambd * (loss_dec1 + loss_dec2)
        # print(f'{loss_inv} {loss_dec1} {loss_dec1}')

        loss.backward()
        optimizer.step()

        if loss < best:
            best = loss
            torch.save(model.state_dict(), 'unsup_pkl/' + 'cca_ssg_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl')

    model.load_state_dict(torch.load('unsup_pkl/' + 'cca_ssg_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl'))
    model.eval()
    embeds = model.get_embeddings(data.x, data.edge_index)
    return embeds


if __name__ == '__main__':
    args = parse_args()
    print(args)
    print("---------------------------------------------")
    
    set_seed(args.seed)
    device = torch.device('cuda:'+str(args.device) if torch.cuda.is_available() else 'cpu')

    root = './data/'
    dataset = HeterophilousGraphDataset(root=root,name=args.dataset)
    data = dataset[0]
    data = data.to(device)

    # print('Start computing pe.')
    if args.walk_length > 0:
        rw_pe = random_walk_pe(data=data, walk_length=args.walk_length).to(device)
        data.x = torch.cat((data.x, rw_pe), dim=1)
    if args.lap_k > 0:
        lap_pe = laplacian_eigenvector_pe(data=data, k=args.lap_k).to(device)
        data.x = torch.cat((data.x, lap_pe), dim=1)
    # print('End computing pe.')

    # edge_index, _ = gcn_norm(edge_index=data.edge_index)
    # A = to_dense_adj(edge_index=edge_index).squeeze(0)
    # .x = torch.cat((data.x, A), dim=1)

    in_dim = data.x.shape[1] * 2 if args.residual else data.x.shape[1]
    encoder = ChebNetII_CORE(in_dim=in_dim, args=args)
    model = CCA_SSG(encoder=encoder).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr1, weight_decay=args.wd1)

    embeds = unsupervised_learning(data=data, args=args, device=device)
    
    results = unsupervised_eval_linear(data=data, embeds=embeds, args=args, device=device)
    results = [v.item() for v in results]
    test_acc_mean = np.mean(results) * 100
    values = np.asarray(results, dtype=object)
    uncertainty = np.max(
        np.abs(sns.utils.ci(sns.algorithms.bootstrap(values, func=np.mean, n_boot=1000), 95) - values.mean()))
    print(f'test acc mean = {test_acc_mean:.4f} ± {uncertainty * 100:.4f}')



