#!/usr/bin/env python3
"""
Visualization for C2 (metric vs constraint): density-time heatmaps with geodesic quantile lines,
and bar plots for residuals/normalized deviations. 发表质量风格，与全局配色一致。
"""

from __future__ import annotations

import os
from typing import Sequence
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

plt.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['Times New Roman'],
    'font.size': 12,
    'axes.labelsize': 14,
    'axes.titlesize': 16,
    'legend.fontsize': 11,
    'xtick.labelsize': 11,
    'ytick.labelsize': 11,
    'figure.dpi': 200,
    'savefig.dpi': 400,
    'text.usetex': False,
    'axes.unicode_minus': False,
    'figure.constrained_layout.use': True
})

COLORS = {
    'primary': '#9D110E',
    'secondary': '#000000',
    'background': '#FFFFFF',
    'grid': '#E5E5E5'
}


def _save_density_time_with_quantiles(
    rho_series: jnp.ndarray,  # (K,n)
    x: jnp.ndarray,           # (n,)
    quantile_lines: jnp.ndarray,  # (Q,K)
    title: str,
    out_path: str,
):
    K, n = rho_series.shape
    fig, ax = plt.subplots(1, 1, figsize=(10, 6), facecolor=COLORS['background'])
    im = ax.imshow(
        rho_series.T,
        origin='lower',
        aspect='auto',
        extent=[0, K - 1, float(x[0]), float(x[-1])],
        cmap='viridis'
    )
    for q_line in quantile_lines:
        ax.plot(jnp.arange(K), q_line, color=COLORS['primary'], linewidth=1.5, alpha=0.9)
    ax.set_xlabel('time k')
    ax.set_ylabel('state x')
    ax.set_title(title)
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label='density')
    plt.savefig(out_path, dpi=300, bbox_inches='tight', facecolor=COLORS['background'])
    plt.close(fig)


def _save_bar(values: jnp.ndarray, title: str, ylabel: str, out_path: str, yscale: str | None = None):
    x = jnp.arange(values.shape[0])
    fig, ax = plt.subplots(1, 1, figsize=(10, 4), facecolor=COLORS['background'])
    ax.bar(x, values, color=COLORS['secondary'], alpha=0.85)
    ax.set_xlabel('interval k')
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    if yscale:
        ax.set_yscale(yscale)
    ax.grid(True, alpha=0.3, color=COLORS['grid'])
    plt.savefig(out_path, dpi=300, bbox_inches='tight', facecolor=COLORS['background'])
    plt.close(fig)


def save_c2_figures(
    rho_series_1: jnp.ndarray,
    rho_series_2: jnp.ndarray,
    geo_q_lines: jnp.ndarray,
    x: jnp.ndarray,
    residuals_1: jnp.ndarray,
    residuals_2: jnp.ndarray,
    r_norm_1: jnp.ndarray,
    r_norm_2: jnp.ndarray,
    out_dir: str = 'theoretical_verification/results',
    meta_title: str | None = None,
    peaks_k: Sequence[int] | None = None,
):
    os.makedirs(out_dir, exist_ok=True)
    title_b1 = 'C2: density-time (b1) + geodesic quantiles' if meta_title is None else f'C2: density-time (b1) [{meta_title}]'
    title_b2 = 'C2: density-time (b2) + geodesic quantiles' if meta_title is None else f'C2: density-time (b2) [{meta_title}]'
    _save_density_time_with_quantiles(rho_series_1, x, geo_q_lines, title_b1,
                                      os.path.join(out_dir, 'c2_density_time_b1.png'))
    _save_density_time_with_quantiles(rho_series_2, x, geo_q_lines, title_b2,
                                      os.path.join(out_dir, 'c2_density_time_b2.png'))
    # unified y-limits for comparability
    y_max = float(jnp.max(jnp.stack([residuals_1, residuals_2]))) * 1.05
    _save_bar(residuals_1, 'FP residual L2 per interval (b1)', 'L2 residual', os.path.join(out_dir, 'c2_fp_residual_b1.png'))
    fig, ax = plt.subplots(1, 1, figsize=(10, 4), facecolor=COLORS['background'])
    xk = np.arange(residuals_1.shape[0])
    ax.bar(xk, np.array(residuals_1), color=COLORS['secondary'], alpha=0.85)
    ax.axhline(float(jnp.mean(residuals_1)), color=COLORS['grid'], linestyle='--', linewidth=1.5)
    if peaks_k is not None:
        for k in peaks_k:
            ax.text(k, float(residuals_1[k]) * 1.02, f"{float(residuals_1[k]):.2f}", ha='center', va='bottom', fontsize=10)
    ax.set_ylim(0, y_max)
    ax.set_xlabel('interval k'); ax.set_ylabel('L2 residual'); ax.set_title('FP residual L2 per interval (b1)')
    ax.grid(True, alpha=0.3, color=COLORS['grid'])
    plt.savefig(os.path.join(out_dir, 'c2_fp_residual_b1.png'), dpi=400, bbox_inches='tight', facecolor=COLORS['background'])
    plt.close(fig)

    fig, ax = plt.subplots(1, 1, figsize=(10, 4), facecolor=COLORS['background'])
    xk = np.arange(residuals_2.shape[0])
    ax.bar(xk, np.array(residuals_2), color=COLORS['secondary'], alpha=0.85)
    ax.axhline(float(jnp.mean(residuals_2)), color=COLORS['grid'], linestyle='--', linewidth=1.5)
    if peaks_k is not None:
        for k in peaks_k:
            ax.text(k, float(residuals_2[k]) * 1.02, f"{float(residuals_2[k]):.2f}", ha='center', va='bottom', fontsize=10)
    ax.set_ylim(0, y_max)
    ax.set_xlabel('interval k'); ax.set_ylabel('L2 residual'); ax.set_title('FP residual L2 per interval (b2)')
    ax.grid(True, alpha=0.3, color=COLORS['grid'])
    plt.savefig(os.path.join(out_dir, 'c2_fp_residual_b2.png'), dpi=400, bbox_inches='tight', facecolor=COLORS['background'])
    plt.close(fig)

    # normalized residuals (log scale), force same y-limits
    rn_all = jnp.log10(jnp.maximum(jnp.concatenate([r_norm_1, r_norm_2]), 1e-9))
    y_min_log = float(jnp.min(rn_all)) - 0.2
    y_max_log = float(jnp.max(rn_all)) + 0.2
    def save_norm(values, title, out):
        fig, ax = plt.subplots(1, 1, figsize=(10, 4), facecolor=COLORS['background'])
        ax.bar(np.arange(values.shape[0]), np.array(values), color=COLORS['secondary'], alpha=0.85)
        ax.set_yscale('log')
        ax.set_ylim(10 ** y_min_log, 10 ** y_max_log)
        ax.set_xlabel('interval k'); ax.set_ylabel('normalized residual (log)'); ax.set_title(title)
        ax.grid(True, which='both', alpha=0.3, color=COLORS['grid'])
        plt.savefig(out, dpi=400, bbox_inches='tight', facecolor=COLORS['background'])
        plt.close(fig)
    save_norm(r_norm_1, 'Normalized FP residual (b1)', os.path.join(out_dir, 'c2_mid_dev_b1.png'))
    save_norm(r_norm_2, 'Normalized FP residual (b2)', os.path.join(out_dir, 'c2_mid_dev_b2.png'))

def save_density_overlay(
    rho_series_1: jnp.ndarray,
    rho_series_2: jnp.ndarray,
    x: jnp.ndarray,
    geo_q_lines: jnp.ndarray,
    peaks_k: Sequence[int] | None,
    title: str,
    out_path: str,
):
    K = rho_series_1.shape[0]
    fig, ax = plt.subplots(1, 1, figsize=(10, 6), facecolor=COLORS['background'])
    # base: b1 density
    im = ax.imshow(rho_series_1.T, origin='lower', aspect='auto', extent=[0, K - 1, float(x[0]), float(x[-1])], cmap='viridis')
    # overlay median quantiles
    med1 = geo_q_lines[1]  # 0.5
    # approximate median of b2 via its central mass ridge: use weighted mean
    def mean_line(rho):
        h = x[1] - x[0]
        return jnp.sum(rho * x[None, :], axis=1) * h
    mean2 = mean_line(rho_series_2)
    ax.plot(jnp.arange(K), med1, color=COLORS['primary'], linewidth=2.0, label='geodesic q=0.5')
    ax.plot(jnp.arange(K), mean2, color='white', linewidth=2.0, alpha=0.7, label='b2 center (mean)')
    if peaks_k is not None:
        for k in peaks_k:
            ax.axvline(k, color='orange', linestyle='--', linewidth=1.2, alpha=0.9)
    ax.set_xlabel('time k'); ax.set_ylabel('state x'); ax.set_title(title)
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label='density')
    ax.legend(loc='upper left')
    plt.savefig(out_path, dpi=400, bbox_inches='tight', facecolor=COLORS['background'])
    plt.close(fig)


