import argparse
import torch
import numpy as np
from dataset import *
from network import *
from networkLit import *
from torch_geometric.loader import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
pl.seed_everything(0, workers=True)
from pathlib import Path
ABS_PATH = os.path.dirname(os.path.abspath(__file__))

torch.manual_seed(0)
torch.cuda.is_available()
torch.set_float32_matmul_precision('high')

def parse_args():
    parser = argparse.ArgumentParser(description='Train GNN')

    parser.add_argument('--dataset_path', type=str,
                        help='Path to dataset',
                        required=True)
    
    parser.add_argument('--node_features', type=str,
                        help='Node features',
                        required=False,
                        default="dim-pose")
    
    parser.add_argument('--edge_types', type=str,
                        help='Edge types',
                        required=False,
                        default="IK-GO")
    
    parser.add_argument('--edge_features', type=str,
                        help='Edge features',
                        required=False,
                        default="type-IK-GO")
    
    parser.add_argument('--IK_GO_mode', type=str,
                        help='IK GO mode (gt or pred)',
                        required=False,
                        default="gt")
    
    parser.add_argument('--augmentations', type=str,
                        help='Augmentations',
                        required=False,
                        default="dimswitch_all")
    
    parser.add_argument('--device', type=str,
                        help='Device',
                        required=False,
                        default="cuda")
    
    parser.add_argument('--n_epochs', type=int,
                        help='Number of epochs',
                        required=False,
                        default=100)
    
    parser.add_argument('--batch_size', type=int,
                        help='Batch size',
                        required=False,
                        default=2048)
    
    parser.add_argument('--lr', type=float,
                        help='Learning rate',
                        required=False,
                        default=0.0001)
    
    parser.add_argument('--weight_decay', type=float,
                        help='Weight decay',
                        required=False,
                        default=0.0)
    
    parser.add_argument('--num_workers', type=int,
                        help='Number of workers',
                        required=False,
                        default=8)
    
    parser.add_argument('--num_node_features', type=int,
                        help='Number of node features',
                        required=False,
                        default=7)
    
    parser.add_argument('--num_edge_features', type=int,
                        help='Number of edge features',
                        required=False,
                        default=7)
    
    parser.add_argument('--hidden_size', type=int,
                        help='Hidden size',
                        required=False,
                        default=256)
    
    parser.add_argument('--num_heads', type=int,
                        help='Number of heads',
                        required=False,
                        default=4)
    
    parser.add_argument('--n_message_passing', type=int,
                        help='Number of message passing',
                        required=False,
                        default=1)
    
    parser.add_argument('--dropout', type=float,
                        help='Dropout',
                        required=False,
                        default=0.0)
    
    parser.add_argument('--pos_weight', type=float,
                        help='Positive weight',
                        required=False,
                        default=1)

    parser.add_argument('--debug', type=bool,
                        help='Positive weight',
                        required=False,
                        default=False)
    
    args = parser.parse_args()

    return args

if __name__ == '__main__':
    args = parse_args()

    args.num_node_features = 7
    if "IK" in args.edge_features or "GO" in args.edge_features:
        args.num_edge_features = 7
    else:
        args.num_edge_features = 2

    hyperparameters = {"lr": args.lr, "batch_size": args.batch_size, "n_epochs": args.n_epochs, "hidden_size": args.hidden_size, "augmentations": args.augmentations,
                       "num_heads": args.num_heads, "n_message_passing": args.n_message_passing, "node_features": args.node_features, "edge_types": args.edge_types,
                       "edge_features": args.edge_features}

    train_set = GRNDataset(root=os.path.join(args.dataset_path, "train_set"), mode="train", args=args)
    val_set = GRNDataset(root=os.path.join(args.dataset_path, "val_set"), mode="val", args=args)
    test_set = GRNDataset(root=os.path.join(args.dataset_path, "test_set"), mode="test", args=args)
    train_loader = DataLoader(dataset = train_set, batch_size = args.batch_size, shuffle = True, num_workers = args.num_workers)
    val_loader = DataLoader(dataset = val_set, batch_size = args.batch_size, shuffle = False, num_workers = args.num_workers)
    test_loader = DataLoader(dataset = test_set, batch_size = args.batch_size, shuffle = False, num_workers = args.num_workers)
    print("Number of datapoints: train_set =", len(train_set), " val_set =", len(val_set), " test_set =", len(test_set))
    print("Number of batches : train_set =", len(train_loader), " val_set =", len(val_loader), " test_set =", len(test_loader))

    model = AGFModuleLit(args, hyperparameters).to(args.device)
    logger = TensorBoardLogger("lightning_logs", name="AGFModule/" + args.dataset_path.split("/")[-1] + "/" + args.augmentations)
    checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="objective", mode='max', save_top_k=1, save_last=False)
    if args.debug:
        trainer = pl.Trainer(max_epochs=args.n_epochs, accelerator=args.device, deterministic=True, callbacks=[checkpoint_callback], logger=logger, 
                             log_every_n_steps=1, detect_anomaly=True)
    else:
        trainer = pl.Trainer(max_epochs=args.n_epochs, accelerator=args.device, deterministic=True, callbacks=[checkpoint_callback], logger=logger)
    trainer.fit(model, train_loader, val_loader)
    model = AGFModuleLit.load_from_checkpoint(trainer.checkpoint_callback.best_model_path, args=args, hyperparameters=hyperparameters)
    torch.save(model.model.state_dict(), os.path.join(ABS_PATH, "lightning_logs/AGFModule_" + args.dataset_path.split("/")[-1] + "_" + args.augmentations + "_" + \
                                                      args.edge_features + ".pt"))
    trainer.test(model, test_loader)