"""
This module provides a script for training and evaluating JKOnet* and other models for learning diffusion terms on population data.

Functions
-------------
- ``numpy_collate``
    A custom collate function for PyTorch's DataLoader to properly stack or nest NumPy arrays when using JAX.

- ``main``
    The main function that orchestrates the training loop, evaluation, logging, and visualization. It reads configurations, initializes models and datasets, and executes the training and evaluation processes.

Command-Line arguments
----------------------
The script accepts the following command-line arguments:

- `--solver`, `-s` (`EnumMethod`):
    Name of the solver (model) to use. Choices are defined in the `EnumMethod` class.

- `--dataset`, `-d` (`str`):
    Name of the dataset to train the model on. The dataset should be prepared and located in a directory matching this name.

- `--eval` (`str`):
    Option to test the fit on `'train_data'` or `'test_data'` (e.g., for debugging purposes). Default is `'test_data'`.

- `--wandb` (`bool`):
    If specified, activates Weights & Biases logging for experiment tracking.

- `--debug` (`bool`):
    If specified, runs the script in debug mode (disables JIT compilation in JAX for easier debugging).

- `--seed` (`int`):
    Seed for random number generation to ensure reproducibility.

- `--epochs` (`int`):
    Number of epochs to train the model. If not specified, the number of epochs is taken from the configuration file.

Usage example
-------------
To train a model using the `jkonet-star-potential` solver on a dataset named `my_dataset` with wandb logging:

.. code-block:: bash

    python train.py --solver jkonet-star-potential --dataset my_dataset --wandb

"""

import argparse
import os
import random
import sys
import typing as tp
from pathlib import Path
from time import time

import chex
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optuna
import yaml
from comet_ml import Experiment
from flax.training.train_state import TrainState
from torch.utils.data import DataLoader
from tqdm import tqdm

from dataset import PopulationDataset, PopulationEvalDataset
from models import EnumMethod, get_model
from models.base import LearningDiffusionModel
from utils.plotting import plot_level_curves, plot_predictions
from utils.sde_simulator import get_SDE_predictions

import torch  # isort:skip


# get the root path of the runing script(train.py)
SCRIPT_DIR = Path(__file__).resolve().parent

# add the root path of script into environment,
# such that it can use config in the path.
sys.path.append(SCRIPT_DIR)


def numpy_collate(batch: list[np.ndarray | tuple | list]) -> np.ndarray | list:
    """
    Collates a batch of samples into a single array or nested list of arrays.

    This function recursively processes a batch of samples, stacking NumPy arrays, and collating lists or tuples by grouping elements together. If the batch consists of NumPy arrays, they are stacked. If the batch contains tuples or lists, the function recursively applies the collation.

    This collate function is taken from the `JAX tutorial with PyTorch Data Loading <https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html>`_.

    Parameters
    ----------
    batch : list[np.ndarray | tuple | list]
        A batch of samples where each sample is either a NumPy array, a tuple, or a list. It depends on the
        data loader.

    Returns
    -------
    np.ndarray
        The collated batch, either as a stacked NumPy array or as a nested structure of arrays.
    """
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple, list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)


def setup_seeds(seed: int) -> jnp.ndarray:
    key = jax.random.PRNGKey(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    return key


def load_configs(args) -> tuple[dict[str, tp.Any], str]:
    directory = Path(__file__).resolve().parent
    config_main_path = directory / args.config
    config_extra_path = directory / args.extra_config

    config_main = yaml.safe_load(open(config_main_path))
    config_extra = yaml.safe_load(open(config_extra_path))

    def deep_update(d: dict, u: dict):
        for k, v in u.items():
            if isinstance(v, dict):
                d[k] = deep_update(d.get(k, {}), v)
            else:
                d[k] = v
        return d

    config = deep_update(config_main, config_extra)
    config["T"] = args.n_timesteps
    if args.epochs:
        config["train"]["epochs"] = args.epochs

    postfix = ""
    if str(args.solver).startswith("inverse-jkonet"):
        postfix = f".{config_extra['otmap']['optim']['inner_iter']}.{config_extra['otmap']['model']['type']}"

    return config, postfix


def log_model_params(
    model: LearningDiffusionModel,
    state: TrainState,
    config: dict[str, tp.Any],
    experiment: Experiment | None = None,
) -> None:
    if experiment and config["wandb"]["save_model"]:
        potential_params, internal_params, interaction_params = model.get_params(state)
        experiment.log_parameter("potential_parameters", str(potential_params))
        experiment.log_parameter("internal_parameters", str(internal_params))
        experiment.log_parameter("interaction_parameters", str(interaction_params))


def plot_and_log_predictions(
    predictions: jnp.ndarray,
    dataset_eval: PopulationEvalDataset,
    model: LearningDiffusionModel,
    state: TrainState,
    epoch: int,
    args,
    plot_folder_name: str,
    save_locally: bool = False,
    experiment: Experiment | None = None,
):
    plot_path = os.path.join(plot_folder_name, "predictions.png") if save_locally else None
    trajectory_fig = plot_predictions(
        predictions, dataset_eval.trajectory, interval=None, model=str(args.solver), save_to=plot_path
    )
    if experiment:
        trajectory_fig.savefig(plot_path)
        experiment.log_image(plot_path, name="trajectory", step=epoch)

    potential_dict = {}
    interaction_dict = {}
    if dataset_eval.data_dim == 2:
        if str(args.solver) != "jkonet-star-time-potential":
            img_path = os.path.join(plot_folder_name, f"level_curves_potential.png")
            potential_fig = plot_level_curves(
                model.get_potential(state),
                ((-4, -4), (4, 4)),
                dimensions=dataset_eval.data_dim,
                save_to=img_path if save_locally else None,
                ground_truth=dataset_eval.gt_potential,  # WARN: Not optimal since gt_potential curves always recomputed
            )
            if experiment:
                potential_fig.savefig(img_path)
                experiment.log_image(img_path, name="level_curves_potential", step=epoch)
        else:
            for t in range(1, dataset_eval.T + 1):
                current_potential = lambda x: model.get_potential(state)(jnp.concatenate([x, jnp.array([t])], axis=0))
                _ = plot_level_curves(
                    current_potential,
                    ((-4, -4), (4, 4)),
                    dimensions=dataset_eval.data_dim,
                    save_to=os.path.join(plot_folder_name, f"level_curves_potential_t_{t}") if save_locally else None,
                )

        img_path = os.path.join(plot_folder_name, f"level_curves_interaction.png")
        interaction_fig = plot_level_curves(
            model.get_interaction(state),
            ((-4, -4), (4, 4)),
            dimensions=dataset_eval.data_dim,
            save_to=img_path if save_locally else None,
            ground_truth=dataset_eval.gt_interaction,
        )
        if experiment:
            interaction_fig.savefig(img_path)
            experiment.log_image(img_path, name="level_curves_interaction", step=epoch)

    plt.close(trajectory_fig)
    if len(interaction_dict) > 0:
        plt.close(potential_fig)
    if len(potential_dict) > 0:
        plt.close(interaction_fig)


def compute_and_log_errors(
    dataset_eval: PopulationEvalDataset,
    model: LearningDiffusionModel,
    key_eval: jnp.ndarray,
    state: TrainState,
    config: dict[str, tp.Any],
    epoch: int,
    args,
    plot_folder_name: str,
    save_locally: bool = False,
    experiment: Experiment | None = None,
) -> dict[str, tp.Any]:
    logs = {"epoch": epoch}
    if config["metrics"]["compute_one_ahead"]["enabled"]:
        if "LO" in args.dataset:
            eval_func = dataset_eval.errors_leave_one_out
        else:
            eval_func = dataset_eval.errors_one_step_ahead
        #errors = dataset_eval.errors_one_step_ahead(
        #errors = dataset_eval.errors_leave_one_out(
        errors = eval_func(
            model.get_potential(state),
            model.get_beta(state),
            model.get_interaction(state),
            key_eval,
            model=str(args.solver),
            metrics=config["metrics"]["compute_one_ahead"]["types"],
            simulator=str(args.simulator),
            plot_folder_name=plot_folder_name if save_locally else None,
        )
        if "simulator_jko" in args.dataset and "inverse" in str(args.solver):
            errors = errors | dataset_eval.map_errors_one_step_ahead(str(args.solver), model, state)

        for metric, values in errors.items():
            if metric != "L2_UVP_beta":
                for t, val in enumerate(values):
                    logs[f"error_{metric}_one_ahead_{t}"] = float(val)
                    print(f"Epoch {epoch} | {metric} one step ahead at t={t}: {val:.4f}")
            else:
                logs[f"error_{metric}_one_ahead"] = values

        stats = jax.tree.map(lambda x: (jnp.mean(x), jnp.std(x)), errors)
        for metric, (mean_val, std_val) in stats.items():
            logs[f"error_{metric}_one_ahead"] = float(mean_val)
            logs[f"error_{metric}_one_ahead_std"] = float(std_val)
            print(f"Epoch {epoch} | {metric} one step ahead: mean={mean_val:.4f}, std={std_val:.4f}")
    if experiment:
        experiment.log_metrics(logs, step=epoch)

    return logs


def evaluate_and_log(
    loader_val: DataLoader,
    dataset_eval: PopulationEvalDataset,
    model: LearningDiffusionModel,
    key: jnp.ndarray,
    state: TrainState,
    config: dict[str, tp.Any],
    epoch: int,
    args,
    save_locally: bool = False,
    experiment: Experiment | None = None,
) -> tuple[jnp.ndarray, dict[str, tp.Any]]:
    key, key_eval = jax.random.split(key)
    init_pp = next(iter(loader_val))
    predictions = get_SDE_predictions(
        str(args.solver),
        dataset_eval.dt,
        dataset_eval.T,
        1,
        model.get_potential(state),
        model.get_beta(state),
        model.get_interaction(state),
        key_eval,
        init_pp,
        str(args.simulator),
    )

    plot_folder_name = os.path.join("out", "plots", args.dataset, str(args.solver), str(epoch))
    os.makedirs(plot_folder_name, exist_ok=True) if save_locally else None

    log_model_params(model, state, config, experiment)
    plot_and_log_predictions(
        predictions, dataset_eval, model, state, epoch, args, plot_folder_name, save_locally, experiment
    )
    logs = compute_and_log_errors(
        dataset_eval, model, key_eval, state, config, epoch, args, plot_folder_name, save_locally, experiment
    )
    return key, logs


def check_pruning(
    metric_name: str, threshold: float, logs: dict[str, tp.Any], epoch: int, comparison: tp.Literal[">", "<"] = ">"
) -> None:
    """
    Checks whether a given metric violates a pruning condition and raises TrialPruned if so.

    Args:
        metric_name: Name of the metric to check in the logs.
        threshold: Threshold value for pruning.
        logs: Dictionary containing metric values.
        epoch: Current training epoch.
        comparison: Direction of comparison ('>' or '<').

    Raises:
        optuna.exceptions.TrialPruned: If the metric violates the pruning condition.
    """
    value = logs.get(metric_name)
    if value is None or jnp.isnan(value):
        print(f"Pruning trial at epoch {epoch} due to {metric_name}: value is NaN or missing")
        raise optuna.exceptions.TrialPruned(f"Early stopping at epoch {epoch}, metric={metric_name}, value={value}")

    if (comparison == ">" and value > threshold) or (comparison == "<" and value < threshold):
        print(f"Pruning trial at epoch {epoch} due to {metric_name}: {value}")
        raise optuna.exceptions.TrialPruned(f"Early stopping at epoch {epoch}, metric={metric_name}, value={value}")


def main(args: argparse.Namespace) -> None:
    """
    Main function to train a model on a specified dataset and evaluate it.

    Parameters
    ----------
    args : argparse.Namespace (see module description)

    Returns
    -------
    None
        The function trains the model, evaluates it, and optionally logs the results and metrics.
    """
    key = setup_seeds(args.seed)

    config, postfix = load_configs(args)

    save_locally = config["train"]["save_locally"]
    batch_size, epochs, eval_freq = (
        config["train"]["batch_size"],
        config["train"]["epochs"],
        config["train"]["eval_freq"],
    )

    chex.assert_scalar_positive(epochs)
    chex.assert_scalar_positive(batch_size)
    chex.assert_scalar_positive(eval_freq)

    # Initialize wandb
    config["group_name"] = args.group_name
    experiment = None

    if args.wandb:
        experiment = Experiment(
            project_name=args.group_name,
            auto_output_logging=None,
            log_code=False,
            parse_args=False,
        )
        experiment.set_name(f"{args.solver}.{args.dataset}.{args.seed}{postfix}")
        experiment.log_parameters(config)

    # Load model and dataset
    dataset_eval = PopulationEvalDataset(
        key,
        args.dataset,
        str(args.solver),
        config["metrics"]["wasserstein_order"],
        args.eval,
        dt=config["train"]["dt"],
    )

    dataset = (
        PopulationDataset(args.dataset, batch_size, "train_data.npy", "train_sample_labels.npy")
        if "inverse" in str(args.solver)
        else None
    )  # TODO: fix this crutch in the future

    model = get_model(args.solver, config, dataset_eval.data_dim, args.array_tau, dataset)
    state = model.create_state(key)
    dataset_train = model.load_dataset(args.dataset)
    loader_train = DataLoader(
        dataset_train,
        batch_size=model.batch_size if batch_size > 0 else len(dataset_train),
        shuffle=True,
        collate_fn=numpy_collate,
    )
    loader_val = DataLoader(dataset_eval, batch_size=len(dataset_eval), shuffle=False, collate_fn=numpy_collate)

    print(f"Training {args.solver} on {args.dataset} with seed {args.seed} for {epochs} epochs.\nArguments: {args}")
    train_step = model.train_step
    available_devices = jax.devices()
    if epochs > 1:
        train_step = jax.jit(model.train_step, device=available_devices[args.device])

    print("Evaluating model before training (epoch 0)...")
    key, logs = evaluate_and_log(
        loader_val, dataset_eval, model, key, state, config, 0, args, save_locally, experiment
    )
    EMD_start = logs["error_EMD_one_ahead_0"]

    progress_bar = tqdm(range(1, epochs + 1))
    for epoch in progress_bar:
        epoch_metrics = {}
        t_start = time()

        for sample in loader_train:
            if "multimap" in str(args.solver):
                state, metrics = train_step(state, sample, epoch)
            else:
                state, metrics =  train_step(state, sample)
            epoch_metrics = (
                metrics if not epoch_metrics else jax.tree.map(lambda x, y: float(x + y), epoch_metrics, metrics)
            )

        t_end = time()
        beta = model.get_beta(state)
        progress_bar.desc = f"Epoch {epoch} | Loss Energy: {epoch_metrics['loss_energy']}"
        if experiment:
            experiment.log_metrics(
                {"epoch": epoch, "time_per_epoch": t_end - t_start, "beta": beta} | epoch_metrics, step=epoch
            )

        if epoch % eval_freq == 0:
            key, logs = evaluate_and_log(
                loader_val, dataset_eval, model, key, state, config, epoch, args, save_locally, experiment
            )

        #check_pruning("error_EMD_one_ahead_0", EMD_start, logs, epoch, ">")
        check_pruning("loss_energy", -(10**10), epoch_metrics, epoch, "<")

    #     # Save model
    #     # model.save(f"models/{args.solver}_{args.dataset}_{args.seed}_{epoch}.pt")

    if "error_L2_UVP_potential_backward_one_ahead" in logs:
        return logs.get("error_L2_UVP_potential_backward_one_ahead")
    elif "error_BW2_UVP_one_ahead" in logs:
        return logs.get("error_BW2_UVP_one_ahead")
    # elif "EMD_one_ahead" in logs:
    #     return logs.get("EMD_one_ahead")
    else:
        raise ValueError("Include EMD or L2_UVP into metrics!")


def get_parser():
    # parse arguments
    parser = argparse.ArgumentParser()
    def array_of_floats(arg):
        return np.array(list(map(float, arg.split(','))))

    parser.add_argument('--array-tau', type=array_of_floats,default=0.01)


    parser.add_argument("--config", type=str, default="config.yaml", help="Path to main config.")
    parser.add_argument(
        "--extra_config",
        type=str,
        default="config-inverse-jkonet-extra.yaml",
        help="Path to extra model-specific config.",
    )

    parser.add_argument(
        "--solver",
        "-s",
        type=EnumMethod,
        choices=list(EnumMethod),
        default=EnumMethod.JKO_NET_STAR_POTENTIAL,
        help=f"""Name of the solver to use.""",
    )

    parser.add_argument(
        "--simulator",
        type=str,
        default="forward",
        choices=["forward", "backward"],
        help="Name of the simulator to use",
    )

    parser.add_argument(
        "--dataset",
        "-d",
        type=str,
        help=f"""Name of the dataset to train the model on. The name of the dataset should match the name of the directory generated by the `data_generator.py` script.""",
    )

    parser.add_argument(
        "--eval",
        type=str,
        default="test_data",
        choices=["train_data", "test_data"],
        help=f"""Option to test fit on test data or train data (e.g., for debugging purposes).""",
    )

    parser.add_argument("--debug", action="store_true", help="Option to run in debug mode.")

    parser.add_argument("--epochs", type=int, help="Number of epochs to train the model.")

    # reproducibility
    parser.add_argument("--seed", type=int, default=0, help="Set seed for the run")

    parser.add_argument("--device", type=int, default=0, help="Set device for the run")

    parser.add_argument("--n-timesteps", type=int,default=-1, help="Specify the number of  timesteps for this exp!")
    # path for dataset required by train and test
    parser.add_argument(
        "--data_dir", type=str, default=SCRIPT_DIR / "data", help="Specify the path for experiments dataset"
    )
    parser.add_argument("--project", type=str, default="inverse-jko", help="Set project for running")
    # Wandb arguments group
    wandb_group = parser.add_argument_group("wandb", "Options related to Weights & Biases (wandb)")
    wandb_group.add_argument("--wandb", action="store_true", help="Option to run with activated wandb.")
    wandb_group.add_argument(
        "--group_name", type=str, help="Name of the group to use in wandb (only applicable if --wandb is enabled)."
    )

    return parser


if __name__ == "__main__":
    parser = get_parser()

    args = parser.parse_args()

    # Enforce --group_name when --wandb is enabled
    if args.wandb and not args.group_name:
        parser.error("--group_name is required when --wandb is enabled.")

    # set debug mode
    if args.debug:
        print("Running in DEBUG mode.")
        jax.config.update("jax_disable_jit", True)

    # specifying device
    #jax.config.update("jax_default_device", jax.devices("gpu")[args.device])

    main(args)
