import math
from pathlib import Path
from typing import Any, Optional

import diffrax
import jax
import jax.numpy as jnp
import numpy
import pyro
import torch
from pyro import distributions as pdist
from pyro.distributions import constraints
from pyro.distributions.transforms import biject_to

from sbibm.tasks.distributions import (
    HierarchicalDistribution,
    SummedStackTransform,
)
from sbibm.tasks.simulator import Simulator
from sbibm.tasks.task import Task


class HierarchicalSIR(Task):
    def __init__(
        self,
        n_l: int = 5,
        N: float = 1000000.0,
        I0: float = 1.0,
        R0: float = 0.0,
        days: float = 160.0,
        saveat: float = 1.0,
        total_count: int = 1000,
        summary: Optional[str] = "subsample",
        device: str = 'cpu'
    ):
        """Hierarchical SIR epidemic model

        Hierarchical extension of the SIR (Susceptible-Infected-Recovered)
        epidemic model where each observation consists of n_l local contexts
        (e.g., different regions or populations). Uses Strategy 1 (natural
        global/local split): recovery rate is global (shared across regions),
        while transmission rates are local (region-specific).

        The SIR model describes disease spread through three compartments:
        susceptible (S), infected (I), and recovered (R) individuals.

        Parameters
        ----------
        The model has (1 + n_l) parameters split into:

        **Global parameters** (dim=1):
            γ : scalar, γ > 0
                Recovery rate (mean rate at which infected individuals recover)
                Shared across all regions/contexts
                Prior: γ ~ LogNormal(log(0.125), 0.2)

        **Local parameters** (dim=n_l):
            β_i : scalar, β_i > 0 (for each region i=1,...,n_l)
                Region-specific transmission rate (contact rate × infection probability)
                Prior: β_i ~ LogNormal(log(0.4), 0.5) independently for each region

        Parameter Layout
        ----------------
        Parameters are concatenated as: θ = [γ, β_1, β_2, ..., β_{n_l}]

        For n_l regions:
        - θ[0]: Global recovery rate γ
        - θ[1]: Local transmission rate β_1 for region 1
        - θ[2]: Local transmission rate β_2 for region 2
        - ...
        - θ[n_l]: Local transmission rate β_{n_l} for region n_l

        Simulator
        ---------
        For each local region i, the model solves the SIR ODE system:

            dS/dt = -β_i · S · I / N
            dI/dt = β_i · S · I / N - γ · I
            dR/dt = γ · I

        With initial conditions: [S(0), I(0), R(0)] = [N - I0 - R0, I0, R0]

        Where:
        - S(t), I(t), R(t) are populations of susceptible, infected, and recovered
        - N is the total population (constant)
        - β_i is the region-specific transmission rate
        - γ is the global recovery rate
        - t ∈ [0, days] is time in days

        The ODE is solved using JAX's diffrax library with the Dopri5 solver.
        Observations are generated by:
        1. Solving the ODE to get I(t) trajectory
        2. Subsampling every 17 timesteps (10 observations total)
        3. Sampling counts: y_{i,t} ~ Binomial(total_count, I_{i,t} / N)

        Likelihood
        ----------
        The likelihood factorizes across independent regions and time points:

            p(y | θ) = ∏_{i=1}^{n_l} ∏_{t=1}^{10} Binomial(y_{i,t} | total_count, I_{i,t}(θ) / N)

        Where:
        - y_{i,t} is the observed count at time t for region i
        - I_{i,t}(θ) is the infected population from the ODE solution
        - The ODE solution depends on both γ and β_i

        Args
        ----
        n_l : int, default=5
            Number of local contexts/regions
        N : float, default=1000000.0
            Total population per region
        I0 : float, default=1.0
            Initial number of infected individuals per region
        R0 : float, default=0.0
            Initial number of recovered individuals per region
        days : float, default=160.0
            Number of days to simulate
        saveat : float, default=1.0
            Time step for saving ODE trajectory (in days)
        total_count : int, default=1000
            Total count for Binomial sampling of observations
        summary : str or None, default='subsample'
            Type of summary statistics ('subsample' for 10 subsampled points, None for all)

        Notes
        -----
        This model demonstrates Strategy 1 hierarchical modeling where recovery
        rate is naturally shared (diseases have similar recovery dynamics) while
        transmission varies by region (different contact patterns, interventions, etc.).

        The SIR model assumptions:
        - Homogeneous mixing within each region
        - No births, deaths, or migration
        - Permanent immunity after recovery
        - Population size N remains constant

        Implementation details:
        - Uses JAX's diffrax library with Dopri5 (Runge-Kutta) solver
        - Vectorized over samples and regions using jax.vmap
        - Observations are subsampled every 17 timesteps (160 days → 10 observations)
        - Binomial sampling adds observation noise to infected counts

        See Also
        --------
        sir : The non-hierarchical version
        """
        self.n_l = n_l
        self.N = N
        self.I0 = I0
        self.R0 = R0
        self.days = days
        self.saveat = saveat
        self.total_count = total_count
        self.summary = summary
        self.device = device

        self.dim_data_raw = int(3 * (days / saveat + 1))

        if summary is None:
            dim_data = self.dim_data_raw * n_l
        elif summary == "subsample":
            dim_data = 10 * n_l  # 10 subsampled points per region
        else:
            raise NotImplementedError

        # Observation seeds
        observation_seeds = [
            3000001,
            3000002,
            3000003,
            3000004,
            3000005,
            3000006,
            3000007,
            3000008,
            3000009,
            3000010,
        ]

        super().__init__(
            dim_parameters=1 + n_l,  # 1 global (gamma) + n_l local (beta)
            dim_data=dim_data,
            name="hierarchical_sir",
            name_display="Hierarchical SIR",
            num_observations=10,
            num_posterior_samples=10000,
            num_reference_posterior_samples=10000,
            num_simulations=[100, 1000, 10000, 100000, 1000000],
            observation_seeds=observation_seeds,
            path=Path(__file__).parent.absolute(),
        )

        # Define hierarchical prior distribution
        # Global parameter: gamma (recovery rate)
        global_dist = pdist.LogNormal(
            loc=torch.tensor([math.log(0.125)]).to(device=device),
            scale=torch.tensor([0.2]).to(device=device),
            validate_args=False
        ).to_event(1)

        # Local parameters: beta_i (transmission rates per region)
        def local_dist_fn(global_params, n_local):
            # Return LogNormal distribution for n_local transmission rates
            # Independent of global gamma
            batch_shape = global_params.shape[:-1]
            return pdist.Independent(
                pdist.LogNormal(
                    loc=torch.tensor(math.log(0.4)).to(device=device),
                    scale=torch.tensor(0.5).to(device=device),
                    validate_args=False
                ).expand(list(batch_shape) + [n_local]),
                1,
            )

        self.prior_dist = HierarchicalDistribution(
            global_dist,
            local_dist_fn,
            dim_global=1,
            dim_local=1,
            n_local=n_l,
        )
        self.prior_dist.set_default_validate_args(False)

        # Build composite transform (constrained <-> unconstrained)
        # All parameters are log-scale (positive real) -> R
        transforms_list = []

        # Gamma: LogNormal (R+) <-> R
        transforms_list.append(biject_to(constraints.positive))

        # Betas: LogNormal (R+) <-> R
        for _ in range(n_l):
            transforms_list.append(biject_to(constraints.positive))

        # Use custom wrapper to ensure Jacobian is properly summed
        self.composite_transform = SummedStackTransform(
            transforms_list, dim=-1
        )

        # Initial conditions per region
        self.u0 = torch.tensor([N - I0 - R0, I0, R0])

    def _sir_ode_func_single(
        self, t: jnp.ndarray, u: jnp.ndarray, args
    ) -> jnp.ndarray:
        """SIR ODE function for single (sample, region) pair.

        Args:
            t: Time (scalar)
            u: State vector shape (3,) with [S, I, R]
            args: Tuple of (gamma, beta) scalars

        Returns:
            du/dt: State derivatives shape (3,)
        """
        S = u[0]
        I = u[1]
        gamma = args[0]
        beta = args[1]

        dS = -beta * S * I / self.N
        dI = beta * S * I / self.N - gamma * I
        dR = gamma * I

        return jnp.array([dS, dI, dR])

    def _sir_ode_func(
        self, t: jnp.ndarray, u: jnp.ndarray, args
    ) -> jnp.ndarray:
        """Vectorized SIR ODE function for batch solving.

        For hierarchical SIR, we batch across all (sample, region) pairs.

        Args:
            t: Time (scalar)
            u: State vector shape (batch, 3) with [S, I, R]
            args: Tuple of (gamma, beta) with shape (batch, 2)

        Returns:
            du/dt: State derivatives shape (batch, 3)
        """
        S = u[:, 0]
        I = u[:, 1]
        gamma = args[:, 0]
        beta = args[:, 1]

        dS = -beta * S * I / self.N
        dI = beta * S * I / self.N - gamma * I
        dR = gamma * I

        return jnp.stack([dS, dI, dR], axis=1)

    def _solve_ode_single_site_jax(
        self, gamma: jnp.ndarray, beta: jnp.ndarray
    ) -> jnp.ndarray:
        """Solve SIR ODE for a single (sample, region) pair in JAX.

        Args:
            gamma: Recovery rate (JAX scalar)
            beta: Transmission rate (JAX scalar)

        Returns:
            Trajectory array shape (3, num_timepoints) in JAX
        """
        t_save = torch.linspace(
            0, self.days, int(self.days / self.saveat) + 1
        )
        t_save_jax = jnp.array(t_save.numpy())

        # Initial conditions
        u0 = jnp.array(self.u0.numpy(), dtype=jnp.float32)

        # ODE term
        vector_field = diffrax.ODETerm(self._sir_ode_func_single)

        # Solve ODE
        solution = diffrax.diffeqsolve(
            vector_field,
            diffrax.Dopri5(),
            t0=t_save_jax[0],
            t1=t_save_jax[-1],
            dt0=0.1,
            y0=u0,
            args=jnp.array([gamma, beta]),
            saveat=diffrax.SaveAt(ts=t_save_jax),
            max_steps=16**5,
        )

        # Permute to (3, timepoints)
        return solution.ys.T

    def _solve_ode_trajectories_jax(
        self, gamma_jax: jnp.ndarray, beta_jax: jnp.ndarray
    ) -> jnp.ndarray:
        """Solve hierarchical SIR ODE using vmapped solver in JAX.

        All computation stays in JAX land. Caller converts torch to/from JAX.

        Args:
            gamma_jax: JAX array shape (num_samples,)
            beta_jax: JAX array shape (num_samples, n_local)

        Returns:
            Trajectories in JAX shape (num_samples, n_local, 3, num_timepoints)
        """
        # Create vmapped solver: vmap over sites (axis 1) of beta
        solve_vmap_sites = jax.vmap(
            self._solve_ode_single_site_jax, in_axes=(None, 0)
        )

        # Create vmapped solver over samples: vmap over both gamma and beta
        solve_vmap_both = jax.vmap(
            solve_vmap_sites, in_axes=(0, 0)
        )

        # Solve all ODE systems (stays in JAX)
        trajectories_jax = solve_vmap_both(gamma_jax, beta_jax)

        return trajectories_jax

    def solve_ode_trajectories(self, parameters: torch.Tensor) -> torch.Tensor:
        """Solve hierarchical SIR ODE using vmapped solver.

        Solves ODE for each (sample, region) pair independently using
        JAX's vmap for vectorization.

        Args:
            parameters: Shape (num_samples, 1 + n_local) with
                [gamma, beta_1, ..., beta_n_local]

        Returns:
            Trajectories shape (num_samples, n_local, 3, num_timepoints)
        """
        # Extract gamma and beta for each sample
        gamma = parameters[:, 0]  # (num_samples,)
        beta = parameters[:, 1:]  # (num_samples, n_local)

        # Convert to JAX arrays
        gamma_jax = jnp.array(gamma.cpu().numpy())  # (num_samples,)
        beta_jax = jnp.array(beta.cpu().numpy())  # (num_samples, n_local)

        # Solve in JAX
        trajectories_jax = self._solve_ode_trajectories_jax(
            gamma_jax,
            beta_jax
        )

        # Convert back to PyTorch
        trajectories_np = numpy.asarray(trajectories_jax).copy()
        trajectories = torch.from_numpy(trajectories_np).to(torch.float32).to(
            device=self.device
        )

        return trajectories.float()

    def get_labels_parameters(self):
        """Get list containing parameter labels"""
        labels = [r"$\gamma$"]  # Global parameter
        for i in range(self.n_l):
            labels.append(rf"$\beta_{i}$")  # Local parameters
        return labels

    def get_prior(self):
        """Get prior distribution.

        Returns a callable that samples from self.prior_dist using pyro.
        """

        def prior(num_samples=1):
            return pyro.sample(
                "parameters", self.prior_dist.expand_by([num_samples])
            )

        return prior

    def get_simulator(self, max_calls: Optional[int] = None) -> Simulator:
        """Get simulator function.

        For each local context (region), solves the SIR ODE with the global
        beta and region-specific gamma_i.

        Args:
            max_calls: Maximum number of simulator calls

        Returns:
            Simulator callable
        """

        def simulator(parameters):
            num_samples = parameters.shape[0]
            # Infer n_local from parameter shape
            n_local = parameters.shape[1] - 1

            # Solve ODE for all parameters and regions
            # Vectorizes over both samples and local sites
            all_observations = self.solve_ode_trajectories(parameters)

            # Check for NaN values
            idx_contains_nan = torch.where(
                torch.isnan(
                    all_observations.reshape(num_samples, -1)
                ).any(axis=1)
            )[0]
            idx_contains_no_nan = torch.where(
                ~torch.isnan(
                    all_observations.reshape(num_samples, -1)
                ).any(axis=1)
            )[0]

            if self.summary is None:
                # Return raw observations: flatten to (num_samples, dim_data)
                return all_observations.reshape(num_samples, -1)

            elif self.summary == "subsample":
                # Infer dim_data from n_local (10 per region)
                dim_data = 10 * n_local
                data = float("nan") * torch.ones((num_samples, dim_data)).to(
                    device=self.device
                )
                if len(idx_contains_nan) == num_samples:
                    return data

                # Subsample infected population (I) every 17 time steps
                # all_observations[:, :, 1, ::17] -> (num_samples, n_local, 10)
                us_subsampled = all_observations[:, :, 1, ::17]

                # Apply Binomial sampling
                for region_idx in range(n_local):
                    us_region = us_subsampled[:, region_idx, :]
                    probs = (
                        us_region[idx_contains_no_nan, :] / self.N
                    ).clamp(0.0, 1.0)
                    data_region = pyro.sample(
                        f"data_region_{region_idx}",
                        pdist.Binomial(
                            total_count=self.total_count,
                            probs=probs,
                        ).to_event(1),
                    )
                    # Place in correct position
                    data[
                        idx_contains_no_nan,
                        region_idx * 10 : (region_idx + 1) * 10,
                    ] = data_region / self.total_count

                return data

            else:
                raise NotImplementedError

        return Simulator(task=self, simulator=simulator, max_calls=max_calls)

    def _likelihood(
        self, parameters: torch.Tensor, data: torch.Tensor, log: bool = True
    ) -> torch.Tensor:
        """Compute likelihood of data given parameters.

        Likelihood model: For each region and time point, observed
        counts follow Binomial(total_count, p=I(t)/N) where I(t) is
        the infected population from the ODE solution.

        Args:
            parameters: Parameter tensor (batch_size, dim_parameters)
            data: Data tensor (batch_size, dim_data)
            log: Whether to return log-likelihood (default: True)

        Returns:
            (Log-)likelihood values (batch_size,)
        """
        num_samples = parameters.shape[0]

        # Solve ODE for all parameters and regions
        all_observations = self.solve_ode_trajectories(parameters)

        log_likelihoods = []

        for num_sample in range(num_samples):
            data_sample = data[num_sample, :]
            sample_log_likelihood = 0.0

            # Check if this sample has NaN
            if (
                torch.isnan(
                    all_observations[num_sample].reshape(-1)
                ).any()
            ):
                sample_log_likelihood = float("-inf")
            else:
                for region_idx in range(self.n_l):
                    # Get I population for this region
                    I_trajectory = all_observations[
                        num_sample, region_idx, 1, :
                    ]

                    # Subsample every 17 time steps
                    I_subsampled = I_trajectory[::17]

                    # Get data for this region
                    data_region = data_sample[
                        region_idx * 10 : (region_idx + 1) * 10
                    ]

                    # Compute Binomial log-likelihood
                    probs = (I_subsampled / self.N).clamp(0.0, 1.0)
                    binomial_dist = pdist.Binomial(
                        total_count=self.total_count,
                        probs=probs,
                    )

                    # Sum log-likelihood across time points
                    region_log_lik = (
                        binomial_dist.log_prob(data_region).sum()
                    )
                    sample_log_likelihood += region_log_lik

            log_likelihoods.append(sample_log_likelihood)

        log_likelihoods = torch.tensor(log_likelihoods).to(device=self.device)

        if log:
            return log_likelihoods
        else:
            return torch.exp(log_likelihoods)

    def get_prior_dist(self):
        """Get prior distribution object.

        Returns:
            Prior distribution
        """
        return self.prior_dist

    def _get_transforms(
        self,
        automatic_transforms_enabled: bool = True,
        n_l: Optional[int] = None,
        **kwargs: Any,
    ):
        """Get transforms for unconstrained <-> constrained space.

        Args:
            automatic_transforms_enabled: Whether to return transforms
            n_l: Number of local contexts (defaults to self.n_l)

        Returns:
            Dictionary with 'parameters' key containing the transform
        """
        if n_l is None:
            n_l = self.n_l

        # Build composite transform (constrained <-> unconstrained)
        # All parameters are log-scale (positive real) -> R
        transforms_list = []

        # Gamma: LogNormal (R+) <-> R
        transforms_list.append(biject_to(constraints.positive))

        # Betas: LogNormal (R+) <-> R
        for _ in range(n_l):
            transforms_list.append(biject_to(constraints.positive))

        # Use custom wrapper to ensure Jacobian is properly summed
        composite_transform = SummedStackTransform(transforms_list, dim=-1)
        return {"parameters": composite_transform.inv}

    def _sample_reference_posterior(
        self,
        num_samples: int,
        num_observation: Optional[int] = None,
        observation: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Sample reference posterior - not implemented for hierarchical.

        Reference posteriors are not available for hierarchical tasks.
        Use reference-free metrics (reverse KL, LC2ST) instead.

        Args:
            num_samples: Number of samples to generate
            num_observation: Observation number
            observation: Observation tensor

        Raises:
            NotImplementedError: Reference posteriors not available
        """
        raise NotImplementedError(
            "Reference posteriors are not available for hierarchical tasks. "
            "Use reference-free metrics (reverse KL, LC2ST) instead."
        )


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(
        description="Setup hierarchical SIR task"
    )
    parser.add_argument(
        "--n_l",
        type=int,
        default=5,
        help="Number of local contexts",
    )
    args = parser.parse_args()

    task = HierarchicalSIR(n_l=args.n_l)
    task._setup(n_jobs=1, create_reference=False)
