import argparse

from loader import MoleculeDataset
from torch_geometric.data import DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from tqdm import tqdm
import numpy as np

from model import GNN, GNN_graphpred
from sklearn.metrics import roc_auc_score

from splitters import scaffold_split, random_split, random_scaffold_split
import pandas as pd

from tensorboardX import SummaryWriter

import os 

import yaml
from easydict import EasyDict

# for the RWPE
from torchvision import transforms as T
from eigvec_util import *




criterion = nn.BCEWithLogitsLoss(reduction = "none")

def train(args, model, device, loader, optimizer): # TODO, precompute eigvecs and eigvals for the whole dataset, do it in the same way the other one does it
    model.train()
    total_base_loss = 0
    total_eigval_loss = 0
    total_energy_loss = 0


    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)
        if args.predict_eigvecs:
            pred, pred_eigvecs = model(batch.x, batch.edge_index, batch.edge_attr, batch.pre_positional, batch.batch)
        
        else:
            pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.pre_positional, batch.batch)
        y = batch.y.view(pred.shape).to(torch.float64)
        
        #Whether y is non-null or not.
        is_valid = y**2 > 0
        #Loss matrix
        loss_mat = criterion(pred.double(), (y+1)/2)
        #loss matrix after removing null target
        loss_mat = torch.where(is_valid, loss_mat, torch.zeros(loss_mat.shape).to(loss_mat.device).to(loss_mat.dtype))
            
        optimizer.zero_grad()
        loss = torch.sum(loss_mat)/torch.sum(is_valid)

        if args.predict_eigvecs:

            max_nodes = args.eigvec_cfg.max_nodes

            
            keep = eigen_mask(args.eigvec_cfg.num_eigvecs, batch.batch, max_nodes)
            L = get_masked_laplacian(batch.edge_index, batch.batch, keep)
            # print("batch.eigvals", batch.eigvals.shape)
            eigvals_gt = batch.eigvals[batch.batch] # expands eigvals matrix, copying each eigval per-node in each graph
            # print("eigvals_gt", eigvals_gt.shape)
            eigvals_gt = eigvals_gt[keep]
            
            batch_eigen = batch.batch[keep]
            
            
            energy_loss, eigval_loss, ortho_loss, _, _ = SupervisedEigenvalueLoss2(pred_eigvecs, eigvals_gt, batch_eigen, L) 

            orig_loss = loss
            loss = args.eigvec_cfg.lambda_base * orig_loss + args.eigvec_cfg.lambda_eigval * eigval_loss 

            print("orig_loss", orig_loss)
            print("energy_loss", energy_loss)
            print("eigval_loss", eigval_loss)


        
        loss.backward()

        optimizer.step()

        if args.predict_eigvecs:
            total_base_loss += float(orig_loss.cpu().item()) 
            total_eigval_loss += float(eigval_loss.cpu().item())
            total_energy_loss += float(energy_loss.cpu().item())

    if args.predict_eigvecs:
        total_base_loss = total_base_loss / len(loader)
        total_energy_loss = total_energy_loss / len(loader)
        total_eigval_loss = total_eigval_loss / len(loader)

        return total_base_loss, total_energy_loss, total_eigval_loss

def eval(args, model, device, loader, normalized_weight):
    model.eval()
    y_true = []
    y_scores = []

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)

        with torch.no_grad():
            if args.predict_eigvecs:
                pred, pred_eigvecs = model(batch.x, batch.edge_index, batch.edge_attr, batch.pre_positional, batch.batch)
            else:
                pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.pre_positional, batch.batch)

        y_true.append(batch.y.view(pred.shape).cpu())
        y_scores.append(pred.cpu())

    y_true = torch.cat(y_true, dim = 0).numpy()
    y_scores = torch.cat(y_scores, dim = 0).numpy()

    roc_list = []
    weight = []
    for i in range(y_true.shape[1]):
        #AUC is only defined when there is at least one positive data.
        if np.sum(y_true[:,i] == 1) > 0 and np.sum(y_true[:,i] == -1) > 0:
            is_valid = y_true[:,i]**2 > 0
            roc_list.append(roc_auc_score((y_true[is_valid,i] + 1)/2, y_scores[is_valid,i]))
            weight.append(normalized_weight[i])

    if len(roc_list) < y_true.shape[1]:
        print("Some target is missing!")
        print("Missing ratio: %f" %(1 - float(len(roc_list))/y_true.shape[1]))

    weight = np.array(weight)
    roc_list = np.array(roc_list)

    return weight.dot(roc_list)


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device', type=int, default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size', type=int, default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs', type=int, default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr', type=float, default=0.001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--decay', type=float, default=0,
                        help='weight decay (default: 0)')
    parser.add_argument('--num_layer', type=int, default=5,
                        help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim', type=int, default=300,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio', type=float, default=0.2,
                        help='dropout ratio (default: 0.2)')
    parser.add_argument('--graph_pooling', type=str, default="mean",
                        help='graph level pooling (sum, mean, max, set2set, attention)')
    parser.add_argument('--JK', type=str, default="last",
                        help='how the node features across layers are combined. last, sum, max or concat')
    parser.add_argument('--dataset', type=str, default = 'chembl_filtered', help='root directory of dataset. For now, only classification.')
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument('--input_model_file', type=str, default = '', help='filename to read the model (if there is any)')
    parser.add_argument('--output_model_file', type = str, default = '', help='filename to output the pre-trained model')
    parser.add_argument('--num_workers', type=int, default = 8, help='number of workers for dataset loading')
    
    parser.add_argument('--rwpe_dim', type=int, default=0, help='k value for rwpe. Suggested size is 20 for ZINC, if using')
    
    parser.add_argument('--predict_eigvecs', action="store_true", help='Whether to include eigenvector pretraining head')
    parser.add_argument('--eigvec_cfg', type=str, default="eigvec_cfg.yaml", help='cfg file for eigvec head')
    args = parser.parse_args()

    if args.predict_eigvecs:
        with open(args.eigvec_cfg, 'r') as file:
            # Use yaml.safe_load() for safer parsing, especially with untrusted input
            cfg = yaml.safe_load(file)
        eigvec_cfg = EasyDict(cfg)
        args.eigvec_cfg = eigvec_cfg
    else:
        args.eigvec_cfg = None

    torch.manual_seed(0)
    np.random.seed(0)
    device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(0)

    #Bunch of classification tasks
    if args.dataset == "chembl_filtered":
        num_tasks = 1310
    else:
        raise ValueError("Invalid dataset name.")

    
    transform = RandomWalkPETransform(walk_length=args.rwpe_dim)
    pre_transform = None
    if args.predict_eigvecs:
        pre_transform= EigvecPretransform2(5)

    #set up dataset
    dataset = MoleculeDataset("dataset/" + args.dataset, dataset=args.dataset, transform=transform, pre_transform=pre_transform)

    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)

    #set up model
    model = GNN_graphpred(args.num_layer, args.emb_dim, args.rwpe_dim, num_tasks, JK = args.JK, drop_ratio = args.dropout_ratio, graph_pooling = args.graph_pooling, gnn_type = args.gnn_type, predict_eigvecs=args.predict_eigvecs, eigvec_cfg = args.eigvec_cfg)
    print(model)
    pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("total params:", pytorch_total_params)

    if not args.input_model_file == "":
        model.from_pretrained(args.input_model_file + ".pth")
    
    model.to(device)

    #set up optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay)  
    print(optimizer)

    train_loss_hist = []
    train_energy_loss_hist = []
    train_eigval_loss_hist = []

    for epoch in range(1, args.epochs+1):
        print("====epoch " + str(epoch))
        
        if args.predict_eigvecs:
            train_loss, train_energy_loss, train_eigval_loss = train(args, model, device, loader, optimizer)

            train_loss_hist.append(train_loss)
            train_energy_loss_hist.append(train_energy_loss) 
            train_eigval_loss_hist.append(train_eigval_loss)

            print("loss: ", train_loss, "  energy_loss: ", train_energy_loss, "  eigval_loss: ", train_eigval_loss)
        else:
            train(args, model, device, loader, optimizer)

    
    if args.predict_eigvecs:
        if not args.output_model_file == "":
            torch.save(model.gnn.state_dict(), args.output_model_file + ".pth")


        loss_hist_dict = {"loss": train_loss_hist, "energy_loss": train_energy_loss_hist, "eigval_loss": train_eigval_loss_hist}
        torch.save(loss_hist_dict, args.output_model_file + "loss_hist_dict" + ".pth")

        
        os.makedirs('plots', exist_ok=True)

        for key in loss_hist_dict.keys():
            path = f"plots/{args.output_model_file}_{key}.png"
            plot_loss_history(loss_hist_dict[key], path)
         

    if not args.output_model_file == "":
        torch.save(model.gnn.state_dict(), args.output_model_file + ".pth")


if __name__ == "__main__":
    main()
