import argparse

from loader import MoleculeDataset
from dataloader import DataLoaderMasking #, DataListLoader

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from tqdm import tqdm
import numpy as np

from model import GNN, EigvecHead
from sklearn.metrics import roc_auc_score

from splitters import scaffold_split, random_split, random_scaffold_split
import pandas as pd

from util import MaskAtom

from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool

from tensorboardX import SummaryWriter

criterion = nn.CrossEntropyLoss()

import timeit

# for the RWPE
from torchvision import transforms as T

import os 

import yaml
from easydict import EasyDict
from eigvec_util import *

import csv



def compute_accuracy(pred, target):
    return float(torch.sum(torch.max(pred.detach(), dim = 1)[1] == target).cpu().item())/len(pred)


def train(args, model_list, loader, optimizer_list, device, eigvec_head = None, optimizer_eigvec_head=None):
    model, linear_pred_atoms, linear_pred_bonds = model_list
    optimizer_model, optimizer_linear_pred_atoms, optimizer_linear_pred_bonds = optimizer_list
    
    model.train()
    linear_pred_atoms.train()
    linear_pred_bonds.train()
    if args.predict_eigvecs:
        eigvec_head.train()

    loss_accum = 0
    acc_node_accum = 0
    acc_edge_accum = 0
    torch.autograd.set_detect_anomaly(True)
    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)
        
        node_rep = model(batch.x, batch.edge_index, batch.edge_attr, batch.pre_positional)
        ## loss for nodes
        pred_node = linear_pred_atoms(node_rep[batch.masked_atom_indices])
        loss = criterion(pred_node.double(), batch.mask_node_label[:,0])

        acc_node = compute_accuracy(pred_node, batch.mask_node_label[:,0])
        acc_node_accum += acc_node

        if args.mask_edge:
            masked_edge_index = batch.edge_index[:, batch.connected_edge_indices]
            edge_rep = node_rep[masked_edge_index[0]] + node_rep[masked_edge_index[1]]
            pred_edge = linear_pred_bonds(edge_rep)
            loss += criterion(pred_edge.double(), batch.mask_edge_label[:,0])

            acc_edge = compute_accuracy(pred_edge, batch.mask_edge_label[:,0])
            acc_edge_accum += acc_edge

        
        if args.predict_eigvecs:
            pred_eigvecs = eigvec_head(node_rep, batch.batch)

            max_nodes = args.eigvec_cfg.max_nodes

            if args.eigvec_cfg.head.type == "concat":
                keep = eigen_mask(args.eigvec_cfg.num_eigvecs, batch.batch, max_nodes)
            else:
                keep = eigen_mask(args.eigvec_cfg.num_eigvecs, batch.batch)
            L = get_masked_laplacian(batch.edge_index, batch.batch, keep)
            # print("batch.eigvals", batch.eigvals.shape)
            eigvals_gt = batch.eigvals[batch.batch] # expands eigvals matrix, copying each eigval per-node in each graph
            # print("eigvals_gt", eigvals_gt.shape)
            eigvals_gt = eigvals_gt[keep]
            
            batch_eigen = batch.batch[keep]
            # print("batch.eigvals", batch.eigvals.shape)
            # print("pred_eigvecs", pred_eigvecs.shape)
            # print("eigvals_gt", eigvals_gt.shape)
            energy_loss, eigval_loss, ortho_loss, _, _ = SupervisedEigenvalueLoss2(pred_eigvecs, eigvals_gt, batch_eigen, L) 

            orig_loss = loss
            loss = args.eigvec_cfg.lambda_base * orig_loss + args.eigvec_cfg.lambda_eigval * eigval_loss 

            print("orig_loss", orig_loss)
            print("energy_loss", energy_loss)
            print("eigval_loss", eigval_loss)

        # assert not torch.isnan(loss).any()
        
        optimizer_model.zero_grad()
        optimizer_linear_pred_atoms.zero_grad()
        optimizer_linear_pred_bonds.zero_grad()
        if args.predict_eigvecs:
            optimizer_eigvec_head.zero_grad()

        loss.backward()

        optimizer_model.step()
        optimizer_linear_pred_atoms.step()
        optimizer_linear_pred_bonds.step()
        if args.predict_eigvecs:
            optimizer_eigvec_head.step()


        loss_accum += float(loss.cpu().item())

    return loss_accum/step, acc_node_accum/step, acc_edge_accum/step

def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device', type=int, default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size', type=int, default=256,
                        help='input batch size for training (default: 256)')
    parser.add_argument('--epochs', type=int, default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr', type=float, default=0.001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--decay', type=float, default=0,
                        help='weight decay (default: 0)')
    parser.add_argument('--num_layer', type=int, default=5,
                        help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim', type=int, default=300,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio', type=float, default=0,
                        help='dropout ratio (default: 0)')
    parser.add_argument('--mask_rate', type=float, default=0.15,
                        help='dropout ratio (default: 0.15)')
    parser.add_argument('--mask_edge', type=int, default=0,
                        help='whether to mask edges or not together with atoms')
    parser.add_argument('--JK', type=str, default="last",
                        help='how the node features are combined across layers. last, sum, max or concat')
    parser.add_argument('--dataset', type=str, default = 'zinc_standard_agent', help='root directory of dataset for pretraining')
    parser.add_argument('--output_model_file', type=str, default = '', help='filename to output the model')
    parser.add_argument('--input_model_file', type=str, default = '', help='filename to input the model')

    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument('--seed', type=int, default=0, help = "Seed for splitting dataset.")
    parser.add_argument('--num_workers', type=int, default = 8, help='number of workers for dataset loading')

    parser.add_argument('--rwpe_dim', type=int, default=0, help='k value for rwpe. Suggested size is 20 for ZINC, if using')

    parser.add_argument('--predict_eigvecs', action="store_true", help='Whether to include eigenvector pretraining head')
    parser.add_argument('--eigvec_cfg', type=str, default="eigvec_cfg.yaml", help='cfg file for eigvec head')

    args = parser.parse_args()
    if args.predict_eigvecs:
        with open(args.eigvec_cfg, 'r') as file:
            # Use yaml.safe_load() for safer parsing, especially with untrusted input
            cfg = yaml.safe_load(file)
        eigvec_cfg = EasyDict(cfg)
        args.eigvec_cfg = eigvec_cfg
    else:
        args.eigvec_cfg = None

    torch.manual_seed(0)
    np.random.seed(0)
    device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(0)

    print("num layer: %d mask rate: %f mask edge: %d" %(args.num_layer, args.mask_rate, args.mask_edge))
    
    transform = T.Compose([MaskAtom(num_atom_type = 119, num_edge_type = 5, mask_rate = args.mask_rate, mask_edge=args.mask_edge), RandomWalkPETransform(walk_length=args.rwpe_dim)])

    #set up dataset and transform function.
    dataset = MoleculeDataset("dataset/" + args.dataset, dataset=args.dataset, transform = transform)

    loader = DataLoaderMasking(dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)

    #set up models, one for pre-training and one for context embeddings
    model = GNN(args.num_layer, args.emb_dim, args.rwpe_dim, JK = args.JK, drop_ratio = args.dropout_ratio, gnn_type = args.gnn_type).to(device)
    

    if args.input_model_file != '':
        print("loading from: ", args.input_model_file + ".pth")
        model.load_state_dict(torch.load(args.input_model_file + ".pth"))
    
    linear_pred_atoms = torch.nn.Linear(args.emb_dim, 119).to(device)
    linear_pred_bonds = torch.nn.Linear(args.emb_dim, 4).to(device)

    eigvec_head = None
    if args.predict_eigvecs:
        eigvec_head = EigvecHead(args.emb_dim, eigvec_cfg).to(device)
        print("USING EIGVEC HEAD:")
        print(eigvec_head)


    model_list = [model, linear_pred_atoms, linear_pred_bonds]

    #set up optimizers
    optimizer_model = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay)
    optimizer_linear_pred_atoms = optim.Adam(linear_pred_atoms.parameters(), lr=args.lr, weight_decay=args.decay)
    optimizer_linear_pred_bonds = optim.Adam(linear_pred_bonds.parameters(), lr=args.lr, weight_decay=args.decay)

    optimizer_list = [optimizer_model, optimizer_linear_pred_atoms, optimizer_linear_pred_bonds]
    
    optimizer_eigvec_head = None
    if args.predict_eigvecs:
        optimizer_eigvec_head = optim.Adam(eigvec_head.parameters(), lr=args.lr, weight_decay=args.decay)

    losses_df = pd.DataFrame(["epoch", "losses"])
    for epoch in range(1, args.epochs+1):
        print("====epoch " + str(epoch))
        
        train_loss, train_acc_atom, train_acc_bond = train(args, model_list, loader, optimizer_list, device, eigvec_head=eigvec_head, optimizer_eigvec_head=optimizer_eigvec_head)
        print(train_loss, train_acc_atom, train_acc_bond)
        new_row = pd.DataFrame([{"epoch": epoch, "losses": losses}])
        losses_df = pd.concat([losses_df, new_row], ignore_index=True)

    losses_df.to_csv(f"{args.output_model_file}_losses.csv")

    if not args.output_model_file == "":
        torch.save(model.state_dict(), args.output_model_file + ".pth")


if __name__ == "__main__":
    main()
