"""
SEIR (Susceptible-Exposed-Infectious-Recovered) model implementation
using Flow Matching for Posterior Estimation.

This script implements a modern epidemiological model with:
- Both SFMPE and MCMC inference approaches
- Hydra configuration management

Updated to use latest package interfaces following hierarchical_brownian.py pattern.
This version performs estimation only and saves results to .npy files.
"""

import json
import logging
import time
from pathlib import Path
from jaxtyping import Array
import hydra
from omegaconf import DictConfig
from hydra.core.hydra_config import HydraConfig

from jax import numpy as jnp, random as jr, tree, jit
import tensorflow_probability.substrates.jax as tfp
from tensorflow_probability.substrates.jax import distributions as tfd

import arviz as az
import optax
from flax import nnx

from sfmpe.sfmpe import SFMPE
from sfmpe.bottom_up import train_bottom_up
from sfmpe.structured_cnf import StructuredCNF
from sfmpe.nn.transformer.transformer import Transformer

from seir_utils import (
    prior_fn, create_simulator_dist, create_simulator_fn,
    f_in_fn, apply_dequantization,
    f_in_fn_observed, _flatten, flatten_f_in,
    get_standard_bijector_specs, get_y_bijector_specs, create_pytree_bijectors,
    create_selective_prior_fn,
    create_selective_flat_bijector, flatten_selective_theta_dict,
    reconstruct_selective_theta_dict, create_selective_numpyro_seir_model,
    create_selective_sfmpe_functions
)


def run(cfg: DictConfig) -> None:
    """Main execution function."""
    logger = logging.getLogger(__name__)
    logger.info(f"Running SEIR with n_simulations={cfg.n_simulations}, "
                f"n_rounds={cfg.n_rounds}, n_epochs={cfg.n_epochs}")

    # Extract parameters
    n_timesteps = cfg.n_timesteps
    n_obs = cfg.n_obs
    n_sites = cfg.n_sites
    n_warmup = cfg.n_warmup
    n_simulations = cfg.n_simulations
    n_rounds = cfg.n_rounds
    n_epochs = cfg.n_epochs
    n_post_samples = cfg.n_post_samples

    # Set up parameters for unified selective approach
    # Full inference is just selective inference with all parameters sampled
    all_params = ['beta_0', 'alpha', 'sigma', 'A', 'T_season', 'phi']

    if (hasattr(cfg, 'inference') and
        cfg.inference is not None and
        cfg.inference.sample_params is not None):
        sample_params = cfg.inference.sample_params
        logger.info(f"Using selective inference: sampling {sample_params}")
    else:
        sample_params = all_params
        logger.info("Using full inference: sampling all parameters")

    fixed_param_names = [p for p in all_params if p not in sample_params]
    if fixed_param_names:
        logger.info(f"Fixed parameters: {fixed_param_names}")
    else:
        logger.info("No fixed parameters")

    key = jr.PRNGKey(cfg.seed)
    
    # Create functions
    simulator_dist = create_simulator_dist(n_timesteps, cfg.dt, cfg.population, cfg.I0_prop, n_warmup)
    simulator_fn = create_simulator_fn(simulator_dist)
    
    # Generate ground truth and observations
    theta_key, obs_key, f_in_key, key = jr.split(key, 4)
    
    theta_truth = prior_fn(n_sites).sample((1,), seed=theta_key)

    # For selective inference, broadcast fixed local parameters to be identical across sites
    for param_name in fixed_param_names:
        if param_name in ['A', 'T_season', 'phi']:
            # Take the first site's value and broadcast it across all sites
            single_value = theta_truth[param_name][0, 0:1]  # Shape: (1, 1)
            broadcasted_value = jnp.broadcast_to(single_value, (n_sites, 1))
            theta_truth[param_name] = theta_truth[param_name].at[0].set(broadcasted_value)

    f_in = f_in_fn(n_obs, n_sites, n_timesteps).sample((1,), seed=f_in_key)
    y_observed = simulator_fn(obs_key, theta_truth, f_in)

    # Extract fixed parameter values
    fixed_params = {}
    for param_name in fixed_param_names:
        fixed_params[param_name] = theta_truth[param_name][0]  # Remove batch dim
    if fixed_params:
        logger.info(f"Fixed parameter values: {tree.map(lambda x: jnp.squeeze(x), fixed_params)}")

    # Generate representative data for consistent Z-scaling across all bijectors
    repr_key, key = jr.split(key)
    repr_theta = prior_fn(n_sites).sample((1000,), seed=repr_key)

    # Generate prior samples for comparison plots
    prior_key, key = jr.split(key)
    selective_prior = create_selective_prior_fn(n_sites, sample_params, fixed_params)
    prior_samples_selective = selective_prior(n_sites).sample((cfg.n_prior_samples,), seed=prior_key)
    prior_samples_flat = flatten_selective_theta_dict(prior_samples_selective, sample_params)[None, ...]

    # Define bijector specifications for constrained -> unconstrained transformation
    bijector_specs = get_standard_bijector_specs()

    # Create MCMC bijector (always use selective approach)
    flat_theta_bijector = create_selective_flat_bijector(repr_theta, bijector_specs, n_sites, sample_params)

    # Create proxy functions for MCMC sampling
    def flat_prior_fn(key: Array, n_samples: int) -> Array:
        """Prior function compatible with FMPE interface"""
        selective_prior = create_selective_prior_fn(n_sites, sample_params, fixed_params)
        theta_samples = selective_prior(n_sites).sample((n_samples,), seed=key)
        return flatten_selective_theta_dict(theta_samples, sample_params)

    @jit
    def flat_simulator_log_prob(theta_flat: Array) -> Array:
        """Simulator function compatible with FMPE interface"""
        # Reconstruct full theta from selective samples + fixed values
        theta_dict = reconstruct_selective_theta_dict(theta_flat, sample_params, fixed_params, n_sites)
        # Create selective prior for log_prob calculation
        selective_prior = create_selective_prior_fn(n_sites, sample_params, fixed_params)(n_sites)
        # Extract sampled parameters for prior calculation
        theta_selective = {k: v for k, v in theta_dict.items() if k in sample_params}
        prior_p = selective_prior.log_prob(theta_selective)

        # Run simulator
        n_simulations = theta_flat.shape[0]
        f_in_matched = tree.map(
            lambda leaf: jnp.repeat(leaf, n_simulations, axis=0),
            f_in
        )
        sim_dist = simulator_dist(theta_dict, f_in_matched)
        return jnp.sum(prior_p, axis=1) + jnp.sum(sim_dist.log_prob(y_observed), axis=(1, 2))

    if cfg.method == "MCMC":
        # Train using round-based approach
        logger.info("Starting MCMC sampling")
        start_time = time.time()

        # Sample from MCMC
        n_burnin = cfg.n_simulations - cfg.n_post_samples
        sample_key, init_key, key = jr.split(key, 3)

        if cfg.mcmc.sampler == "slice":
            init_state = flat_prior_fn(init_key, cfg.mcmc.n_chains)
            kernel = tfp.mcmc.TransformedTransitionKernel(
                inner_kernel=tfp.mcmc.SliceSampler(
                    target_log_prob_fn=flat_simulator_log_prob,
                    step_size=cfg.mcmc.step_size,
                    max_doublings=cfg.mcmc.max_doublings
                ),
                bijector=flat_theta_bijector
            )

            mcmc_posterior_samples = tfp.mcmc.sample_chain(
                num_results=n_post_samples,
                num_burnin_steps=n_burnin,
                current_state=init_state,
                kernel=kernel,
                seed=sample_key
            )
            mcmc_posterior_samples = mcmc_posterior_samples.all_states
            # change axes so that it's [chain, sample, param]
            mcmc_posterior_samples = jnp.swapaxes(
                mcmc_posterior_samples,
                0,
                1
            )
        elif cfg.mcmc.sampler in ["nuts", "ess"]:
            from numpyro.infer import MCMC, NUTS
            from numpyro.infer.ensemble import ESS

            if cfg.mcmc.use_numpyro_model:
                # Use NumPyro model approach
                logger.info(f"Using NumPyro model with {cfg.mcmc.sampler} sampler")

                # Create NumPyro model (always use selective approach)
                numpyro_model = create_selective_numpyro_seir_model(
                    simulator_fn, n_sites, f_in, sample_params, fixed_params
                )

                if cfg.mcmc.sampler == "ess":
                    kernel = ESS(numpyro_model)
                    chain_method = "vectorized"
                else:
                    kernel = NUTS(
                        numpyro_model,
                        step_size=cfg.mcmc.step_size,
                        max_tree_depth=cfg.mcmc.max_tree_depth,
                        adapt_step_size=True,
                        forward_mode_differentiation=True
                    )
                    chain_method = "parallel"

                mcmc = MCMC(
                    kernel,
                    num_warmup=n_burnin,
                    num_samples=n_post_samples,
                    chain_method=chain_method,
                    num_chains=cfg.mcmc.n_chains,
                    jit_model_args=True
                )
                mcmc.run(sample_key, y_observed=y_observed)

                # Extract samples and convert to expected format
                samples = mcmc.get_samples(group_by_chain=True)

                # Reconstruct theta dictionary from selective NumPyro samples
                theta_dict = {}
                for param_name in sample_params:
                    if param_name in ['beta_0', 'alpha', 'sigma']:
                        theta_dict[param_name] = samples[param_name][:, :, None, None]
                    else:
                        theta_dict[param_name] = samples[param_name][:, :, :, None]

                # Add fixed parameters
                for param_name in fixed_param_names:
                    fixed_val = fixed_params[param_name]
                    batch_shape = samples[sample_params[0]].shape[:2]  # [n_chains, n_samples]
                    if param_name in ['beta_0', 'alpha', 'sigma']:
                        theta_dict[param_name] = jnp.broadcast_to(
                            fixed_val[None, None, :, :],
                            batch_shape + (1, 1)
                        )
                    else:
                        theta_dict[param_name] = jnp.broadcast_to(
                            fixed_val[None, None, :, :],
                            batch_shape + (n_sites, 1)
                        )

                # Convert to flat format for downstream analysis
                mcmc_posterior_samples = flatten_selective_theta_dict(theta_dict, sample_params)

            else:
                # Use existing manual log_prob approach
                logger.info(f"Using manual log_prob with {cfg.mcmc.sampler} sampler")
                init_state = flat_prior_fn(init_key, cfg.mcmc.n_chains)

                def transformed_log_prob(theta: Array) -> Array:
                    batched_theta = theta[None, ...]
                    unconstrained_theta = flat_theta_bijector.forward(batched_theta)
                    log_prob = flat_simulator_log_prob(unconstrained_theta)[0]
                    det = flat_theta_bijector.forward_log_det_jacobian(
                        batched_theta
                    )[0]
                    return log_prob + det

                if cfg.mcmc.sampler == "ess":
                    kernel = ESS(
                        potential_fn=transformed_log_prob
                    )
                    chain_method = "vectorized"
                else:
                    kernel = NUTS(
                        potential_fn=transformed_log_prob,
                        step_size=cfg.mcmc.step_size,
                        max_tree_depth=cfg.mcmc.max_tree_depth,
                        adapt_step_size=True,
                        forward_mode_differentiation=True
                    )
                    chain_method = "parallel"

                mcmc = MCMC(
                    kernel,
                    num_warmup=n_burnin,
                    num_samples=n_post_samples,
                    chain_method=chain_method,
                    num_chains=cfg.mcmc.n_chains,
                    jit_model_args=True
                )
                mcmc.run(sample_key, init_params=flat_theta_bijector.forward(init_state))
                unconstrained_samples = mcmc.get_samples(group_by_chain=True)
                mcmc_posterior_samples = flat_theta_bijector.inverse(unconstrained_samples)
        else:
            raise ValueError(f"Unknown MCMC sampler: {cfg.mcmc.sampler}")


        logger.info(f'MCMC posterior mean: {jnp.mean(mcmc_posterior_samples, axis=(0, 1))}')
        logger.info(f"MCMC posterior sampling completed in {time.time() - start_time:.2f} seconds")

    elif cfg.method == "SFMPE":
        # SFMPE implementation
        logger.info("Starting SFMPE training")
        start_time = time.time()
        
        # Apply dequantization to observed data
        deq_key, key = jr.split(key)
        y_processed = apply_dequantization(y_observed, deq_key)
        
        # Generate representative data for consistent Z-scaling across all bijectors
        repr_key, key = jr.split(key)
        # Always use selective functions for representative data generation
        selective_prior_fn, selective_local_fn, wrapped_simulator_fn, global_names, local_names = create_selective_sfmpe_functions(
            n_sites, sample_params, fixed_params, simulator_fn
        )
        # Sample only the selected parameters for representative data
        repr_theta = selective_prior_fn(n_sites).sample((1000,), seed=repr_key)

        # For representative data, use the same f_in for all samples
        repr_f_in = tree.map(lambda leaf: jnp.repeat(leaf, 1000, axis=0), f_in)
        repr_y_raw = wrapped_simulator_fn(repr_key, repr_theta, repr_f_in)
        repr_y = apply_dequantization(repr_y_raw, deq_key)

        # Create Z-scaled bijector maps and PyTreeBijectors
        # Filter bijector specs to only include sampled parameters
        theta_bijector_specs = get_standard_bijector_specs()
        selective_theta_specs = {k: v for k, v in theta_bijector_specs.items() if k in sample_params}
        y_bijector_specs = get_y_bijector_specs()
        sfmpe_theta_bijector, sfmpe_y_bijector = create_pytree_bijectors(repr_theta, repr_y, selective_theta_specs, y_bijector_specs)
        
        # Transform observations to unconstrained space
        y_unconstrained = sfmpe_y_bijector.forward(y_processed)
        
        # Create wrapped functions for train_bottom_up
        def wrapped_prior_fn(n):
            """Prior function that returns TransformedDistribution."""
            base_prior = selective_prior_fn(n)
            return tfd.TransformedDistribution(
                base_prior,
                sfmpe_theta_bijector,
                name="transformed_prior"
            )

        def wrapped_p_local(g, n):
            """Local prior function that returns TransformedDistribution."""
            base_local = selective_local_fn(g, n)
            return tfd.TransformedDistribution(
                base_local,
                sfmpe_theta_bijector,
                name="transformed_local"
            )

        def wrapped_simulator_fn_for_training(seed, theta, f_in_sample):
            """Simulator function that handles bijector transformations."""
            # Transform parameters back to constrained space
            theta_constrained = sfmpe_theta_bijector.inverse(theta)

            # Apply wrapped simulator (handles parameter reconstruction if needed)
            y_constrained = wrapped_simulator_fn(seed, theta_constrained, f_in_sample)
            y_deq = apply_dequantization(y_constrained, seed)

            # Transform outputs to unconstrained space
            return sfmpe_y_bijector.forward(y_deq)
        
        # Independence structure for structured inference (dynamic based on sampled parameters)
        local_independence = ['obs'] + local_names
        cross_local_connections = [(param, 'obs', (0, 0)) for param in local_names]
        independence = {
            'local': local_independence,  # Observations independent across time/sites
            'cross': [],
            'cross_local': cross_local_connections
        }

        # SFMPE Neural Network Setup (dynamic n_labels)
        rngs = nnx.Rngs(key)
        n_labels = len(global_names) + len(local_names) + 1  # sampled parameters + obs
        logger.info(f"Using {n_labels} labels: {len(global_names)} global + {len(local_names)} local + 1 obs")

        transformer_config = {
            'latent_dim': cfg.sfmpe.transformer.latent_dim,
            'label_dim': cfg.sfmpe.transformer.label_dim,
            'index_out_dim': cfg.sfmpe.transformer.index_out_dim,
            'n_encoder': cfg.sfmpe.transformer.n_encoder,
            'n_decoder': cfg.sfmpe.transformer.n_decoder,
            'n_heads': cfg.sfmpe.transformer.n_heads,
            'n_ff': cfg.sfmpe.transformer.n_ff,
            'dropout': cfg.sfmpe.transformer.dropout,
            'activation': nnx.relu,
        }

        nn = Transformer(
            transformer_config,
            value_dim=1,
            n_labels=n_labels,
            index_dim=1,  # Temporal indexing
            rngs=rngs
        )

        model = StructuredCNF(nn, rngs=rngs)
        estim = SFMPE(model, rngs=rngs)

        # Training
        train_key, key = jr.split(key)
        logger.info("Starting SFMPE bottom-up training")
        
        # Set up f_in function arguments based on configuration
        if cfg.f_in_sample == 'observed':
            f_in_fn_train = f_in_fn_observed
            f_in_args = (n_obs, 1, f_in)
            f_in_args_global = (n_obs, n_sites, f_in)
        elif cfg.f_in_sample == 'prior':
            f_in_fn_train = f_in_fn
            f_in_args = (n_obs, 1, n_timesteps)
            f_in_args_global = (n_obs, n_sites, n_timesteps)
        else:
            raise ValueError(f"Invalid f_in_sample: {cfg.f_in_sample}")
        
        labels, slices, masks = train_bottom_up(
            train_key,
            estim,
            wrapped_prior_fn,
            wrapped_p_local,
            wrapped_simulator_fn_for_training,  # Use training-specific wrapped simulator
            global_names,  # Dynamic global parameters
            local_names,   # Dynamic local parameters
            n_sites,
            n_rounds,
            n_simulations,
            n_epochs,
            y_unconstrained,  # Use unconstrained data
            independence,
            optimiser=optax.adam(cfg.training.learning_rate),
            batch_size=int(n_simulations * cfg.training.batch_size_fraction),
            f_in=f_in_fn_train,
            f_in_args=f_in_args,
            f_in_args_global=f_in_args_global,
            f_in_target=f_in
        )
        logger.info(f"SFMPE bottom-up training completed in {time.time() - start_time:.2f} seconds")

        # Sample SFMPE posterior
        logger.info("Sampling SFMPE posterior")
        start_time = time.time()
        
        # Create flattened f_in index for posterior sampling
        f_in_flattened = flatten_f_in(f_in, sample_params=sample_params)
        posterior = estim.sample_posterior(
            _flatten(y_processed)[..., None],
            labels,
            slices,
            masks=masks,
            n_samples=n_post_samples,
            index=f_in_flattened
        )

        # Transform posterior samples back into constrained space
        posterior = sfmpe_theta_bijector.inverse(posterior)

        # Convert SFMPE posterior to the same format as MCMC for downstream analysis
        mcmc_posterior_samples = flatten_selective_theta_dict(posterior, sample_params)[None, ...]
        
        logger.info(f'SFMPE posterior mean: {jnp.mean(mcmc_posterior_samples, axis=(0, 1))}')
        logger.info(f"SFMPE posterior sampling completed in {time.time() - start_time:.2f} seconds")
        
    else:
        raise ValueError(f"Unknown method: {cfg.method}. Choose 'MCMC' or 'SFMPE'.")

    logger.info(f'Analysing MCMC')
    start_time = time.time()
    logger.info(f"Converting MCMC posterior to az format")

    # Reconstruct only sampled parameters for ArviZ
    # First reconstruct with fixed params to get proper shapes, then filter
    post_dict_full = reconstruct_selective_theta_dict(
        mcmc_posterior_samples,
        sample_params,
        fixed_params,
        n_sites
    )
    post_dict = {k: v for k, v in post_dict_full.items() if k in sample_params}
    posterior = az.from_dict(posterior=post_dict)
    logger.info(f"Summarising MCMC posterior")
    print(az.summary(posterior))
    logger.info(f"MCMC summarisation completed in {time.time() - start_time:.2f} seconds")

    # Use Hydra's output directory
    hydra_cfg = HydraConfig.get()
    out_dir = Path(hydra_cfg.runtime.output_dir)
    
    # Save estimation results as .npy files for later visualization
    logger.info("Saving results to .npy files")
    jnp.save(out_dir / "theta_truth.npy", theta_truth)
    jnp.save(out_dir / "y_observed.npy", y_observed)
    jnp.save(out_dir / "f_in.npy", f_in)
    jnp.save(out_dir / "mcmc_posterior_samples.npy", mcmc_posterior_samples)
    jnp.save(out_dir / "prior_samples.npy", prior_samples_flat)
    
    # Save configuration parameters needed for plotting
    plot_config = {
        'n_sites': n_sites,
        'n_timesteps': n_timesteps,
        'n_warmup': n_warmup,
        'population': cfg.population,
        'I0_prop': cfg.I0_prop,
        'dt': cfg.dt
    }

    # Save as JSON for easy loading
    with open(out_dir / "plot_config.json", 'w') as f:
        json.dump(plot_config, f, indent=2)

    # Always save selective inference configuration
    selective_config = {
        'sample_params': list(sample_params),  # Convert from ListConfig to list
        'fixed_params': tree.map(lambda x: x.tolist(), fixed_params)
    }
    with open(out_dir / "selective_inference_config.json", 'w') as f:
        json.dump(selective_config, f, indent=2)
    
    logger.info("SEIR MCMC estimation completed successfully!")

@hydra.main(version_base=None, config_path="conf", config_name="seir_mcmc_config")
def main(cfg: DictConfig) -> None:
    """Main function with Hydra configuration management."""
    # Setup logging
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    logger = logging.getLogger(__name__)
    logger.info("Starting SEIR experiment")
    
    # Run the experiment
    run(cfg)

if __name__ == "__main__":
    main()
