import os
import json
import torch
import time
import lightning.pytorch as pl

from csiva.model.causal_inducer import CausalInducer
from typing import Dict
from torch.utils.data import DataLoader, random_split
from utils.data import GraphDataset
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import EarlyStopping

# Project directory
WORKSPACE = os.path.join(os.getcwd(), "..")
LOGS_DIR = "lightning_logs"

def args_sanity_check(args: Dict) -> None:
    """Some sanity checks for the input arguments."""

    # Check dataset_paths exits
    for datasets in args["datasets_list"]:
        for folder in datasets:
            path = os.path.join(args["data_path"], folder)
            if not os.path.exists(path):
                raise OSError(f"The path {path} to the data does not exist.")
    print("All checks completed")
    pass


if __name__ == "__main__":

    # Check script directory exists
    if not os.path.isdir(os.path.join(WORKSPACE, "script-arguments")):
        raise IsADirectoryError("The '../script-arguments' directory does not exist."\
                                "Please create the directory with the 'dataset_inference.json' file.")

    # Read hyperparameters
    json_args_file = os.path.basename(__file__).split(".")[0] + ".json"
    json_args_path = os.path.join(WORKSPACE, "script-arguments", json_args_file)
    if os.path.exists(json_args_path):
        with open(json_args_path, "r") as file:
            args = json.load(fp=file)
    else:
        raise FileNotFoundError(f"The file '{json_args_path}' does not exist."\
                                " Please create one.")
    
    # Sanity checks
    args_sanity_check(args)
    random_seed = args["random_seed"]
    learning_rate = args["learning_rate"]
    p_dropout = args["p_dropout"]
    training_epochs = args["training_epochs"]
    max_num_nodes = args["max_num_nodes"] # TODO: implement padding
    device = args["device"]
    datasets_folder = args["data_path"]
    ckpt_path = args["ckpt_path"]
    eps_layer_norm = args["eps_layer_norm"]
    batch_size = args["batch_size"]
    datasets_ratios = args["datasets_ratios"]
    datasets_list = args["datasets_list"]
    n_samples = args["n_samples"]
    patience = args["patience"]
    encoder_summary_type = args["encoder_summary_type"]
    encoder_layer_type = args["encoder_layer_type"]

    # Device to cpu if cuda not available
    if not torch.cuda.is_available():
        device = "cpu"
    print("Device: " + device)


    # Train and evaluate a model for each dataset
    for i in range(len(datasets_list)):
        training_datasets = datasets_list[i]
        current_dataset_ratio = datasets_ratios[i] 
        path_to_datasets = [os.path.join(datasets_folder, dataset) for dataset in training_datasets]

        pl.seed_everything(random_seed)
        generator = torch.Generator().manual_seed(random_seed)

        # Split in train, eval, test data
        dataset = GraphDataset(path_to_datasets, current_dataset_ratio, n_samples)
        training_data, val_data, test_data = random_split(dataset, [0.75, 0.25, 0.0], generator)
        train_dataloader = DataLoader(training_data, batch_size = batch_size, num_workers=8, persistent_workers=True)
        val_dataloader = DataLoader(val_data, batch_size = batch_size, num_workers=8, persistent_workers=True)
        test_dataloader = DataLoader(test_data, num_workers=8, persistent_workers=True)


        # Construct CausalInducer (less parameters than Ke et al.)
        embed_dim=64
        dim_feedforward = embed_dim*4
        depth_encoder = 8
        depth_decoder = 8
        num_heads = 8
        model = CausalInducer(
            num_nodes=max_num_nodes, d_model=embed_dim, dim_feedforward=dim_feedforward, depth_encoder=depth_encoder,
            depth_decoder=depth_decoder, num_heads=num_heads, lr=learning_rate, eps_layer_norm=eps_layer_norm,
            p_dropout=p_dropout, encoder_summary_type=encoder_summary_type, encoder_layer_type=encoder_layer_type
        )

        # Set up the model (e.g. device)
        device = "gpu" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

        # Training and validation
        logger = TensorBoardLogger(WORKSPACE, name=LOGS_DIR)
        start_time = time.time()
        trainer = pl.Trainer(
            accelerator=device, devices="auto", max_epochs=training_epochs, precision="32",
            logger=logger, callbacks=[EarlyStopping(monitor="val_loss", patience=patience)]
        )
        trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, ckpt_path=ckpt_path)

        print('Execution time:', '{:5.2f}'.format(time.time() - start_time), 'seconds')
