import argparse
import torch
import numpy as np
from dataset import *
from network import *
from networkLit import *
from torch_geometric.loader import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
pl.seed_everything(0, workers=True)
from pathlib import Path
ABS_PATH = os.path.dirname(os.path.abspath(__file__))

torch.manual_seed(0)
torch.cuda.is_available()
torch.set_float32_matmul_precision('high')

def parse_args():
    parser = argparse.ArgumentParser(description='Train GOModule')
    
    parser.add_argument('--dataset_path', type=str,
                        help='Path to dataset',
                        required=True)

    parser.add_argument('--batch_size', type=int,
                        help='Batch size',
                        required=False,
                        default=8192)
    
    parser.add_argument('--hidden_size', type=int,
                        help='Hidden size',
                        required=False,
                        default=512)
    
    parser.add_argument('--device', type=str,
                        help='Device',
                        required=False,
                        default="cuda")
    
    parser.add_argument('--lr', type=float,
                        help='Learning rate',
                        required=False,
                        default=0.001)
    
    parser.add_argument('--weight_decay', type=float,
                        help='Weight decay',
                        required=False,
                        default=0.0)
    
    parser.add_argument('--n_epochs', type=int,
                        help='Number of epochs',
                        required=False,
                        default=100)
    
    parser.add_argument('--num_workers', type=int,
                        help='Number of workers',
                        required=False,
                        default=0)
    
    parser.add_argument('--augmentations', type=str,
                        help='Augmentations',
                        required=False,
                        default="dimswitch_all")
    
    parser.add_argument('--dropout', type=float,
                        help='Dropout',
                        required=False,
                        default=0.0)
    
    parser.add_argument('--pos_weight', type=float,
                        help='Positive weight',
                        required=False,
                        default=1.0)
    
    parser.add_argument('--debug', type=bool,
                        help='Positive weight',
                        required=False,
                        default=False)

    args = parser.parse_args()

    return args


if __name__ == '__main__':
    args = parse_args()

    hyperparameters = {"lr": args.lr, "batch_size": args.batch_size, "n_epochs": args.n_epochs, "hidden_size": args.hidden_size, "augmentations": args.augmentations}

    train_set = GODataset(path=os.path.join(args.dataset_path, "train_set"), mode="train", args=args)
    val_set = GODataset(path=os.path.join(args.dataset_path, "val_set"), mode="val", args=args)
    test_set = GODataset(path=os.path.join(args.dataset_path, "test_set"), mode="test", args=args)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
    print("Number of datapoints: train_set =", len(train_set), " val_set =", len(val_set), " test_set =", len(test_set))
    print("Number of batches : train_set =", len(train_loader), " val_set =", len(val_loader), " test_set =", len(test_loader))

    model = GOModuleLit(args, hyperparameters)
    logger = TensorBoardLogger("lightning_logs", name="GO/" + args.dataset_path.split("/")[-1] + "/" + args.augmentations)
    checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="objective", mode='min', save_top_k=1, save_last=False)
    if args.debug:
        trainer = pl.Trainer(max_epochs=args.n_epochs, accelerator=args.device, deterministic=True, callbacks=[checkpoint_callback], logger=logger, log_every_n_steps=1,
                         detect_anomaly=True)
    else:
        trainer = pl.Trainer(max_epochs=args.n_epochs, accelerator=args.device, deterministic=True, callbacks=[checkpoint_callback], logger=logger)
    
    trainer.fit(model, train_loader, val_loader)
    model = GOModuleLit.load_from_checkpoint(trainer.checkpoint_callback.best_model_path, args=args, hyperparameters=hyperparameters)
    torch.save(model.model.state_dict(), os.path.join(ABS_PATH, "lightning_logs/GOModule_" + args.dataset_path.split("/")[-1] + "_" + args.augmentations + ".pt"))
    trainer.test(model, test_loader)

