# pretrain_eigvec.py 
import argparse

from loader import MoleculeDataset
from dataloader import DataLoaderEigvec #, DataListLoader

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from tqdm import tqdm
import numpy as np

from model import GNN
from sklearn.metrics import roc_auc_score

from splitters import scaffold_split, random_split, random_scaffold_split
import pandas as pd

from util import MaskAtom
from eigvec_util import *

from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool
from torch_geometric.utils import to_dense_batch

from tensorboardX import SummaryWriter

from torchvision import transforms as T


criterion = nn.CrossEntropyLoss()

import timeit

import time 

import os



def train(args, model_list, loader, optimizer_list, device):
    model, linear_pred_eigvecs = model_list
    optimizer_model, optimizer_linear_pred_eigvecs = optimizer_list

    model.train()
    linear_pred_eigvecs.train()


    total_loss = 0
    total_energy_loss = 0
    total_eigval_loss = 0

    # time_dict = {}
    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        # if step % 10 == 0:
        #     # print(time_dict['ortho_loss'])
        #     time_dict = {}
        #     time_dict['sparse_adj'] = 0
        #     time_dict['model_forward'] = 0
        #     time_dict['eigvec_forward'] = 0
        #     time_dict['orthogonalize'] = 0
        #     time_dict['normalize'] = 0
        #     time_dict['energy_loss'] = 0
        #     time_dict['eigval_loss'] = 0
        #     time_dict['optim_step'] = 0
        #     time_dict['ortho_loss'] = 0

        # TODO: time benchmarking on each step 

        
        # t1 = time.time()
        adj = edge_index_to_sparse_adj(batch.edge_index, batch.num_nodes)
        # time_dict['sparse_adj'] += time.time() - t1

        # t1 = time.time()
        num_nodes = batch.x.shape[0]
        batch = batch.to(device)
        node_rep = model(batch.x, batch.edge_index, batch.edge_attr, batch.pre_positional)
        # time_dict['model_forward'] += time.time() - t1
        # print(node_rep)
        ## loss for nodes
        # print(node_rep.shape)
        # print("Node rep norm: ", node_rep.norm(dim=0))

        if batch.post_positional.shape[1] > 0:
            # print("Post positional norm: ", batch.post_positional.norm(dim=0))
            node_rep = torch.cat((node_rep, batch.post_positional), dim=1) # adding positional embeddings 

        # print(node_rep.shape)
        # t1 = time.time()
        pred_eigvecs = eigvec_forward(linear_pred_eigvecs, node_rep, batch.batch, args.max_nodes, args.device)  
        # time_dict['eigvec_forward'] = time.time() - t1
        
        # t1 = time.time()
        pred_eigvecs = orthogonalize_by_batch(pred_eigvecs, batch.batch)
        # time_dict['orthogonalize']+= time.time() - t1

       #  t1 = time.time()
        pred_eigvecs = normalize_by_batch(pred_eigvecs, batch.batch)
        # time_dict['normalize'] +=time.time() - t1

        # t1 = time.time()
        energy_loss, eigval_loss = SupervisedEigenvalueLoss(pred_eigvecs, batch.edge_index, batch.eigvals, batch.batch, args.max_nodes) 
        # time_dict['eigval_loss'] += time.time() - t1
        
        ortho_loss = OrthogonalityLoss(pred_eigvecs)
        # time_dict['ortho_loss'] += ortho_loss

        # t1 = time.time()
        loss = energy_loss + eigval_loss  

        optimizer_model.zero_grad()
        optimizer_linear_pred_eigvecs.zero_grad()
        loss.backward()
        optimizer_model.step()
        optimizer_linear_pred_eigvecs.step()
        # time_dict['optim_step'] += time.time() - t1


        total_loss += float(loss.cpu().item())
        total_eigval_loss += float(eigval_loss.cpu().item())
        total_energy_loss += float(energy_loss.cpu().item())


    total_loss = total_loss / len(loader)
    total_energy_loss = total_energy_loss / len(loader)
    total_eigval_loss = total_eigval_loss / len(loader)

    return total_loss, total_energy_loss, total_eigval_loss

def main():
    
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device', type=int, default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size', type=int, default=256,
                        help='input batch size for training (default: 256)')
    parser.add_argument('--epochs', type=int, default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr', type=float, default=0.001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--decay', type=float, default=0,
                        help='weight decay (default: 0)')
    parser.add_argument('--num_layer', type=int, default=5,
                        help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim', type=int, default=300,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio', type=float, default=0,
                        help='dropout ratio (default: 0)')
    parser.add_argument('--JK', type=str, default="last",
                        help='how the node features are combined across layers. last, sum, max or concat')
    parser.add_argument('--dataset', type=str, default = 'zinc_standard_agent', help='root directory of dataset for pretraining')
    parser.add_argument('--input_model_file', type=str, default = '', help='filename to input the model')
    parser.add_argument('--output_model_file', type=str, default = '', help='filename to output the model')
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument('--seed', type=int, default=0, help = "Seed for splitting dataset.")
    parser.add_argument('--num_workers', type=int, default = 8, help='number of workers for dataset loading')

    parser.add_argument('--max_nodes', type=int, default = 50, help='max nodes, for padding')
    parser.add_argument('--num_eigenvectors', type=int, default = 6, help='number of eigenvectors to predict')
    
    parser.add_argument('--wavelet_positional_num_nodes', type=int, default=0, help='number of random nodes to compute wavelets from')
    parser.add_argument('--sanity_check_dim', type=int, default=0, help='number of dims to add to sanity check transform')

    parser.add_argument('--rwpe_dim', type=int, default=0, help='k value for rwpe. Suggested size is 20 for ZINC, if using')

    args = parser.parse_args()

    torch.manual_seed(0)
    np.random.seed(0)
   

    device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(0)    

    transform = None

    if args.sanity_check_dim > 0: 
        transform = T.Compose([PostTransformSanityCheck(dim = args.sanity_check_dim), RandomWalkPETransform(walk_length=args.rwpe_dim)])
    elif args.wavelet_positional_num_nodes > 0:
        transform = T.Compose([PositionalWaveletTransform(num_nodes = args.wavelet_positional_num_nodes), RandomWalkPETransform(walk_length=args.rwpe_dim)])
    else: 
        transform = RandomWalkPETransform(walk_length=args.rwpe_dim)

    dataset = MoleculeDataset("dataset/" + args.dataset, dataset=args.dataset, transform=transform)

    loader = DataLoaderEigvec(dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)

    #set up models, one for pre-training and one for context embeddings
    
    if args.sanity_check_dim > 0:
        post_positional_dim = args.sanity_check_dim
    else:
        post_positional_dim = args.wavelet_positional_num_nodes * 5

    model = GNN(args.num_layer, args.emb_dim, args.rwpe_dim, JK = args.JK, drop_ratio = args.dropout_ratio, gnn_type = args.gnn_type).to(device)
    print("total parameters:", )
    linear_pred_eigvecs = torch.nn.Linear((args.emb_dim + post_positional_dim)* args.max_nodes, args.num_eigenvectors * args.max_nodes).to(device)

    if args.input_model_file != '':
        print("loading from: ", args.input_model_file + ".pth")
        model.load_state_dict(torch.load(args.input_model_file + ".pth"))

    model_list = [model, linear_pred_eigvecs]

    #set up optimizers
    optimizer_model = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay)
    optimizer_linear_pred_eigvecs = optim.Adam(linear_pred_eigvecs.parameters(), lr=args.lr, weight_decay=args.decay)

    optimizer_list = [optimizer_model, optimizer_linear_pred_eigvecs]

    train_loss_hist = []
    train_energy_loss_hist = []
    train_eigval_loss_hist = []

    
    for epoch in range(1, args.epochs+1):
        print("====epoch " + str(epoch))
        
        train_loss, train_energy_loss, train_eigval_loss = train(args, model_list, loader, optimizer_list, device)

        train_loss_hist.append(train_loss)
        train_energy_loss_hist.append(train_energy_loss) 
        train_eigval_loss_hist.append(train_eigval_loss)

        print("loss: ", train_loss, "  energy_loss: ", train_energy_loss, "  eigval_loss: ", train_eigval_loss)

    

    if not args.output_model_file == "":
        torch.save(model.state_dict(), args.output_model_file + ".pth")


    loss_hist_dict = {"loss": train_loss_hist, "energy_loss": train_energy_loss_hist, "eigval_loss": train_eigval_loss_hist}
    torch.save(loss_hist_dict, args.output_model_file + "loss_hist_dict" + ".pth")

    
    os.makedirs('plots', exist_ok=True)

    for key in loss_hist_dict.keys():
        path = f"plots/{args.output_model_file}_{key}.png"
        plot_loss_history(loss_hist_dict[key], path)
    return 

if __name__ == "__main__":
    main()
