from __future__ import annotations

import csv
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, List, Tuple, Dict, Optional

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import FixedLocator, FormatStrFormatter


# -----------------------------
# Constants
# -----------------------------
COL_PARTICIPANT = "Participant"
COL_RHO_1 = "rho = 1"
COL_RHO_4_3 = "rho = 4/3"

NUM_PARTICIPANTS = 15
NUM_SPLITS = 2


# -----------------------------
# Data structures
# -----------------------------
@dataclass(frozen=True)
class Entry:
    """Single measurement row for a participant & split."""
    participant: int   # 1-based index for np_data; real_data uses 0-based in CSV; normalized later.
    split: int         # 1 or 2
    rho_1: float       # value for rho = 1
    rho_4_3: float     # value for rho = 4/3

    @property
    def diff(self) -> float:
        return self.rho_4_3 - self.rho_1


# -----------------------------
# Utilities
# -----------------------------
def index_to_tuple(n: int) -> Tuple[int, int]:
    """
    Convert a flat batch index into (participant, split).

    The mapping used by the original code:
        participant = n // 2 + 1
        split       = n % 2 + 1
    """
    row = n // 2 + 1
    col = n % 2 + 1
    return row, col


def _read_csv_rows(file_path: Path) -> Iterable[Dict[str, str]]:
    with file_path.open(newline="") as csvfile:
        yield from csv.DictReader(csvfile)


# -----------------------------
# Loaders
# -----------------------------
def load_real_data_csv(file_path: str | Path, split: int) -> List[Entry]:
    """
    Load ground-truth data (SVBMC ELBO) for a specific split.

    The CSV has:
        - "Participant": 0-based participant index in the file
        - "rho = 1", "rho = 4/3": float values

    We normalize 'participant' to be 1-based to align with np_data later.
    """
    p = Path(file_path)
    entries: List[Entry] = []

    for row in _read_csv_rows(p):
        participant_zero_based = int(row[COL_PARTICIPANT])
        # Normalize to 1-based for consistent lookups later.
        participant_one_based = participant_zero_based + 1

        entries.append(
            Entry(
                participant=participant_one_based,
                split=split,
                rho_1=float(row[COL_RHO_1]),
                rho_4_3=float(row[COL_RHO_4_3]),
            )
        )
    return entries


def load_np_data_csv(file_path: str | Path) -> List[Entry]:
    """
    Load Fast-AR results (log-likelihood means) and map batch_index to (participant, split).

    Expected columns:
        - "batch_index": int
        - "engine_B_joint_mean" -> rho = 1
        - "engine_A_joint_mean" -> rho = 4/3
    """
    p = Path(file_path)
    entries: List[Entry] = []

    for _, row in enumerate(_read_csv_rows(p)):
        participant, split = index_to_tuple(int(row["batch_index"]))
        entries.append(
            Entry(
                participant=participant,
                split=split,
                rho_1=float(row["engine_B_joint_mean"]),
                rho_4_3=float(row["engine_A_joint_mean"]),
            )
        )
    return entries


# -----------------------------
# Organization helpers
# -----------------------------
def index_by_participant_split(entries: Iterable[Entry]) -> Dict[Tuple[int, int], Entry]:
    """Create a dict keyed by (participant, split)."""
    return {(e.participant, e.split): e for e in entries}


def print_side_by_side(
    real_index: Dict[Tuple[int, int], Entry],
    np_index: Dict[Tuple[int, int], Entry],
    participants: int = NUM_PARTICIPANTS,
    splits: int = NUM_SPLITS,
) -> List[Tuple[Optional[float], Optional[float]]]:
    """
    Print real vs np data side-by-side and collect (gt_diff, np_diff) pairs.

    Returns a list of (ground_truth_diff, np_diff), possibly containing None for missing entries.
    """
    diffs: List[Tuple[Optional[float], Optional[float]]] = []
    values_rho_4_3: List[Tuple[Optional[float], Optional[float]]] = []
    values_rho1: List[Tuple[Optional[float], Optional[float]]] = []

    for i in range(1, participants + 1):
        for j in range(1, splits + 1):
            real_entry = real_index.get((i, j))
            np_entry = np_index.get((i, j))

            gt_diff: Optional[float] = real_entry.diff if real_entry else None
            np_diff: Optional[float] = np_entry.diff if np_entry else None

            if real_entry and np_entry:
                print(f"Participant {i}, Split {j}:")
                print(f"  SVBMC ELBO - rho=1: {real_entry.rho_1:.6g}, rho=4/3: {real_entry.rho_4_3:.6g}, diff: {gt_diff:.6g}")
                print(f"  Fast-AR LL - rho=1: {np_entry.rho_1:.6g}, rho=4/3: {np_entry.rho_4_3:.6g}, diff: {np_diff:.6g}")
                print()
            else:
                # Still report missing pairings to make gaps obvious
                print(f"Participant {i}, Split {j}:")
                if not real_entry:
                    print("  SVBMC ELBO - MISSING")
                if not np_entry:
                    print("  Fast-AR LL - MISSING")
                print()

            diffs.append((gt_diff, np_diff))
            values_rho_4_3.append((real_entry.rho_4_3 if real_entry else None,
                                   np_entry.rho_4_3 if np_entry else None))
            values_rho1.append((real_entry.rho_1 if real_entry else None,
                                np_entry.rho_1 if np_entry else None))

    return diffs, values_rho_4_3, values_rho1

def _metrics(values: np.ndarray):
    y_true = values[:, 0]
    y_pred = values[:, 1]
    rmse = np.sqrt(np.mean((y_pred - y_true) ** 2))
    ss_res = np.sum((y_pred - y_true) ** 2)
    ss_tot = np.sum((y_true - y_true.mean()) ** 2)
    r2 = 1.0 - ss_res / ss_tot if ss_tot > 0 else np.nan
    return rmse, r2

def plot_differences(gt_vs_np: np.ndarray) -> None:
    """
    Scatter-plot (ground truth diff) vs (np diff) with a diagonal reference line,
    plus MSE and R².
    """
    rmse, r2 = _metrics(gt_vs_np)

    # scatter
    plt.figure(figsize=(6,6))
    plt.scatter(gt_vs_np[:, 0], gt_vs_np[:, 1], color="blue")

    plt.xlabel("Ground Truth (SVBMC) Difference (ELBO rho=4/3 - rho=1)")
    plt.ylabel("Fast-AR Log-Likelihood Difference (LL rho=4/3 - rho=1)")
    plt.title("Comparison of Model Selection Differences")
    plt.ylim([-30, 30])
    plt.xlim([-30, 30])

    # diagonal line
    plt.plot([-30, 30], [-30, 30], color="red", linestyle="--", label="y = x")

    plt.grid(True)
    plt.legend()

    # annotate metrics
    text = f"RMSE={rmse:.3g}, R²={r2:.3f}"
    plt.gca().text(0.02, 0.98, text, transform=plt.gca().transAxes,
                   va='top', ha='left',
                   bbox=dict(facecolor='white', alpha=0.8, edgecolor='none'))

    plt.tight_layout()
    plt.savefig("bav_model_selection_comparison_ground_truth.png", dpi=200)
    plt.show()

    # print results
    print("Differences -> RMSE:", rmse, " R²:", r2)

def plot_error(values_rho_4_3: np.ndarray, values_rho1: np.ndarray) -> None:
    # individual metrics
    rmse_43, r2_43 = _metrics(values_rho_4_3)
    rmse_1, r2_1 = _metrics(values_rho1)

    # combined metrics
    combined = np.vstack([values_rho_4_3, values_rho1])
    rmse_total, r2_total = _metrics(combined)

    # scatter plot
    plt.figure(figsize=(6,6))
    plt.scatter(values_rho_4_3[:, 0], values_rho_4_3[:, 1], label='rho=4/3')
    plt.scatter(values_rho1[:, 0], values_rho1[:, 1], label='rho=1')

    # diagonal line using data-driven range
    x_all = combined[:, 0]
    y_all = combined[:, 1]
    lo = min(x_all.min(), y_all.min())
    hi = max(x_all.max(), y_all.max())
    plt.plot([lo, hi], [lo, hi], linestyle="--", color="red", label="y = x")

    plt.xlabel("Ground Truth (SVBMC) Value (ELBO)")
    plt.ylabel("Fast-AR Log-Likelihood Value (LL)")
    plt.grid(True)
    plt.legend()

    # annotate metrics
    text = (
        f"rho=4/3: RMSE={rmse_43:.3g}, R²={r2_43:.3f}\n"
        f"rho=1:   RMSE={rmse_1:.3g}, R²={r2_1:.3f}\n"
        f"TOTAL:   RMSE={rmse_total:.3g}, R²={r2_total:.3f}"
    )
    plt.gca().text(0.02, 0.98, text, transform=plt.gca().transAxes,
                   va='top', ha='left',
                   bbox=dict(facecolor='white', alpha=0.8, edgecolor='none'))

    plt.tight_layout()
    plt.savefig("bav_model_selection_comparison_ground_truth_ll.png", dpi=200)
    plt.show()

    # print results to console
    print("rho=4/3  -> RMSE:", rmse_43, " R²:", r2_43)
    print("rho=1    -> RMSE:", rmse_1,  " R²:", r2_1)
    print("TOTAL    -> RMSE:", rmse_total, " R²:", r2_total)

def plot_combined_publication(
    gt_vs_np_diff: np.ndarray,
    values_rho_4_3: np.ndarray,
    values_rho1: np.ndarray,
    filename: str = "bav_model_selection_combined.png",
) -> None:
    """
    Create a publication-ready figure with two subplots (side-by-side):
      - Left: differences scatter (Δ rho=4/3 - rho=1) with 1:1 line and metrics
      - Right: values scatter for rho=4/3 and rho=1 with 1:1 line and metrics

    Inputs are the same arrays expected by the existing helpers:
      - `gt_vs_np_diff`: Nx2 of [ground_truth_diff, np_diff]
      - `values_rho_4_3`: Mx2 of [ground_truth_value, np_value] for rho=4/3
      - `values_rho1`: Kx2 of [ground_truth_value, np_value] for rho=1
    """
    # Compute metrics
    rmse_diff, r2_diff = _metrics(gt_vs_np_diff) if gt_vs_np_diff.size else (np.nan, np.nan)
    rmse_43, r2_43 = _metrics(values_rho_4_3) if values_rho_4_3.size else (np.nan, np.nan)
    rmse_1, r2_1 = _metrics(values_rho1) if values_rho1.size else (np.nan, np.nan)
    combined_vals = np.vstack([values_rho_4_3, values_rho1]) if values_rho_4_3.size and values_rho1.size else (
        values_rho_4_3 if values_rho_4_3.size else values_rho1)
    rmse_total, r2_total = _metrics(combined_vals) if combined_vals.size else (np.nan, np.nan)

    # Figure setup for publication: half A4 width (~8.27in / 2) and slightly taller for readability
    a4_half_width_in = 8.27 / 2.0  # ≈ 4.135 inches
    fig_height_in = 3.0
    fig, (ax_r, ax_l) = plt.subplots(1, 2, figsize=(a4_half_width_in, fig_height_in), constrained_layout=True)

    # A clean, publication-friendly style (light grid, thinner spines)
    for ax in (ax_l, ax_r):
        ax.grid(True, alpha=0.25, linestyle='--', linewidth=0.5)
        # Use full box border (all spines visible) and consistent linewidth
        for spine in ax.spines.values():
            spine.set_visible(True)
            spine.set_linewidth(0.8)
        # Smaller tick labels (0.5x of previous) and fewer ticks
        ax.tick_params(axis='both', labelsize=8, length=3.0, width=0.8)
        # Tick positions will be set explicitly per subplot (min, mid, max)
        # Ensure both subplots have identical box aspect (square) so boxes match in size
        ax.set_box_aspect(1)

    # ----- Left: Differences scatter -----
    if gt_vs_np_diff.size:
        ax_l.scatter(
            gt_vs_np_diff[:, 0], gt_vs_np_diff[:, 1],
            s=40, color="#05650f", edgecolors='k', linewidths=0.6, alpha=0.9,
        )

        # symmetric limits around 0, rounded to a nice multiple
        vals = np.concatenate([gt_vs_np_diff[:, 0], gt_vs_np_diff[:, 1]])
        max_abs = np.max(np.abs(vals)) if vals.size else 1.0
        span = np.ceil((max_abs * 1.05) / 5.0) * 5.0 if max_abs > 0 else 1.0
        ax_l.set_xlim(-span, span)
        ax_l.set_ylim(-span, span)
        # Exactly three ticks: min, mid, max (rounded, no decimals)
        diff_ticks = [int(-span), 0, int(span)]
        ax_l.xaxis.set_major_locator(FixedLocator(diff_ticks))
        ax_l.yaxis.set_major_locator(FixedLocator(diff_ticks))
        ax_l.xaxis.set_major_formatter(FormatStrFormatter('%.0f'))
        ax_l.yaxis.set_major_formatter(FormatStrFormatter('%.0f'))
        ax_l.plot([-span, span], [-span, span], color="red", linestyle="--", linewidth=1.2)

    # put r2 in top right corner
    ax_l.text(0.45, 0.9, f"R² = {r2_diff:.2f}", transform=ax_l.transAxes, ha="right", fontsize=8)
    ax_l.set_aspect('equal', adjustable='box')
    # Title removed for cleaner publication figure
    ax_l.set_xlabel("Δ LML (True)", fontsize=9)
    ax_l.set_ylabel("Δ LML (TNP w/ buffer)", fontsize=9)
    ax_l.legend(frameon=False, fontsize=7, loc='lower right')
    # Metrics annotation removed for cleaner publication plot

    # ----- Right: Values scatter -----
    if values_rho_4_3.size:
        ax_r.scatter(
            values_rho_4_3[:, 0], values_rho_4_3[:, 1],
            s=40, marker='o', color="#1f77b4", edgecolors='k', linewidths=0.6, alpha=0.9, label='ρ=4/3'
        )
    if values_rho1.size:
        ax_r.scatter(
            values_rho1[:, 0], values_rho1[:, 1],
            s=40, marker='^', color="#ff7f0e", edgecolors='k', linewidths=0.6, alpha=0.9, label='ρ=1'
        )

    if combined_vals.size:
        x_all = combined_vals[:, 0]
        y_all = combined_vals[:, 1]
        lo = np.min([x_all.min(), y_all.min()])
        hi = np.max([x_all.max(), y_all.max()])
        pad = 0.05 * (hi - lo) if hi > lo else 1.0
        minv = lo - pad
        maxv = hi + pad
        # Round limits to nearest integers for clean tick labels
        minv_r = np.floor(minv)
        maxv_r = np.ceil(maxv)
        if maxv_r <= minv_r:
            maxv_r = minv_r + 1.0
        midv_r = np.round((minv_r + maxv_r) / 2.0)
        ax_r.set_xlim(minv_r, maxv_r)
        ax_r.set_ylim(minv_r, maxv_r)
        # Exactly three ticks: min, mid, max (rounded, no decimals)
        val_ticks = [minv_r, midv_r, maxv_r]
        ax_r.xaxis.set_major_locator(FixedLocator(val_ticks))
        ax_r.yaxis.set_major_locator(FixedLocator(val_ticks))
        ax_r.xaxis.set_major_formatter(FormatStrFormatter('%.0f'))
        ax_r.yaxis.set_major_formatter(FormatStrFormatter('%.0f'))
        ax_r.plot([minv_r, maxv_r], [minv_r, maxv_r], color="red", linestyle="--", linewidth=1.2)
    # put r2 in top right corner
    ax_r.text(0.45, 0.9, f"R² = {r2_total:.2f}", transform=ax_r.transAxes, ha="right", fontsize=8)
    ax_r.set_aspect('equal', adjustable='box')
    # Title removed for cleaner publication figure
    ax_r.set_xlabel("LML (True)", fontsize=9)
    ax_r.set_ylabel("LML (TNP w/ buffer)", fontsize=9)
    ax_r.legend(frameon=False, fontsize=7, loc='lower right', ncol=1)
    # Metrics annotation removed for cleaner publication plot

    # Save and show
    fig.savefig(filename, dpi=300, bbox_inches='tight')
    fig.savefig("bav_model_comparison.pdf", dpi=300, bbox_inches='tight')
    plt.show()

# -----------------------------
# Main
# -----------------------------
def main() -> None:
    # Load data
    real_split1 = load_real_data_csv("data/bav_real/groundtruth_split_1.csv", split=1)
    real_split2 = load_real_data_csv("data/bav_real/groundtruth_split_2.csv", split=2)
    real_data = real_split1 + real_split2

    np_data = load_np_data_csv("eval_results/bav_real_update_model/real_data_batch_stats_TNPB-K16_seed1.csv")

    # Index by (participant, split) for quick lookup
    real_index = index_by_participant_split(real_data)
    np_index = index_by_participant_split(np_data)

    # Print side-by-side and collect diffs
    diffs, values_rho_4_3, values_rho1 = print_side_by_side(real_index, np_index)

    # Build array for plotting; drop rows with missing values
    cleaned = np.array([[g, n] for g, n in diffs if g is not None and n is not None], dtype=float)

    # Plot
    if cleaned.size > 0:
        plot_differences(cleaned)
    else:
        print("No complete (ground truth, np) pairs available to plot.")

    values_rho_4_3_arr = np.array(
        [[g, n] for g, n in values_rho_4_3 if g is not None and n is not None], dtype=float
    )
    values_rho1_arr = np.array(
        [[g, n] for g, n in values_rho1 if g is not None and n is not None], dtype=float
    )

    plot_error(values_rho_4_3_arr, values_rho1_arr)

    # Combined, publication-ready figure
    if cleaned.size > 0:
        plot_combined_publication(cleaned, values_rho_4_3_arr, values_rho1_arr)
if __name__ == "__main__":
    main()
