"""
Module for handling datasets and computing prediction errors in population dynamics.

This module provides several dataset classes designed for loading and accessing different formats of the population trajectory data, including trajectory data and coupling data, and testing fit and prediction errors.

Classes
-------
    - ``PopulationDataset``
        Handles loading and batching of particle trajectory data. The single unit if a particle trajectory.
    - ``CouplingsDataset``
        Loads coupling data for trajectory models, including weights, features, and densities. The single unit is a coupling.
    - ``LinearParametrizationDataset``
        Loads data for the linear parametrization. The single unit is the entire dataset.
    - ``PopulationEvalDataset``
        Facilitates evaluation of model predictions using particle trajectories and computes prediction errors such as the Wasserstein distance.
"""

import glob
import math
import os
import pickle
import re
from collections import defaultdict
from pathlib import Path
from typing import Any, Callable

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset

from utils.functions import interactions_all, potentials_all
from utils.metrics import (
    compute_BW_UVP_by_gt_samples,
    compute_sinkhorn_divergence,
    gradients_interaction,
    gradients_potential,
    l2_distance,
    l2_uvp_backward,
    mmd,
    earth_mover_distance,
    MMD_loss
)
from utils.ot import wasserstein_loss
from utils.plotting import plot_predictions
from utils.sde_simulator import get_SDE_predictions

import re
from typing import List

def extract_lo_numbers(s: str) -> List[int]:
    """
    Extracts a list of integers after 'LO_' in the input string.
    Example: "RAW_RNA_EB_100_LO_1_3" -> [1, 3]
    """
    match = re.search(r'LO_([\d_]+)', s)
    if match:
        number_str = match.group(1)
        return list(map(int, number_str.split('_')))
    return []

def parse_name(name: str) -> dict[str, str]:

    run_details = {}
    match = re.search(r"potential_(.+?)_internal_(.+?)_beta_(.+?)_interaction_(.+?)_dt", name)

    if match:
        run_details["potential"] = match.group(1)
        run_details["interaction"] = match.group(4)
        run_details["beta"] = match.group(3)
        return run_details
    else:
        raise ValueError("Can not match the pattern expression!")


class PopulationDataset(Dataset):
    """
    Dataset class for loading and accessing particle trajectory data.

    The dataset is expected to be located in a directory named 'data/{dataset_name}' and consist of a single .npy file named 'data.npy'. The data contains particle trajectories over time, where each timestep has a set of particles.

    If the number of particles in a timestep is less than the maximum number of particles in any timestep, the dataset wraps around to handle the imbalance.

    Attributes
    ----------
    trajectory : list[np.ndarray]
        Array of shape (num_timesteps, num_particles, num_features) containing
        the particle trajectories. Each entry in the array represents a particle's
        state at a given timestep.
    """

    def __init__(
        self,
        dataset_name: str,
        batch_size: int,
        data_file="data.npy",
        labels_file="sample_labels.npy",
        random_indices: bool = True,
    ):
        """
        Initialize the PopulationDataset by loading data from 'data.npy'.

        Parameters
        ----------
        dataset_name : str
            The name of the dataset to load. The dataset should be located in
            'data/{dataset_name}' and should contain a .npy file named 'data.npy'.
        """
        directory = Path(__file__).resolve().parent
        self.data = np.load(directory / "data" / dataset_name / data_file)
        self.sample_labels = np.load(directory / "data" / dataset_name / labels_file)
        self.batch_size = batch_size
        self.random_indices = random_indices

        # Group particles by their timestep using a defaultdict
        self.trajectory = defaultdict(list)
        for value, label in zip(self.data, self.sample_labels):
            self.trajectory[label].append(value)

        # Convert lists to numpy arrays
        for label in self.trajectory:
            self.trajectory[label] = np.array(self.trajectory[label])

        # Find the maximum number of particles in any timestep
        self.max_particles = max([particles.shape[0] for particles in self.trajectory.values()])
        if self.max_particles % self.batch_size != 0:
            self.max_particles = math.ceil(self.max_particles / self.batch_size) * self.batch_size

    def __len__(self) -> int:
        """
        Returns the number of timesteps in the dataset.

        Returns
        -------
        int
            The number of timesteps in the dataset.
        """

        return self.max_particles

    def __getitem__(self, idx: int) -> list:
        """
        Retrieve particle data for each timestep at the given index.

        Parameters
        ----------
        idx : int
            The index of the particle to retrieve.

        Returns
        -------
        list of np.ndarray
            A list where each element is an array representing the state of a
            particle at each timestep. The length of the list corresponds to the
            number of timesteps, and each array represents the particle state
            at a specific timestep.
        """
        timesteps = sorted(self.trajectory.keys())
        num_timesteps = len(timesteps)
        indices = (
            np.random.choice(self.max_particles, num_timesteps, replace=True)
            if self.random_indices
            else np.full(num_timesteps, idx % self.max_particles)
        )

        sampled_batches = []
        for timestep, ind in zip(timesteps, indices):
            pool = self.trajectory[timestep]
            sampled_batches.append(pool[ind % len(pool)])
        return sampled_batches


class CouplingsDataset(Dataset):
    """
    Dataset class for loading and accessing couplings data.

    The dataset is expected to be located in a directory named 'data/{dataset_name}' and consist of multiple .npy files. It provides access to input features, target features, time labels, weights, density values, and density gradients.

    Attributes
    ----------
    weight : np.ndarray
        Array of weights extracted from the couplings data.
    x : np.ndarray
        Array of input features extracted from the couplings data.
    y : np.ndarray
        Array of target features extracted from the couplings data.
    time : np.ndarray
        Array of time labels extracted from the couplings data.
    densities : np.ndarray
        Array of density values extracted from the densities files.
    densities_grads : np.ndarray
        Array of gradients of densities extracted from the densities files.
    """

    def __init__(self, dataset_name: str) -> None:
        """
        Initialize the CouplingsDataset by loading data from .npy files.

        Parameters
        ----------
        dataset_name : str
            The name of the dataset to load. The dataset is expected to be located in a
            directory named 'data/{dataset_name}' and consist of multiple .npy files.
        """
        # load couplings for all timesteps together
        directory = Path(__file__).parent
        couplings = np.concatenate(
            [np.load(f) for f in glob.glob(os.path.join(str(directory / "data" / dataset_name), "couplings_*.npy"))]
        )
        self.weight = couplings[:, -1]
        self.x = couplings[:, : (couplings.shape[1] - 2) // 2]
        self.y = couplings[:, (couplings.shape[1] - 2) // 2 : -2]
        self.time = couplings[:, -2]
        self.densities = np.concatenate(
            [
                np.load(f)
                for f in glob.glob(os.path.join(str(directory / "data" / dataset_name), "density_and_grads_*.npy"))
            ]
        )
        self.densities_grads = self.densities[:, 1:]
        self.densities = self.densities[:, 0]

    def __len__(self) -> int:
        """
        Returns the number of samples in the dataset.

        Returns:
            int: The number of samples.
        """
        return self.x.shape[0]

    def __getitem__(
        self, idx: int
    ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
        """
        Retrieve a sample (x, y, t, w, rho, rho_grad) from the dataset at the given index.

        Parameters
        ----------
        idx : int
            The index of the sample to retrieve.

        Returns
        -------
        tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
            A tuple containing:

            - Input features (jnp.ndarray): Initial particle distribution.
            - Target features (jnp.ndarray): Target particle distribution.
            - Time label (jnp.ndarray): Time label.
            - Weight of the coupling (jnp.ndarray): Weight of the coupling.
            - Density value (jnp.ndarray): Density value.
            - Gradient of densities (jnp.ndarray): Gradient of densities.
        """
        return (
            self.x[idx],
            self.y[idx],
            self.time[idx],
            self.weight[idx],
            self.densities[idx],
            self.densities_grads[idx],
        )


class LinearParametrizationDataset(Dataset):
    """
    This dataset class loads and organizes data necessary for linear parametrization solver tasks, for which all data is analyzed together.
    """

    def __init__(self, dataset_name: str) -> None:
        """
        Initialize the LinearParametrizationDataset.

        Parameters
        ----------
        dataset_name : str
            The name of the dataset to load.

        """
        couplings = [np.load(f) for f in glob.glob(os.path.join("data", dataset_name, "couplings_*.npy"))]

        densities = [np.load(f) for f in glob.glob(os.path.join("data", dataset_name, "density_and_grads_*.npy"))]
        self.data = [
            (
                c[:, : (c.shape[1] - 1) // 2],
                c[:, (c.shape[1] - 1) // 2 : -2],
                c[:, -2],
                c[:, -1],
                densities[t][:, 0],
                densities[t][:, 1:],
            )
            for t, c in enumerate(couplings)
        ]

    def __len__(self) -> int:
        """
        Return the number of elements in the dataset.

        Returns
        -------
        int
            The number of elements (always 1 for this dataset).
        """
        return 1

    def __getitem__(self, _) -> list[tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
        """
        Retrieve the entire dataset.

        Since for the linear parametrization all data is used together, this method returns all data at once and the index parameter `_` is ignored.

        Parameters
        ----------
        _ : any
            This parameter is ignored.

        Returns
        -------
        list[tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]]
            A list of tuples, where each tuple contains:

        - Input features (jnp.ndarray): Initial particle distribution.
        - Target features (jnp.ndarray): Target particle distribution.
        - Time label (jnp.ndarray): Time label associated with each sample.
        - Weight of the coupling (jnp.ndarray): Weight assigned to the coupling.
        - Density values (jnp.ndarray): Density values.
        - Gradient of densities (jnp.ndarray): Gradient of the density values.
        """
        return self.data


class PopulationEvalDataset(Dataset):
    """
    This dataset class loads and organizes population trajectory data for evaluation.

    Attributes
    ----------
    trajectory : dict
        A dictionary where each key corresponds to a unique timestep in the dataset, and the value is an array of trajectory data associated with that timestep.
    label_mapping : dict
        A dictionary mapping the original sample labels to consecutive integer indices.
    T : int
        The number of timesteps in the trajectories.
    data_dim : int
        The dimensionality of the data at each timestep.
    no_ground_truth : bool
        Flag indicating if the dataset lacks a ground truth file.
    potential : str
        The potential function used in the predictions.
    internal : str
        The internal dynamics setting used.
    beta : float
        The beta parameter used in the simulations.
    interaction : str
        The interaction function used in the predictions.
    dt : float
        The timestep size used in the simulation.
    trajectory_only_potential : np.ndarray
        Trajectory predictions considering only the potential term.
    trajectory_only_interaction : np.ndarray
        Trajectory predictions considering only the interaction term.
    """

    def __init__(
        self,
        key: Any,
        dataset_name: str,
        solver: str,
        wasserstein_metric: int,
        label: str = "test_data",
        dt: float = 0.01,
    ):
        """
        Initialize the PopulationEvalDataset.

        Parameters
        ----------
        key : Any
            A key used for random number generation or seeding.
        dataset_name : str
            The name of the dataset to load. The data should be located in the directory
            'data/{dataset_name}' and consist of .npy files.
        solver : str
            The solver method used, primarily for plotting or prediction purposes.
        wasserstein_metric: int
            Specifies the order of the Wasserstein distance to be used for the error calculation.
        label : str, optional
            Specifies whether to load 'test_data' or 'train_data'. Default is 'test_data'.
        dt : float, optional
            Parameter of SDE discretization.

        """
        # dt does not actually matter for learning,
        # because everything can be scaled accordingly - as long as it
        # is always consistent
        self.key = key
        self.solver = solver
        self.wasserstein_metric = wasserstein_metric
        self.dt = dt

        data, sample_labels = self._load_data(dataset_name, label)
        self.label_mapping = self._create_label_mapping(sample_labels)
        self.trajectory = self._organize_trajectory(data, sample_labels)

        first_label = next(iter(self.trajectory))
        self.data_dim = self.trajectory[first_label].shape[1] if self.trajectory else 0
        self.T = len(self.trajectory) - 1  # It's strange to subtract 1 in order to add later 1
        self.no_ground_truth = False

        self.gt_potential = None
        self.gt_interaction = None
        self.gt_beta = None
        # shape of trajectory[i]: (num_samples, dim)
        MMD_DMSB = MMD_loss()
        self.data_metrics = {
            "MMD_DMSB": lambda pred, t: MMD_DMSB(pred,self.trajectory[t]),
            "EMD_Tong": lambda pred, t: earth_mover_distance(pred, self.trajectory[t]),
            "MMD": lambda pred, t: mmd(pred, self.trajectory[t]),
            "EMD": lambda pred, t: wasserstein_loss(pred, self.trajectory[t], self.wasserstein_metric),
            "BW2_UVP": lambda pred, t: compute_BW_UVP_by_gt_samples(pred, self.trajectory[t]),
            "L2_distance": lambda pred, t: l2_distance(pred, self.trajectory[t]),
            "ott_sinkhorn": lambda pred, t: compute_sinkhorn_divergence(pred, self.trajectory[t]),
        }

        if "RNA" not in dataset_name:
            run_detail = parse_name(dataset_name)
            if run_detail["potential"] != "none":

                self.gt_potential = potentials_all[run_detail["potential"]]
            else:
                self.gt_potential = lambda _: 0.0

            if run_detail["interaction"] != "none":
                self.gt_interaction = interactions_all[run_detail["interaction"]]
            else:
                self.gt_interaction = lambda _: 0.0

            self.gt_beta = float(run_detail["beta"])

            self.functional_metrics = {
                "L2_UVP_potential_backward": lambda potential_func, t: l2_uvp_backward(
                    self.trajectory[t], self.trajectory[t - 1], gradients_potential, self.gt_potential, potential_func
                )
                * 0.01,  # self.dt**2,
                "L2_UVP_interaction_backward": lambda interaction_func, t: l2_uvp_backward(
                    self.trajectory[t],
                    self.trajectory[t - 1],
                    gradients_interaction,
                    self.gt_interaction,
                    interaction_func,
                )
                * 0.01,  # self.dt**2,
                "L2_UVP_beta": lambda beta, t: abs(beta - self.gt_beta),
            }
        else:
            self.functional_metrics = {}

    def _load_data(self, dataset_name: str, label: str) -> tuple[np.ndarray, np.ndarray]:
        """Loads data and sample labels from disk."""
        base_path = Path(__file__).resolve().parent / "data" / dataset_name
        suffix = "test" if label == "test_data" else "train"
        data = np.load(base_path / f"{suffix}_data.npy")
        sample_labels = np.load(base_path / f"{suffix}_sample_labels.npy")
        self.leave_out_list = extract_lo_numbers(dataset_name)

        if "simulator_jko" in dataset_name:
            with open(base_path / "gt_maps.pkl", "rb") as f:
                self.transport_maps = pickle.load(f)
            with open(base_path / "gt_params.pkl", "rb") as f:
                self.transport_params = pickle.load(f)

        return data, sample_labels

    def _create_label_mapping(self, sample_labels: np.ndarray) -> dict[int, int]:
        """Creates a mapping from original labels to consecutive indices."""
        unique_labels = np.unique(sample_labels)
        return {original: i for i, original in enumerate(unique_labels)}

    def _organize_trajectory(self, data: np.ndarray, sample_labels: np.ndarray) -> dict[int, jnp.ndarray]:
        """Organizes trajectory data by mapped labels."""
        trajectory = defaultdict(list)
        for value, label in zip(data, sample_labels):
            trajectory[self.label_mapping[label]].append(value)
        return {label: jnp.array(values) for label, values in trajectory.items()}

    def __len__(self) -> int:
        """Returns the number of particles at the first timestep."""
        return self.trajectory[0].shape[0]

    def __getitem__(self, idx: int) -> jnp.ndarray:
        """Retrieves a particle's features at the first timestep."""
        return self.trajectory[0][idx, :]

    def errors_leave_one_out(self,
        potential: Callable[[jnp.ndarray], float],
        beta: float,
        interaction: Callable[[jnp.ndarray], float],
        key_eval: jnp.ndarray,
        model: str,
        metrics: list[str],
        simulator: str,
        plot_folder_name: str | None = None):


        #if self.errors_leave_one_out:
        lo_num = len(self.leave_out_list)
        errors = {
            metric: jnp.ones(lo_num)
            for metric in metrics
            if metric in self.data_metrics or metric in self.functional_metrics
        }
        for l_idx, label in enumerate(self.leave_out_list):
            
            target_idx = self.label_mapping[label]
            init_idx = target_idx - 1
            idx = init_idx
            t = label 
            rho = self.trajectory[idx]

            predictions = get_SDE_predictions(
                    self.solver, self.dt, 1, t, potential, beta, interaction, key_eval, rho, simulator
                )
            if plot_folder_name:
                plot_path = os.path.join(plot_folder_name, f"one_ahead_tp_{t + 1}")
                prediction_fig = plot_predictions(
                    predictions[-1].reshape(1, -1, self.data_dim),
                    self.trajectory,
                    interval=(idx + 1, idx + 1),
                    model=model,
                    save_to=plot_path,
                )
                plt.close(prediction_fig)
            for metric in errors.keys():
                if metric in self.data_metrics:
                    errors[metric] = errors[metric].at[l_idx].set(self.data_metrics[metric](predictions[-1], idx + 1))
                elif metric in self.functional_metrics:
                    if "potential" in metric and self.gt_potential is not None:
                        errors[metric] = errors[metric].at[l_idx].set(self.functional_metrics[metric](potential, idx + 1))
                    elif "interaction" in metric and self.gt_interaction is not None:
                        errors[metric] = errors[metric].at[l_idx].set(self.functional_metrics[metric](interaction, idx+ 1))
                    elif "beta" in metric and self.gt_beta is not None:
                        errors[metric] = errors[metric].at[l_idx].set(self.functional_metrics[metric](beta, idx + 1))
                else:
                    print(f"Skipped {metric} metric since it is impossible to use it in this experiment.")

        return errors



    def errors_one_step_ahead(
        self,
        potential: Callable[[jnp.ndarray], float],
        beta: float,
        interaction: Callable[[jnp.ndarray], float],
        key_eval: jnp.ndarray,
        model: str,
        metrics: list[str],
        simulator: str,
        plot_folder_name: str | None = None,
    ) -> dict[str, jnp.ndarray]:
        errors = {
            metric: jnp.ones(self.T)
            for metric in metrics
            if metric in self.data_metrics or metric in self.functional_metrics
        }

        for t in range(self.T):
            rho = self.trajectory[t]

            predictions = get_SDE_predictions(
                self.solver, self.dt, 1, t + 1, potential, beta, interaction, key_eval, rho, simulator
            )  # [K-1, batch_size, d]

            if plot_folder_name:
                plot_path = os.path.join(plot_folder_name, f"one_ahead_tp_{t + 1}")
                prediction_fig = plot_predictions(
                    predictions[-1].reshape(1, -1, self.data_dim),
                    self.trajectory,
                    interval=(t + 1, t + 1),
                    model=model,
                    save_to=plot_path,
                )
                plt.close(prediction_fig)

            for metric in errors.keys():
                if metric in self.data_metrics:
                    errors[metric] = errors[metric].at[t].set(self.data_metrics[metric](predictions[-1], t + 1))
                elif metric in self.functional_metrics:
                    if "potential" in metric and self.gt_potential is not None:
                        errors[metric] = errors[metric].at[t].set(self.functional_metrics[metric](potential, t + 1))
                    elif "interaction" in metric and self.gt_interaction is not None:
                        errors[metric] = errors[metric].at[t].set(self.functional_metrics[metric](interaction, t + 1))
                    elif "beta" in metric and self.gt_beta is not None:
                        errors[metric] = errors[metric].at[t].set(self.functional_metrics[metric](beta, t + 1))
                else:
                    print(f"Skipped {metric} metric since it is impossible to use it in this experiment.")

        return errors

    def map_errors_one_step_ahead(self, solver: str, model, state) -> dict[str, jnp.ndarray]:
        if "multimap" in solver:
            rho_list = []
            rho_pred_gt_list = []

            for t in range(self.T):
                rho = self.trajectory[t]

                predictions_gt = jax.vmap(
                    lambda x: self.transport_maps[t].apply({"params": self.transport_params[t]}, x)
                )(rho)

                rho_list.append(rho)
                rho_pred_gt_list.append(predictions_gt)

            # Model prediction using learned OT maps
            rho = jnp.stack(rho_list)  # shape: (T, N, D)
            rho_pred = jnp.swapaxes(
                jax.vmap(model.model_otmaps.apply, in_axes=(None, 0))(
                    {"params": state.otmaps.params}, jnp.swapaxes(rho, 0, 1)
                ),
                1,
                0,
            )  # shape: (T, N, D)

            errors = jnp.array(
                [
                    l2_distance(rho_pred_gt_list[t], rho_pred[t]) / jnp.var(self.trajectory[t], axis=0, ddof=1).sum()
                    for t in range(self.T)
                ]
            )

            return {"L2_map": errors}
        else:
            errors_list = []
            for t in range(self.T):
                rho = self.trajectory[t]
                rho_next_gt = jax.vmap(
                    lambda x: self.transport_maps[t].apply({"params": self.transport_params[t]}, x)
                )(rho)

                _, _, _, otmap = state
                pushforward = model._get_pushforward(otmap.params, t)
                rho_next = jax.vmap(pushforward)(rho)

                errors_list.append(
                    l2_distance(rho_next_gt, rho_next) / jnp.var(self.trajectory[t], axis=0, ddof=1).sum()
                )

            return {"L2_map": jnp.array(errors_list)}
