import os

import numpy as np
import pandas as pd
import jax.numpy as jnp

import matplotlib.pyplot as plt
from tqdm.auto import tqdm


from cfd_data import SITE_ORDER
from cfd_sim import simulate_5cycles_then_sample

# =============================================================================
# VISUALIZATION
# =============================================================================

def plot_posterior_marginals(theta_g_samples, theta_l_samples, true_theta_g, true_theta_l, output_dir):
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    param_names = [r'$\log \beta_{scale}$', r'$\log \mu$', r'$\log Q_{in}$']
    for i, (ax, name) in enumerate(zip(axes, param_names)):
        samples = theta_g_samples[:, i]
        ax.hist(samples, bins=30, density=True, alpha=0.7, color='steelblue', edgecolor='white')
        ax.axvline(true_theta_g[i], color='red', linestyle='--', linewidth=2, label='True')
        ax.axvline(np.mean(samples), color='green', linestyle='-', linewidth=2, label='Mean')
        ax.set_xlabel(name, fontsize=12)
        ax.set_ylabel('Density' if i == 0 else '')
        ax.legend()
        ax.set_title(f'Posterior: {name}')
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'posterior_global.png'), dpi=150, bbox_inches='tight')
    plt.close()

    n_terminals = theta_l_samples.shape[1]
    fig, axes = plt.subplots(n_terminals, 2, figsize=(10, 3 * n_terminals))
    for s in range(n_terminals):
        for j, param_name in enumerate([r'$\log R_T$', r'$\log C_T$']):
            ax = axes[s, j]
            samples = theta_l_samples[:, s, j]
            ax.hist(samples, bins=30, density=True, alpha=0.7, color='steelblue', edgecolor='white')
            ax.axvline(true_theta_l[s, j], color='red', linestyle='--', linewidth=2, label='True')
            ax.axvline(np.mean(samples), color='green', linestyle='-', linewidth=2, label='Mean')
            ax.set_xlabel(param_name)
            ax.set_title(f'{SITE_ORDER[s]}: {param_name}')
            if s == 0 and j == 0:
                ax.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'posterior_local.png'), dpi=150, bbox_inches='tight')
    plt.close()


def plot_training_losses(all_losses, output_dir):

    n_rounds = len(all_losses)
    fig, axes = plt.subplots(n_rounds, 2, figsize=(12, 4 * n_rounds))
    if n_rounds == 1:
        axes = axes.reshape(1, -1)
    for r, (train_local, val_local, train_global, val_global) in enumerate(all_losses):
        ax = axes[r, 0]
        ax.plot(train_local, label='Train', alpha=0.8)
        ax.plot(val_local, label='Val', alpha=0.8)
        ax.set_xlabel('Iteration')
        ax.set_ylabel('Loss')
        ax.set_title(f'Round {r}: Local Likelihood')
        ax.legend()
        ax.set_yscale('log')
        ax.grid(True, alpha=0.3)
        ax = axes[r, 1]
        if len(train_global) > 0:
            ax.plot(train_global, label='Train', alpha=0.8)
            ax.plot(val_global, label='Val', alpha=0.8)
        ax.set_xlabel('Iteration')
        ax.set_ylabel('Loss')
        ax.set_title(f'Round {r}: Global Posterior')
        ax.legend()
        ax.set_yscale('log')
        ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'training_losses.png'), dpi=150, bbox_inches='tight')
    plt.close()


def plot_posterior_predictive(net, base, config, theta_g_samples, theta_l_samples, y_obs, output_dir, n_ppc=20):
    n_terminals = len(SITE_ORDER)
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    axes = axes.flatten()
    t_grid = np.linspace(0, base.T, config.N_t, endpoint=False)
    for s_idx, (ax, site_name) in enumerate(zip(axes, SITE_ORDER)):
        y_site = y_obs[0, s_idx, :, 0]
        ax.plot(t_grid, y_site, 'k-', linewidth=2, label='Observed', zorder=10)
        indices = np.random.choice(len(theta_g_samples), min(n_ppc, len(theta_g_samples)), replace=False)
        for idx in tqdm(indices, desc=f"PPC {site_name}", leave=False):
            theta_g_i = jnp.array(theta_g_samples[idx])
            theta_l_i = jnp.array(theta_l_samples[idx])
            try:
                resampled = simulate_5cycles_then_sample(net, base, theta_g_i, theta_l_i, N_t=config.N_t, nx=config.nx, dt_init=config.dt_init)
                y_pred = resampled[site_name]['q']
                ax.plot(t_grid, y_pred, 'b-', alpha=0.2, linewidth=0.5)
            except Exception:
                pass
        ax.set_xlabel('Time (s)')
        ax.set_ylabel('Flow (m³/s)')
        ax.set_title(f'{site_name}')
        ax.grid(True, alpha=0.3)
        if s_idx == 0:
            ax.plot([], [], 'b-', alpha=0.5, label='Posterior predictive')
            ax.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'posterior_predictive.png'), dpi=150, bbox_inches='tight')
    plt.close()

def create_summary_table(theta_g_samples, theta_l_samples, true_theta_g, true_theta_l, output_dir):
    rows = []
    prior_g_std = [0.3, 0.2, 0.2]
    prior_l_std = 0.4
    global_names = ['log_beta', 'log_mu', 'log_Qin']
    for i, name in enumerate(global_names):
        samples = theta_g_samples[:, i]
        rows.append({
            'Parameter': name, 'Site': 'Global', 'True': f'{true_theta_g[i]:.3f}',
            'Mean': f'{np.mean(samples):.3f}', 'Std': f'{np.std(samples):.3f}',
            'Prior Std': f'{prior_g_std[i]:.3f}',
            '2.5%': f'{np.percentile(samples, 2.5):.3f}', '97.5%': f'{np.percentile(samples, 97.5):.3f}',
            'Covers True': 'Yes' if np.percentile(samples, 2.5) <= true_theta_g[i] <= np.percentile(samples, 97.5) else 'No'
        })
    local_names = ['log_R_T', 'log_C_T']
    for s_idx, site_name in enumerate(SITE_ORDER):
        for j, param_name in enumerate(local_names):
            samples = theta_l_samples[:, s_idx, j]
            true_val = true_theta_l[s_idx, j]
            rows.append({
                'Parameter': param_name, 'Site': site_name, 'True': f'{true_val:.3f}',
                'Mean': f'{np.mean(samples):.3f}', 'Std': f'{np.std(samples):.3f}',
                'Prior Std': f'{prior_l_std:.3f}',
                '2.5%': f'{np.percentile(samples, 2.5):.3f}', '97.5%': f'{np.percentile(samples, 97.5):.3f}',
                'Covers True': 'Yes' if np.percentile(samples, 2.5) <= true_val <= np.percentile(samples, 97.5) else 'No'
            })
    df = pd.DataFrame(rows)
    df.to_csv(os.path.join(output_dir, 'posterior_summary.csv'), index=False)
    return df