import jax
jax.config.update("jax_enable_x64", False)
jax.config.update("jax_platforms", "cpu")

import os
import json
import time

import numpy as np

import jax.numpy as jnp
from jax import random
from tqdm.auto import tqdm

import diffrax
import optax
from flax import nnx

# CFD Data imports
from cfd_data import (
    SITE_ORDER, 
    BaseParams,
    create_arch_network, terminal_ids_ordered,
)

# CFD Simulation imports
from cfd_sim import (
    default_rcr, simulate_5cycles_then_sample
)

# Utils imports
from cfd_tfmpe_utils import (
    TFMPEConfig, 
    Normalizer, 
    run_diagnostics, 
    numpy_seed_from_key
)

# TFMPE imports
from tfmpe.estimators.tfmpe import TFMPE, NormalDistribution
from tfmpe.estimators.training import fit_bottom_up
from tfmpe.preprocessing.tokens import Tokens
from tfmpe.preprocessing.utils import Independence, Labeller
from tfmpe.nn.transformer import Transformer, TransformerConfig

# Visualisation imports
from cfd_tfmpe_viz import ( 
    plot_posterior_marginals, 
    plot_posterior_predictive, 
    plot_training_losses, 
    create_summary_table
)

# =============================================================================
# TFMPE INTERFACE FUNCTIONS  
# =============================================================================

def create_prior_fn(net, base, n_terminals):
    rcr_ref = default_rcr(net, base)
    tids = terminal_ids_ordered(net)
    Rt_ref = np.array([float(rcr_ref.R2[i]) for i in tids])
    C_ref = np.array([float(rcr_ref.C[i]) for i in tids])

    def prior_fn(rng, n, n_samples=1):
        k1, k2, k3, k4 = random.split(rng, 4)
        log_beta = random.normal(k1, (n_samples,)) * 0.3 + jnp.log(base.beta_scale_mean)
        log_mu = random.normal(k2, (n_samples,)) * 0.2 + jnp.log(0.004)
        log_Qin = random.normal(k3, (n_samples,)) * 0.2 + jnp.log(85.0)
        eps = random.normal(k4, (n_samples, n))
        Rt_ref_n = jnp.array([Rt_ref[i % len(Rt_ref)] for i in range(n)])
        C_ref_n = jnp.array([C_ref[i % len(C_ref)] for i in range(n)])
        log_Rt = jnp.log(Rt_ref_n) + 0.4 * eps
        log_C = jnp.log(C_ref_n) - 0.4 * eps
        return {
            "log_beta": log_beta[:, None, None].astype(jnp.float32),
            "log_mu": log_mu[:, None, None].astype(jnp.float32),
            "log_Qin": log_Qin[:, None, None].astype(jnp.float32),
            "log_Rt": log_Rt[..., None].astype(jnp.float32),
            "log_C": log_C[..., None].astype(jnp.float32),
        }, None
    return prior_fn


def create_local_fn(net, base, n_terminals):
    rcr_ref = default_rcr(net, base)
    tids = terminal_ids_ordered(net)
    Rt_ref = np.array([float(rcr_ref.R2[i]) for i in tids])
    C_ref = np.array([float(rcr_ref.C[i]) for i in tids])

    def local_fn(rng, global_samples, n):
        n_samples = global_samples["log_beta"].shape[0]
        eps = random.normal(rng, (n_samples, n))
        Rt_ref_n = jnp.array([Rt_ref[i % len(Rt_ref)] for i in range(n)])
        C_ref_n = jnp.array([C_ref[i % len(C_ref)] for i in range(n)])
        log_Rt = jnp.log(Rt_ref_n) + 0.4 * eps
        log_C = jnp.log(C_ref_n) - 0.4 * eps
        return {"log_Rt": log_Rt[..., None].astype(jnp.float32), "log_C": log_C[..., None].astype(jnp.float32)}, None
    return local_fn

def create_simulator_fn(net, base, config, n_terminals, normalizer=None):
    def simulator_fn(rng, params_dict, n):
        if hasattr(params_dict, "decode"):
            params_dict = params_dict.decode()
        log_beta = params_dict["log_beta"][:, 0, 0]
        log_mu = params_dict["log_mu"][:, 0, 0]
        log_Qin = params_dict["log_Qin"][:, 0, 0]
        log_Rt = params_dict["log_Rt"][:, :, 0]
        log_C = params_dict["log_C"][:, :, 0]
        n_samples = int(log_beta.shape[0])
        log_beta_np = np.array(log_beta)
        log_mu_np = np.array(log_mu)
        log_Qin_np = np.array(log_Qin)
        log_Rt_np = np.array(log_Rt)
        log_C_np = np.array(log_C)
        rng_np = np.random.default_rng(numpy_seed_from_key(rng))
        all_y = []
        for i in tqdm(range(n_samples), desc=f"Simulating (n={n})", leave=False):
            theta_g_i = jnp.array([log_beta_np[i], log_mu_np[i], log_Qin_np[i]])
            theta_loc_phys = np.zeros((n_terminals, 2), dtype=np.float64)
            for s in range(n_terminals):
                theta_loc_phys[s, 0] = log_Rt_np[i, s % n]
                theta_loc_phys[s, 1] = log_C_np[i, s % n]
            theta_loc_i = jnp.array(theta_loc_phys)
            try:
                resampled = simulate_5cycles_then_sample(net, base, theta_g_i, theta_loc_i, N_t=config.N_t, nx=config.nx, dt_init=config.dt_init)
                y_groups = np.zeros((n, config.N_t), dtype=np.float32)
                for g in range(n):
                    site_name = SITE_ORDER[g % n_terminals]
                    q = np.asarray(resampled[site_name]["q"], dtype=np.float32)
                    sigma = config.eta * max(float(np.max(np.abs(q))), 1e-9)
                    q_noisy = q + rng_np.normal(0.0, sigma, size=q.shape).astype(np.float32)
                    y_groups[g, :] = q_noisy
                all_y.append(y_groups)
            except Exception as e:
                print(f"Simulation {i} failed: {e}")
                all_y.append(np.full((n, config.N_t), np.nan, dtype=np.float32))
        y = jnp.asarray(np.stack(all_y, axis=0), dtype=jnp.float32)
        y = y[..., None]
        if normalizer is not None:
            y = normalizer.normalize(y)
        return {"y": y}, None
    return simulator_fn

def _repeat_context_for_sampling(y_obs, n_rep):
    y = y_obs["y"]
    if y.shape[0] == n_rep:
        return y_obs
    if y.shape[0] != 1:
        return y_obs
    return {"y": jnp.repeat(y, repeats=n_rep, axis=0)}

# =============================================================================
# DIAGNOSTICS
# =============================================================================

def run_diagnostics(net, base, config):
    print("\n" + "=" * 70)
    print("RUNNING DIAGNOSTICS")
    print("=" * 70)
    n_terminals = len(terminal_ids_ordered(net))
    prior_fn = create_prior_fn(net, base, n_terminals)
    rng = random.PRNGKey(123)
    rng, key = random.split(rng)
    params, _ = prior_fn(key, n=n_terminals, n_samples=5)
    print("\n1. Prior samples:")
    for k, v in params.items():
        print(f"   {k}: shape={v.shape}, range=[{float(v.min()):.3f}, {float(v.max()):.3f}]")
    simulator_fn = create_simulator_fn(net, base, config, n_terminals, normalizer=None)
    print("\n2. Running simulator on 3 samples...")
    rng, key = random.split(rng)
    small_params = {k: v[:3] for k, v in params.items()}
    y, _ = simulator_fn(key, small_params, n=n_terminals)
    print(f"   y shape: {y['y'].shape}")
    print(f"   y range: [{float(y['y'].min()):.2e}, {float(y['y'].max()):.2e}]")
    print("\n3. Per-site flow ranges:")
    for i, site in enumerate(SITE_ORDER):
        site_data = y['y'][:, i, :, 0]
        print(f"   {site:25s}: [{float(site_data.min()):.2e}, {float(site_data.max()):.2e}]")
    n_nans = int(jnp.sum(jnp.isnan(y['y'])))
    if n_nans > 0:
        print(f"\n   WARNING: {n_nans} NaN values in output!")
    if config.use_normalization:
        print("\n4. Testing normalization:")
        normalizer = Normalizer.fit(np.array(y['y']))
        print(f"   Normalizer means: {normalizer.means}")
        print(f"   Normalizer stds:  {normalizer.stds}")
        y_norm = normalizer.normalize(y['y'])
        print(f"   Normalized range: [{float(y_norm.min()):.2f}, {float(y_norm.max()):.2f}]")
    print("\nDiagnostics complete!")
    print("=" * 70 + "\n")
    return y


# =============================================================================
# MAIN
# =============================================================================

def run_inference(config=None):
    if config is None:
        config = TFMPEConfig()
    os.makedirs(config.output_dir, exist_ok=True)
    print("=" * 70)
    print("TFMPE HEMODYNAMICS INFERENCE (FIXED VERSION)")
    print("=" * 70)
    net = create_arch_network()
    base = BaseParams(T=1.0)
    n_terminals = len(terminal_ids_ordered(net))
    print(f"\nNetwork: {net.n_vessels} vessels, {n_terminals} terminals")
    print(f"Sites: {SITE_ORDER}")
    print(f"Config: N_t={config.N_t}, n_samples={config.n_samples_per_round}, n_iter={config.n_iter_per_round}")
    print(f"Normalization: {config.use_normalization}")

    diag_y = run_diagnostics(net, base, config)
    normalizer = None
    if config.use_normalization:
        normalizer = Normalizer.fit(np.array(diag_y['y']))
        print(f"Fitted normalizer - means: {normalizer.means}, stds: {normalizer.stds}")

    labeller = Labeller.for_keys(['log_beta', 'log_mu', 'log_Qin', 'log_Rt', 'log_C', 'y'])
    print("\nCreating interface functions...")
    prior_fn = create_prior_fn(net, base, n_terminals)
    local_fn = create_local_fn(net, base, n_terminals)
    simulator_fn = create_simulator_fn(net, base, config, n_terminals, normalizer=normalizer)

    print("Generating true parameters and observations...")
    rng = random.PRNGKey(42)
    rng, key = random.split(rng)
    true_params, _ = prior_fn(key, n=n_terminals, n_samples=1)
    true_theta_g = np.array([float(true_params['log_beta'][0, 0, 0]), float(true_params['log_mu'][0, 0, 0]), float(true_params['log_Qin'][0, 0, 0])])
    true_theta_l = np.stack([np.array(true_params['log_Rt'][0, :, 0]), np.array(true_params['log_C'][0, :, 0])], axis=-1)
    print(f"\nTrue parameters:")
    print(f"  beta_scale = {np.exp(true_theta_g[0]):.2e}")
    print(f"  mu = {np.exp(true_theta_g[1]):.4f}")
    print(f"  Q_in = {np.exp(true_theta_g[2]):.1f} mL/s")

    print("\nRunning simulation to generate observations...")
    rng, key = random.split(rng)
    y_obs_dict, _ = simulator_fn(key, true_params, n=n_terminals)
    y_obs = {'y': y_obs_dict['y'].astype(jnp.float32)}
    print(f"Observation shape: {y_obs['y'].shape}")
    print(f"Observation range: [{float(y_obs['y'].min()):.3f}, {float(y_obs['y'].max()):.3f}]")
    if normalizer is not None:
        y_obs_unnorm = normalizer.denormalize(y_obs['y'])
    else:
        y_obs_unnorm = y_obs['y']

    print("\nInitializing TFMPE model...")
    rng, key = random.split(rng)
    template_params, _ = prior_fn(key, n=10, n_samples=10)
    params_tokens = Tokens.from_pytree(template_params, sample_ndims=1, labeller=labeller)
    transformer_config = TransformerConfig(latent_dim=config.latent_dim, n_encoder=config.n_encoder, n_decoder=config.n_decoder, n_heads=config.n_heads, n_ff=config.n_ff)
    rngs = nnx.Rngs(params=random.PRNGKey(0), dropout=random.PRNGKey(1))
    transformer = Transformer(config=transformer_config, tokens=params_tokens, rngs=rngs)
    base_dist = NormalDistribution(rngs=rngs)
    tfmpe = TFMPE(vf_network=transformer, base_dist=base_dist, solver=diffrax.Dopri5(), ode_kwargs={'rtol': 1e-3, 'atol': 1e-3})
    independence = Independence(cross_local=[('log_Rt', 'y', (0, 0)), ('log_C', 'y', (0, 0))])
    optimizer = optax.adam(learning_rate=config.learning_rate)
    opt = nnx.Optimizer(tfmpe, optimizer, wrt=nnx.Param)
    effective_batch_size = min(config.batch_size, config.n_samples_per_round)

    print("\n" + "=" * 70)
    print("TRAINING")
    print("=" * 70)
    print(f"  n_samples_per_round: {config.n_samples_per_round}")
    print(f"  n_iter_per_round: {config.n_iter_per_round}")
    print(f"  batch_size: {effective_batch_size}")
    print(f"  learning_rate: {config.learning_rate}")
    print("\nStarting training...")

    t_start = time.time()
    rng, key = random.split(rng)
    trained_tfmpe, all_losses = fit_bottom_up(tfmpe=tfmpe, y_obs=y_obs, simulator_fn=simulator_fn, prior_fn=prior_fn, local_fn=local_fn, global_names=['log_beta', 'log_mu', 'log_Qin'], n_groups=n_terminals, n_rounds=config.n_rounds, n_samples_per_round=config.n_samples_per_round, n_val_samples=config.n_val_samples, opt=opt, n_iter_per_round=config.n_iter_per_round, batch_size=effective_batch_size, rng=key, independence=independence, labeller=labeller)
    t_train = time.time() - t_start
    print(f"\nTraining complete! (took {t_train:.1f}s)")
    plot_training_losses(all_losses, config.output_dir)

    print("\n" + "=" * 70)
    print("POSTERIOR SAMPLING")
    print("=" * 70)
    print("Creating context tokens from observations...")
    y_obs_for_sampling = _repeat_context_for_sampling(y_obs, config.n_posterior_samples)

    context_tokens = Tokens.from_pytree(
        y_obs_for_sampling, 
        sample_ndims=1, 
        labeller=labeller,
        independence=independence 
    )
    print(f"Creating parameter template for {config.n_posterior_samples} samples...")
    params_template = {'log_beta': jnp.zeros((config.n_posterior_samples, 1, 1), dtype=jnp.float32), 'log_mu': jnp.zeros((config.n_posterior_samples, 1, 1), dtype=jnp.float32), 'log_Qin': jnp.zeros((config.n_posterior_samples, 1, 1), dtype=jnp.float32), 'log_Rt': jnp.zeros((config.n_posterior_samples, n_terminals, 1), dtype=jnp.float32), 'log_C': jnp.zeros((config.n_posterior_samples, n_terminals, 1), dtype=jnp.float32)}

    params_tokens = Tokens.from_pytree(
        params_template, 
        sample_ndims=1, 
        labeller=labeller,
        independence=independence 
    )
    # Diagnostic: verify token shapes match training
    print(f"\nDiagnostic - Token shapes:")
    print(f"  context_tokens.data.shape: {context_tokens.data.shape}")
    print(f"  context_tokens.slices: {context_tokens.slices}")
    print(f"  params_tokens.data.shape: {params_tokens.data.shape}")
    print(f"  params_tokens.slices: {params_tokens.slices}")
    print(f"  params_tokens.independence: {params_tokens.independence}")
    
    print(f"\nDrawing {config.n_posterior_samples} posterior samples...")
    with tqdm(total=1, desc="Sampling posterior") as pbar:
        posterior_samples = trained_tfmpe.sample_posterior(context=context_tokens, params=params_tokens)
        pbar.update(1)
    samples_dict = posterior_samples.decode()
    theta_g_samples = np.stack([np.array(samples_dict['log_beta'][:, 0, 0]), np.array(samples_dict['log_mu'][:, 0, 0]), np.array(samples_dict['log_Qin'][:, 0, 0])], axis=-1)
    theta_l_samples = np.stack([np.array(samples_dict['log_Rt'][:, :, 0]), np.array(samples_dict['log_C'][:, :, 0])], axis=-1)
    print(f"Posterior samples shape: theta_g={theta_g_samples.shape}, theta_l={theta_l_samples.shape}")
    print("\nPosterior vs Prior comparison:")
    for i, (name, prior_std) in enumerate(zip(['log_beta', 'log_mu', 'log_Qin'], [0.3, 0.2, 0.2])):
        post_std = np.std(theta_g_samples[:, i])
        print(f"  {name}: prior_std={prior_std:.3f}, posterior_std={post_std:.3f}, ratio={post_std/prior_std:.2f}")

    print("\n" + "=" * 70)
    print("GENERATING OUTPUTS")
    print("=" * 70)
    print("\n[1/4] Creating summary table...")
    summary_df = create_summary_table(theta_g_samples, theta_l_samples, true_theta_g, true_theta_l, config.output_dir)
    print("\nPosterior Summary:")
    print(summary_df.to_string(index=False))
    print("\n[2/4] Creating posterior marginal plots...")
    plot_posterior_marginals(theta_g_samples, theta_l_samples, true_theta_g, true_theta_l, config.output_dir)
    print("\n[3/4] Creating posterior predictive plot...")
    plot_posterior_predictive(net, base, config, theta_g_samples, theta_l_samples, np.array(y_obs_unnorm), config.output_dir, n_ppc=10)
    print("\n[4/4] Saving results...")
    np.savez(os.path.join(config.output_dir, 'posterior_samples.npz'), theta_g=theta_g_samples, theta_l=theta_l_samples, true_theta_g=true_theta_g, true_theta_l=true_theta_l, site_names=SITE_ORDER)
    config_dict = {k: v for k, v in config.__dict__.items()}
    with open(os.path.join(config.output_dir, 'config.json'), 'w') as f:
        json.dump(config_dict, f, indent=2)
    print("\n" + "=" * 70)
    print("COMPLETE!")
    print("=" * 70)
    print(f"\nOutputs saved to: {os.path.abspath(config.output_dir)}")
    return trained_tfmpe, posterior_samples, summary_df


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="TFMPE Hemodynamics Inference (Fixed)")
    parser.add_argument("--n_rounds", type=int, default=1)
    parser.add_argument("--n_samples", type=int, default=1000)
    parser.add_argument("--n_iter", type=int, default=2000)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--lr", type=float, default=3e-4)
    parser.add_argument("--n_posterior", type=int, default=500)
    parser.add_argument("--output_dir", type=str, default="tfmpe_hemo_results_fixed")
    parser.add_argument("--N_t", type=int, default=50)
    parser.add_argument("--no_normalization", action="store_true")
    args = parser.parse_args()
    config = TFMPEConfig(n_rounds=args.n_rounds, n_samples_per_round=args.n_samples, n_iter_per_round=args.n_iter, batch_size=args.batch_size, learning_rate=args.lr, n_posterior_samples=args.n_posterior, output_dir=args.output_dir, N_t=args.N_t, use_normalization=not args.no_normalization)
    run_inference(config)