import argparse
import torch
import numpy as np
from dataset import *
from network import *
from networkLit import *
from torch_geometric.loader import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
pl.seed_everything(0, workers=True)
from pathlib import Path
ABS_PATH = os.path.dirname(os.path.abspath(__file__))


torch.manual_seed(0)
torch.cuda.is_available()
torch.set_float32_matmul_precision('high')

def parse_args():
    parser = argparse.ArgumentParser(description='Train GNN')

    parser.add_argument('--dataset_path', type=str,
                        help='Path to dataset',
                        required=True)
    
    parser.add_argument('--node_features', type=str,
                        help='Node features',
                        required=False,
                        default="dim-pose")
    
    parser.add_argument('--edge_types', type=str,
                        help='Edge types',
                        required=False,
                        default="IK-GO")
    
    parser.add_argument('--edge_features', type=str,
                        help='Edge features',
                        required=False,
                        default="type-IK-GO")
    
    parser.add_argument('--IK_GO_mode', type=str,
                        help='IK GO mode',
                        required=False,
                        default="pred")
    
    parser.add_argument('--augmentations', type=str,
                        help='Augmentations',
                        required=False,
                        default="dimswitch_all")
    
    parser.add_argument('--device', type=str,
                        help='Device',
                        required=False,
                        default="cuda")
    
    parser.add_argument('--n_epochs', type=int,
                        help='Number of epochs',
                        required=False,
                        default=100)
    
    parser.add_argument('--batch_size', type=int,
                        help='Batch size',
                        required=False,
                        default=2048)
    
    parser.add_argument('--lr', type=float,
                        help='Learning rate',
                        required=False,
                        default=0.0001)
    
    parser.add_argument('--weight_decay', type=float,
                        help='Weight decay',
                        required=False,
                        default=0.0)
    
    parser.add_argument('--num_workers', type=int,
                        help='Number of workers',
                        required=False,
                        default=8)
    
    parser.add_argument('--num_node_features', type=int,
                        help='Number of node features',
                        required=False,
                        default=7)
    
    parser.add_argument('--num_edge_features', type=int,
                        help='Number of edge features',
                        required=False,
                        default=7)
    
    parser.add_argument('--hidden_size', type=int,
                        help='Hidden size',
                        required=False,
                        default=256)
    
    parser.add_argument('--num_heads', type=int,
                        help='Number of heads',
                        required=False,
                        default=4)
    
    parser.add_argument('--n_message_passing', type=int,
                        help='Number of message passing',
                        required=False,
                        default=1)
    
    parser.add_argument('--dropout', type=float,
                        help='Dropout',
                        required=False,
                        default=0.0)
    
    parser.add_argument('--pos_weight', type=float,
                        help='Positive weight',
                        required=False,
                        default=1)

    parser.add_argument('--debug', type=bool,
                        help='Positive weight',
                        required=False,
                        default=False)
    
    parser.add_argument('--gnn_type', type=str,
                        help='GNN type',
                        required=False,
                        default="EGAT")
    
    parser.add_argument('--training_mode', type=str,
                        help='Positive weight',
                        required=False,
                        default="one_by_one")
    
    args = parser.parse_args()

    return args

if __name__ == '__main__':
    args = parse_args()
    
    args.num_node_features = 7
    if "IK" in args.edge_features or "GO" in args.edge_features:
        args.num_edge_features = 7
    else:
        args.num_edge_features = 2

    hyperparameters = {"lr": args.lr, "batch_size": args.batch_size, "n_epochs": args.n_epochs, "hidden_size": args.hidden_size, "augmentations": args.augmentations,
                       "num_heads": args.num_heads, "n_message_passing": args.n_message_passing, "node_features": args.node_features, "edge_types": args.edge_types,
                       "edge_features": args.edge_features}

    train_set = GRNDataset(root=os.path.join(args.dataset_path, "train_set"), mode="train", args=args)
    val_set = GRNDataset(root=os.path.join(args.dataset_path, "val_set"), mode="val", args=args)
    test_set = GRNDataset(root=os.path.join(args.dataset_path, "test_set"), mode="test", args=args)
    train_loader = DataLoader(dataset = train_set, batch_size = args.batch_size, shuffle = True, num_workers = args.num_workers)
    val_loader = DataLoader(dataset = val_set, batch_size = args.batch_size, shuffle = False, num_workers = args.num_workers)
    test_loader = DataLoader(dataset = test_set, batch_size = args.batch_size, shuffle = False, num_workers = args.num_workers)
    print("Number of datapoints: train_set =", len(train_set), " val_set =", len(val_set), " test_set =", len(test_set))
    print("Number of batches : train_set =", len(train_loader), " val_set =", len(val_loader), " test_set =", len(test_loader))

    model = GRNLit(args, hyperparameters).to(args.device)
    if args.training_mode == "one_by_one":
        if "IK" in args.edge_features:
            model.model.ik_module.load_state_dict(
                torch.load(os.path.join(ABS_PATH, "lightning_logs/IKModule_" + args.dataset_path.split("/")[-1] + "_" + args.augmentations + ".pt")))
        if "GO" in args.edge_features:
            model.model.go_module.load_state_dict(
                torch.load(os.path.join(ABS_PATH, "lightning_logs/GOModule_" + args.dataset_path.split("/")[-1] + "_" + args.augmentations + ".pt")))
        model.model.agf_module.load_state_dict(
            torch.load(os.path.join(ABS_PATH, "lightning_logs/AGFModule_" + args.dataset_path.split("/")[-1] + "_" + args.augmentations + "_" + \
                       args.edge_features + "_" + args.gnn_type + ".pt")))

    logger = TensorBoardLogger("lightning_logs", name="GRN/" + args.dataset_path.split("/")[-1] + "/" + args.augmentations)
    checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="Feasibility", mode='max', save_top_k=1, save_last=False)
    if args.debug:
        trainer = pl.Trainer(max_epochs=args.n_epochs, accelerator=args.device, deterministic=True, callbacks=[checkpoint_callback], logger=logger, log_every_n_steps=1,
                         detect_anomaly=True)
    else:
        trainer = pl.Trainer(max_epochs=args.n_epochs, accelerator=args.device, deterministic=True, callbacks=[checkpoint_callback], logger=logger)
    trainer.fit(model, train_loader, val_loader)
    model = GRNLit.load_from_checkpoint(trainer.checkpoint_callback.best_model_path, args=args, hyperparameters=hyperparameters)
    torch.save(model.model.state_dict(), os.path.join(ABS_PATH, "lightning_logs/GRN_" + args.dataset_path.split("/")[-1] + "_" + args.augmentations + "_" + args.edge_features + ".pt"))
    trainer.test(model, test_loader)