import os
import sys
import json
import torch
import torch.backends.mps
import time
import csv
import pandas as pd
import lightning.pytorch as pl
from typing import Dict
from torch.utils.data import DataLoader
from lightning.pytorch.loggers import CSVLogger

from csiva.model.causal_inducer import CausalInducer
from utils.data import GraphDataset

WORKSPACE = os.path.join(os.getcwd(), "..")
LOGS_DIR = "lightning_logs"


def args_sanity_check(args: Dict) -> None:
    """Some sanity checks for the input arguments."""

    for dataset in args["datasets_list"]:
        path = os.path.join(args["data_path"], dataset)
        if not os.path.exists(path):
            raise OSError(f"The path {path} to the data does not exist.")
    print("All checks completed")
    pass


if __name__ == "__main__":

    operation_name = os.path.basename(__file__).split(".")[0]

    if len(sys.argv) > 1:
        # loading pre-generated args
        args_dir = os.path.join(WORKSPACE, "script-arguments", "generated")
        args_filename = f"{operation_name}_{sys.argv[1]}.json"

    else:
        args_dir = os.path.join(WORKSPACE, "script-arguments")
        args_filename = operation_name + ".json"

    # Check script directory exists
    if not os.path.isdir(args_dir):
        raise IsADirectoryError(
            "The '../script-arguments' directory does not exist."
            "Please create the directory with the 'dataset_inference.json' file."
        )

    json_args_path = os.path.join(args_dir, args_filename)

    if os.path.exists(json_args_path):
        with open(json_args_path, "r") as file:
            args = json.load(fp=file)
    else:
        raise FileNotFoundError(f"The file '{json_args_path}' does not exist.")

    # Sanity checks
    args_sanity_check(args)
    random_seed = args["random_seed"]
    max_num_nodes = args["max_num_nodes"]  # TODO: implement padding
    device = args["device"]
    datasets_folder = args["data_path"]
    ckpt_path = args["ckpt_path"]
    version = args["version"]
    dataset_ratios = args["dataset_ratios"]
    n_samples = args["n_samples"]
    encoder_layer_type = args["encoder_layer_type"]
    encoder_summary_type = args["encoder_summary_type"]

    # Device to cpu if cuda not available
    if not torch.cuda.is_available():
        device = "cpu"
    print("Device: " + device)

    for dataset_name in args["datasets_list"]:
        path_to_dataset = os.path.join(datasets_folder, dataset_name)
        pl.seed_everything(random_seed)

        # Load test data
        path_to_datasets = [path_to_dataset]
        dataset = GraphDataset(path_to_datasets, dataset_ratios, n_samples)
        test_dataloader = DataLoader(dataset, num_workers=8, persistent_workers=True)

        # Construct CausalInducer (Ke et al. hyperparameters)
        embed_dim=64
        dim_feedforward = embed_dim*4
        depth_encoder = 8
        depth_decoder = 8
        num_heads = 8
        model = CausalInducer(
            num_nodes=max_num_nodes, d_model=embed_dim, dim_feedforward=dim_feedforward,
            depth_encoder=depth_encoder,depth_decoder=depth_decoder, num_heads=num_heads,
            encoder_layer_type=encoder_layer_type, encoder_summary_type=encoder_summary_type
        )

        # Set up the model (e.g. device)
        device = (
            "gpu"
            if torch.cuda.is_available()
            else "mps" if torch.backends.mps.is_available() else "cpu"
        )

        # Test
        logger = CSVLogger(WORKSPACE, name=LOGS_DIR, version=version)
        start_time = time.time()
        trainer = pl.Trainer(
            accelerator=device,
            devices="auto",
            precision="32",
            logger=logger,
            # callbacks=[DatasetNameCallback(dataset_name)]
        )
        trainer.test(model=model, dataloaders=test_dataloader, ckpt_path=ckpt_path)

        print("Execution time:", "{:5.2f}".format(time.time() - start_time), "seconds")

        # Add dataset name to metrics logs
        path_to_test_metrics = os.path.join(
            WORKSPACE, LOGS_DIR, f"version_{version}", "metrics.csv"
        )
        print(path_to_test_metrics, os.path.abspath(path_to_test_metrics))
        d = os.path.join(WORKSPACE, LOGS_DIR, f"version_{version}")
        print(d, os.listdir(d))
        metrics = pd.read_csv(path_to_test_metrics)

        # Update last row with the dataset name
        row = metrics.iloc[-1].to_list()
        row.insert(0, dataset_name)
        header = metrics.columns.to_list()
        header.insert(0, "dataset_name")

        path_to_custom_metrics = os.path.join(
            WORKSPACE, LOGS_DIR, f"version_{version}", "custom_metrics.csv"
        )
        if not os.path.exists(path_to_custom_metrics):
            with open(path_to_custom_metrics, "w") as f:
                writer = csv.writer(f)
                writer.writerow(header)
        else:
            # Match old and new header order (bug in lightning, should be fixed but is not)
            with open(path_to_custom_metrics, "r") as f:
                line = f.readline()
                custom_header = line.rstrip("\n").split(",")
            permutation = []
            for col in custom_header:
                permutation.append(header.index(col))
            row = [row[idx] for idx in permutation]

        with open(path_to_custom_metrics, "a") as f:
            writer = csv.writer(f)
            writer.writerow(row)
