import os
import warnings
import time
import argparse
import torch
import numpy as np

from models import *
from utils import set_seed
from unsup_model import GRACE
from dataset_loader import DataLoader
warnings.filterwarnings("ignore")


def print_params(model, string):
    print(f'----------- {string} ----------')
    params = model.encoder.named_parameters()
    for name, param in params:
        print(name)
        print(param)
    print('-----------------------------------')


def get_encoder(dataset, args):
    if args.net == 'ChebNetII':
        encoder = UnsupChebNetII(dataset=dataset, args=args)
    elif args.net == 'GCN':
        encoder = UnsupGCN_Net(dataset=dataset, args=args)
    elif args.net == 'SGC':
        encoder = UnsupSGC_Net(dataset=dataset, args=args)
    elif args.net == 'BernNet':
        encoder = UnsupBernNet(dataset=dataset, args=args)
    elif args.net == 'GPRGNN':
        encoder = UnsupGPRGNN(dataset=dataset, args=args)
    elif args.net == 'ChebNetII_V2':
        encoder = UnsupChebNetII_V2(dataset=dataset, args=args)
    elif args.net == 'BernNet_V2':
        encoder = UnsupBernNet_V2(dataset=dataset, args=args)
    elif args.net == 'GPRGNN_V2':
        encoder = UnsupGPRGNN_V2(dataset=dataset, args=args)
    else:
        raise NotImplementedError
    return encoder


def unsupervised_learning(data, args):
    total_times = []
    for epoch in range(args.unsup_epochs):
        start_time = time.time()

        model.train()
        optimizer.zero_grad()
        
        torch.cuda.synchronize()        
        z1, z2 = model(data.x, data.edge_index)
        loss = model.infonce_loss(z1, z2, l1_norm=False)

        loss.backward()
        optimizer.step()

        # torch.cuda.synchronize()
        # peak_gpu_memory_after = torch.cuda.max_memory_allocated() / (1024 ** 2)  # Convert to MB
        epoch_time = time.time() - start_time  # Calculate epoch duration
        total_times.append(epoch_time)
        # torch.cuda.reset_max_memory_allocated()

        # print(f"Epoch {epoch + 1}: Time = {epoch_time:.6f} seconds, "
        #       f"Peak GPU Memory = {peak_gpu_memory_after:.6f} MB")
    print(f'Average epoch time = {np.mean(total_times):.6f} seconds.\n')


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='seed.')
    parser.add_argument('--dataset', type=str,default='Cora')
    parser.add_argument('--device', type=int, default=0, help='GPU device.')
    parser.add_argument('--net', type=str, default='ChebNetII')

    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--hidden', type=int, default=128, help='hidden units.')
    parser.add_argument('--dropout', type=float, default=0.5, help='dropout for neural networks.')
    parser.add_argument('--use_bn', action='store_true')
    parser.add_argument('--K', type=int, default=10, help='propagation steps.')
    parser.add_argument('--alpha', type=float, default=0.1, help='alpha for APPN/GPRGNN.')
    parser.add_argument('--dprate', type=float, default=0.0, help='dropout for propagation layer.')
    parser.add_argument('--q', type=int, default=0, help='The constant for ChebBase.')
    parser.add_argument('--Init', type=str,choices=['SGC', 'PPR', 'NPPR', 'Random', 'WS', 'Null'], default='PPR', help='initialization for GPRGNN.')

    # unsupervised learning
    parser.add_argument('--de1', default=0.2, type=float)
    parser.add_argument('--de2', default=0.2, type=float)
    parser.add_argument('--df1', default=0.2, type=float)
    parser.add_argument('--df2', default=0.2, type=float)
    parser.add_argument('--tau', default=0.5, type=float)
    parser.add_argument("--proj_hid_dim", type=int, default=128, help="Projection hidden layer dim.")

    parser.add_argument("--patience", type=int, default=50, help="Patient epochs to wait before early stopping.")
    parser.add_argument("--unsup_epochs", type=int, default=10, help="Unupservised training epochs.")
    parser.add_argument("--lr1", type=float, default=0.01, help="Learning rate of the unsupervised model.")
    parser.add_argument("--wd1", type=float, default=0.0, help="Weight decay of the unsupervised model.")

    parser.add_argument('--residual', action='store_true')
    parser.add_argument('--constant', action='store_true')
    parser.add_argument('--reg_coef', default=0.5, type=float)
    # norm layer
    parser.add_argument('--norm_type', default='none', type=str)
    parser.add_argument('--scale', default=0.5, type=float)
    parser.add_argument('--norm_x', action='store_true')
    parser.add_argument('--plusone', action='store_true')
    parser.add_argument('--layer_norm', action='store_true')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    print(f'{args.dataset}, {args.net}')
    print("---------------------------------------------")
    
    set_seed(args.seed)
    #10 fixed seeds for random splits from BernNet
    SEEDS=[1941488137,4198936517,983997847,4023022221,4019585660,2108550661,1648766618,629014539,3212139042,2424918363]
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataset = DataLoader(args.dataset)
    data = dataset[0]
    data = data.to(device)
    
    encoder = get_encoder(dataset=dataset, args=args).to(device)
    if args.net in ['ChebNetII_V2', 'BernNet_V2', 'GPRGNN_V2']:
        fc_dim = dataset.num_node_features * 2 if args.residual else dataset.num_node_features
    else:
        fc_dim = args.hidden
    model = GRACE(encoder=encoder, input_dim=fc_dim, num_hidden=args.hidden, num_proj_hidden=args.proj_hid_dim, 
                  tau=args.tau, drop_rate=(args.de1, args.de2, args.df1, args.df2), args=args).to(device)
    optimizer = torch.optim.Adam([{'params': model.parameters(), 'weight_decay': args.wd1, 'lr': args.lr1}])

    unsupervised_learning(data=data, args=args)