import math
from dataclasses import dataclass, field
from pathlib import Path
from typing import Iterable, Optional
from collections.abc import Mapping

import pandas as pd
import torch
import numpy as np
from tqdm import tqdm

from se.configs import TrainConfig, PROJECT_ROOT
from se.models import build_model
from se.utils.psnr_plot import (
    compute_psnr_io,
    plot_psnr_curves,
    resolve_training_sigma_psnr,
)
from se.utils.train_utils import run_name
from se.utils.eval_utils import load_train_config, resolve_checkpoint_path

from model_logs import models_log


@dataclass
class EvalConfig:
    log_dirs: Optional[list[Path]] = None
    checkpoint: Optional[Path] = None
    epoch: Optional[int] = None
    test_path: list[str] | None = field(
        default_factory=lambda: [
            f"{PROJECT_ROOT}/data/Set12",
            # f"{PROJECT_ROOT}/data/Set68",
        ]
    )
    device: Optional[str] = None
    n_averages: int = 1
    sigma_values: Optional[list[float]] = field(
        default_factory=lambda: np.geomspace(5, 100, num=20).tolist()
    )
    save_name: str = "ne_wne_50.png"
    show_legend: bool = True
    psnr_log_mean_mse: bool = False
    models_names: Optional[list[str] | dict] = None
    model_colors: Optional[dict] = None
    use_cache: bool = True
    noise_level: Optional[float] = None  # when set, auto-pick logs matching this σ
    key_substr: Optional[str] = None  # e.g., "dncnn" to filter model keys


def normalize_sigma_values(values: Iterable[float]) -> list[float]:
    normalized = []
    for v in values:
        normalized.append(v / 255.0 if v > 1.0 else float(v))
    return normalized


def select_logs_by_noise(
    noise_level: float, name_contains: Optional[str] = None
) -> tuple[list[Path], list[str]]:
    """
    Pick log dirs from models_log whose key ends with the given noise level and
    optionally contains a substring (case-insensitive). Returns (dirs, short names).
    """
    target_noise = str(int(noise_level))
    substring = name_contains.lower() if name_contains else None

    selected: list[tuple[str, Path]] = []
    for key, path in models_log.items():
        if substring and substring not in key.lower():
            continue
        parts = key.rsplit("_", 1)
        if len(parts) < 2 or parts[1] != target_noise:
            continue
        selected.append((key, path))

    selected.sort(key=lambda item: item[0])
    log_dirs = [Path(p) for _, p in selected]
    names = [key for key, _ in selected]
    return log_dirs, names


def load_models_from_dirs(
    resolved_dirs: list[Path],
    provided_names: list[str],
    use_auto_names: bool,
    device: torch.device,
    checkpoint_override: Path | None,
    epoch: int | None,
) -> tuple[
    list[torch.nn.Module],
    list[str],
    list[Path],
    list[tuple[float, float]],
    list[TrainConfig],
]:
    models: list[torch.nn.Module] = []
    models_names: list[str] = []
    checkpoints: list[Path] = []
    training_sigmas: list[tuple[float, float]] = []
    cfgs: list[TrainConfig] = []

    for idx, log_dir in enumerate(resolved_dirs):
        config_path = log_dir / "config.json"
        if not config_path.is_file():
            raise FileNotFoundError(f"No config.json found in {log_dir}.")
        cfg = load_train_config(config_path)
        cfgs.append(cfg)

        checkpoint_path = resolve_checkpoint_path(log_dir, epoch, checkpoint_override)

        model = build_model(cfg)
        state_dict = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(state_dict)
        model.eval()

        models.append(model)
        checkpoints.append(checkpoint_path)
        training_sigmas.append((cfg.min_noise, cfg.max_noise))

        if use_auto_names:
            name = run_name(cfg)
        elif idx < len(provided_names):
            name = provided_names[idx]
        else:
            name = run_name(cfg)
        models_names.append(name)

    return models, models_names, checkpoints, training_sigmas, cfgs


def prepare_sigma_values(
    eval_cfg: EvalConfig,
    training_sigmas: list[tuple[float, float]],
    noise_type: str | None = None,
) -> tuple[list[float], tuple[float, float] | None, float | None]:
    resolved_noise = (noise_type or "").lower()
    if resolved_noise == "jpeg":
        sigma_values = sorted(list(eval_cfg.sigma_values or []))
        return sigma_values, None, None
    elif eval_cfg.sigma_values is not None:
        sigma_values = normalize_sigma_values(eval_cfg.sigma_values)
    else:
        sigma_values = [s / 255.0 for s in range(5, 100, 10)]

    training_sigma_single = None
    training_sigma = training_sigmas[0]
    for sigma in training_sigmas[1:]:
        if sigma != training_sigma:
            print(
                "Warning: models have different training noise ranges; "
                "skipping shaded training-sigma region."
            )
            training_sigma = None
            break
    if training_sigma is not None and math.isclose(*training_sigma):
        training_sigma_single = training_sigma[0]
        norm_train_sigma = training_sigma_single / 255.0
        if not any(
            math.isclose(norm_train_sigma, s, rel_tol=1e-6, abs_tol=1e-8)
            for s in sigma_values
        ):
            sigma_values = list(sigma_values) + [norm_train_sigma]
    sigma_values = sorted(sigma_values)
    return sigma_values, training_sigma, training_sigma_single


def cache_or_compute_psnr(
    data_dir: Path,
    dataset_save_root: Path,
    eval_cfg: EvalConfig,
    sigma_values: list[float],
    training_sigma: tuple[float, float] | None,
    models_names: list[str],
    models: list[torch.nn.Module],
    device: torch.device,
    dataset_mode: str,
    noise_type: str,
) -> tuple[
    np.ndarray,
    dict[str, list[float]],
    list[str],
    list[float],
    pd.DataFrame,
    Path,
    np.ndarray,
]:
    psnr_csv_path = dataset_save_root / f"{Path(eval_cfg.save_name).stem}_psnr.csv"
    psnr_df: pd.DataFrame | None = None
    existing_df: pd.DataFrame | None = None
    label_order: list[str] = []
    per_model_curves: dict[str, list[float]] = {}
    sigma_values_list: list[float]

    base_cols = ["sigma_8bit", "sigma", "input_psnr_db", "is_train_sigma"]
    if psnr_csv_path.is_file():
        existing_df = pd.read_csv(psnr_csv_path)
    cached_df: pd.DataFrame | None = None
    merged_df: pd.DataFrame | None = None
    if eval_cfg.use_cache and existing_df is not None:
        cached_df = existing_df.copy()
        model_cols = [c for c in cached_df.columns if c not in base_cols]
        expected_cols = models_names if models_names else model_cols
        missing_model_cols = [m for m in expected_cols if m not in model_cols]
        if missing_model_cols:
            print(
                f"Cache {psnr_csv_path} missing model columns {missing_model_cols}; recomputing requested sigmas."
            )
            cached_df = None
            existing_df = None
        else:
            # Drop any cached columns that are not in the requested set to keep
            # the CSV aligned to canonical model names.
            cached_df = cached_df[base_cols + list(expected_cols)]
            label_order = list(expected_cols)
            sigma_values_list = list(sigma_values)

            def row_for_sigma(target: float) -> pd.Series | None:
                assert cached_df is not None  # guarded by file existence
                matches = cached_df.loc[
                    cached_df["sigma"].apply(
                        lambda s: math.isclose(target, s, rel_tol=1e-6, abs_tol=1e-8)
                    )
                ]
                if matches.empty:
                    return None
                # keep last in case cache has duplicates
                return matches.iloc[-1]

            def is_missing_value(row: pd.Series, column: str) -> bool:
                val = row.get(column, np.nan)
                return pd.isna(val) or val == ""

            missing_sigmas = []
            for sigma_val in sigma_values:
                row = row_for_sigma(sigma_val)
                if row is None:
                    missing_sigmas.append(sigma_val)
                    continue
                if any(is_missing_value(row, name) for name in expected_cols):
                    missing_sigmas.append(sigma_val)

            if missing_sigmas:
                print(
                    f"Cache {psnr_csv_path} missing sigmas {missing_sigmas}; computing only those."
                )
                x_new, y_new, labels_new, input_psnr_new = compute_psnr_io(
                    models=models,
                    models_names=expected_cols,
                    data_dirs=[str(data_dir)],
                    sigma_values=missing_sigmas,
                    n_averages=eval_cfg.n_averages,
                    device=str(device),
                    dataset_mode=dataset_mode,
                    log_mean_mse=eval_cfg.psnr_log_mean_mse,
                    noise_type=noise_type,  # type: ignore
                )
                if len(missing_sigmas) != len(x_new):
                    raise RuntimeError(
                        "Mismatch between missing sigma count and computed PSNR length."
                    )
                new_rows = []
                for idx, sigma_val in enumerate(missing_sigmas):
                    sigma_8bit = sigma_val if noise_type == "jpeg" else sigma_val * 255.0
                    in_train_range = False
                    if training_sigma is not None:
                        lo, hi = training_sigma
                        lo, hi = (min(lo, hi), max(lo, hi))
                        in_train_range = (lo - 1e-6) <= sigma_8bit <= (hi + 1e-6)
                    row = {
                        "sigma_8bit": round(sigma_8bit, 6),
                        "sigma": round(float(sigma_val), 8),
                        "input_psnr_db": round(float(input_psnr_new[idx]), 6),
                        "is_train_sigma": in_train_range,
                    }
                    for name, curve in zip(labels_new, y_new):
                        row[name] = round(float(curve[idx]), 6)
                    new_rows.append(row)
                cached_df = pd.concat(
                    [cached_df, pd.DataFrame(new_rows)], ignore_index=True
                )
                cached_df = cached_df.drop_duplicates(
                    subset=["sigma"], keep="last"
                ).sort_values("sigma")

            # align to requested sigma order for plotting
            def pick_row(target: float) -> pd.Series:
                for _, row in cached_df.iterrows():
                    if math.isclose(target, row["sigma"], rel_tol=1e-6, abs_tol=1e-8):
                        return row
                raise RuntimeError(f"Sigma {target} missing even after refresh.")

            ordered_rows = [pick_row(s) for s in sigma_values_list]
            psnr_df = pd.DataFrame(ordered_rows)
            merged_df = cached_df

    input_psnr_vals: np.ndarray

    if psnr_df is None:
        x_vals, y_arrays, label_order, input_psnr_vals = compute_psnr_io(
            models=models,
            models_names=models_names,
            data_dirs=[str(data_dir)],
            sigma_values=sigma_values,
            n_averages=eval_cfg.n_averages,
            device=str(device),
            dataset_mode=dataset_mode,
            log_mean_mse=eval_cfg.psnr_log_mean_mse,
            noise_type=noise_type,  # type: ignore
        )
        sigma_values_list = list(sigma_values)
        if len(sigma_values_list) != len(x_vals):
            raise RuntimeError(
                "Mismatch between sigma_values and returned PSNR curve lengths."
            )

        per_model_curves = {
            label: [float(v) for v in curve]
            for label, curve in zip(label_order, y_arrays)
        }

        rows = []
        for idx, sigma_val in enumerate(sigma_values_list):
            sigma_8bit = sigma_val if noise_type == "jpeg" else sigma_val * 255.0
            in_train_range = False
            if training_sigma is not None:
                lo, hi = training_sigma
                lo, hi = (min(lo, hi), max(lo, hi))
                in_train_range = (lo - 1e-6) <= sigma_8bit <= (hi + 1e-6)

            row = {
                "sigma_8bit": round(sigma_8bit, 6),
                "sigma": round(float(sigma_val), 8),
                "input_psnr_db": round(float(input_psnr_vals[idx]), 6),
                "is_train_sigma": in_train_range,
            }
            for name in label_order:
                row[name] = round(float(per_model_curves[name][idx]), 6)
            rows.append(row)

        psnr_df = pd.DataFrame(rows)
    else:
        sigma_values_list = list(sigma_values)
        input_psnr_vals = psnr_df["input_psnr_db"].to_numpy()
        x_vals = (
            psnr_df["sigma"].to_numpy()
            if noise_type == "jpeg"
            else input_psnr_vals
        )

    if merged_df is None:
        merged_df = cached_df if cached_df is not None else existing_df
    if merged_df is None:
        merged_df = psnr_df
    else:
        merged_df = pd.concat([merged_df, psnr_df], ignore_index=True)
        merged_df = merged_df.sort_values("sigma").drop_duplicates(
            subset=["sigma"], keep="last"
        )

    if not per_model_curves:
        if not label_order:
            label_order = [c for c in psnr_df.columns if c not in base_cols]
        per_model_curves = {
            label: psnr_df[label].to_list() for label in label_order if label in psnr_df
        }
        label_order = [lbl for lbl in label_order if lbl in per_model_curves]
        sigma_values_list = psnr_df["sigma"].to_list()
        input_psnr_vals = psnr_df["input_psnr_db"].to_numpy()
        x_vals = (
            psnr_df["sigma"].to_numpy()
            if noise_type == "jpeg"
            else input_psnr_vals
        )

    ordered_cols = base_cols + [c for c in label_order if c not in base_cols]
    merged_df[ordered_cols].to_csv(psnr_csv_path, index=False)

    return (
        x_vals,
        per_model_curves,
        label_order,
        sigma_values_list,
        psnr_df,
        psnr_csv_path,
        input_psnr_vals,
    )


def write_train_sigma_csv(
    dataset_save_root: Path,
    save_name: str,
    training_sigma_single: float | None,
    models_names: list[str],
    training_sigmas: list[tuple[float, float]],
    sigma_values_list: list[float],
    per_model_curves: dict[str, list[float]],
):
    if training_sigma_single is None:
        return
    train_csv_path = dataset_save_root / f"{Path(save_name).stem}_train_psnr.csv"
    train_rows = []
    for name, sigma_range in zip(models_names, training_sigmas):
        lo, hi = sigma_range
        if not math.isclose(lo, hi):
            train_rows.append(
                {
                    "model": name,
                    "train_sigma_8bit": f"{lo}-{hi}",
                    "train_sigma": "",
                    "psnr_db": "",
                }
            )
            continue
        target_sigma = lo / 255.0
        try:
            idx = next(
                i
                for i, s in enumerate(sigma_values_list)
                if math.isclose(s, target_sigma, rel_tol=1e-6, abs_tol=1e-8)
            )
            psnr_val = per_model_curves[name][idx]
        except StopIteration:
            psnr_val = ""
        train_rows.append(
            {
                "model": name,
                "train_sigma_8bit": lo,
                "train_sigma": round(target_sigma, 8),
                "psnr_db": round(float(psnr_val), 6) if psnr_val != "" else "",
            }
        )
    pd.DataFrame(train_rows).to_csv(train_csv_path, index=False)


def main():
    # Default: Gaussian noise σ=50 (8-bit) on Set12, SwinIR baseline vs WNE
    noise_level = 50
    test_path = [f"{PROJECT_ROOT}/data/Set12"]

    target_models = [
        ("b_swinir", "Baseline"),
        ("wne_swinir", r"$\mathbf{WNE}$"),
    ]

    desired_keys = [f"{prefix}_{noise_level}" for prefix, _ in target_models]
    missing_keys = [k for k in desired_keys if k not in models_log]
    if missing_keys:
        raise ValueError(f"models_log is missing entries for {missing_keys}.")
    configured_dirs = [models_log[k] for k in desired_keys]

    # Use canonical names for CSV; map to pretty labels for plotting.
    canonical_names = list(desired_keys)
    pretty_labels = {
        f"{prefix}_{noise_level}": pretty for prefix, pretty in target_models
    }

    color_palette = {
        "Baseline": 5,
        "SE-arch": 6,
        "NE-arch": 7,
        r"$\mathbf{WNE}$": 8,
    }
    color_overrides = {
        pretty_labels[name]: color_palette[pretty_labels[name]]
        for name in canonical_names
        if pretty_labels[name] in color_palette
    }

    save_stem = f"swinir_sigma{noise_level}"

    eval_cfg = EvalConfig(
        log_dirs=configured_dirs,
        noise_level=noise_level,
        save_name=f"{save_stem}.pdf",
        test_path=test_path,
        models_names=canonical_names,
        model_colors=color_overrides,
        show_legend=True,
        n_averages=20,
    )

    configured_dirs = list(eval_cfg.log_dirs or [])
    provided_names_raw = eval_cfg.models_names
    model_colors = dict(eval_cfg.model_colors or {})
    if isinstance(provided_names_raw, Mapping):
        provided_names = [str(k) for k in provided_names_raw.keys()]
        model_colors.update({str(k): v for k, v in provided_names_raw.items()})
        use_auto_names = False
    else:
        provided_names = list(provided_names_raw or [])
        use_auto_names = provided_names_raw is None

    if not configured_dirs and eval_cfg.noise_level is not None:
        configured_dirs, auto_names = select_logs_by_noise(
            eval_cfg.noise_level, eval_cfg.key_substr
        )
        if not configured_dirs:
            raise ValueError(
                f"No log dirs found for noise={eval_cfg.noise_level} "
                f"with key containing '{eval_cfg.key_substr}'."
            )
        provided_names = auto_names
        use_auto_names = False

    if not configured_dirs:
        raise ValueError("Provide at least one log directory via log_dirs.")

    resolved_dirs: list[Path] = []
    for raw_dir in configured_dirs:
        candidate = Path(raw_dir).expanduser().resolve()
        if not candidate.is_dir():
            raise FileNotFoundError(f"Log directory {candidate} does not exist.")
        resolved_dirs.append(candidate)

    checkpoint_override = (
        Path(eval_cfg.checkpoint).expanduser().resolve()
        if eval_cfg.checkpoint is not None
        else None
    )
    if checkpoint_override is not None and len(resolved_dirs) > 1:
        raise ValueError(
            "checkpoint overrides are only supported when evaluating a single log dir."
        )

    device_str = (
        eval_cfg.device
        if eval_cfg.device is not None
        else ("cuda" if torch.cuda.is_available() else "cpu")
    )
    device = torch.device(device_str)
    (
        models,
        models_names,
        checkpoints,
        training_sigmas,
        cfgs,
    ) = load_models_from_dirs(
        resolved_dirs=resolved_dirs,
        provided_names=provided_names,
        use_auto_names=use_auto_names,
        device=device,
        checkpoint_override=checkpoint_override,
        epoch=eval_cfg.epoch,
    )

    if not models:
        raise RuntimeError("No models were built for evaluation.")

    reference_cfg = cfgs[0]
    dataset_mode = reference_cfg.train_dataset_type.lower()
    noise_type = reference_cfg.noise_type
    if any(cfg.train_dataset_type.lower() != dataset_mode for cfg in cfgs[1:]):
        raise ValueError(
            "All models must use the same train_dataset_type to share evaluation data."
        )
    if any(cfg.noise_type != noise_type for cfg in cfgs[1:]):
        raise ValueError(
            "All models must use the same noise_type for a shared PSNR evaluation."
        )

    if eval_cfg.test_path is not None:
        test_path = [Path(p).expanduser() for p in eval_cfg.test_path]
    else:
        test_path = [Path(p) for p in reference_cfg.test_path]
        for cfg in cfgs[1:]:
            if cfg.test_path != reference_cfg.test_path:
                print(
                    "Warning: differing test_path values detected; "
                    "using the paths from the first config."
                )
                break

    for data_dir in test_path:
        if not data_dir.is_dir():
            raise FileNotFoundError(f"Test data directory {data_dir} does not exist.")

    sigma_values, training_sigma, training_sigma_single = prepare_sigma_values(
        eval_cfg, training_sigmas, noise_type=noise_type
    )
    save_root = PROJECT_ROOT / Path("eval_logs")
    if len(models) == 1:
        print(
            f"Running PSNR sweep for checkpoint {checkpoints[0].name} "
            f"on {device} using data at {', '.join(str(p) for p in test_path)}."
        )
    else:
        joined_paths = ", ".join(str(p) for p in test_path)
        print(
            f"Running PSNR sweep for {len(models)} models on {device} "
            f"using data at {joined_paths}."
        )
        for name, checkpoint_path in zip(models_names, checkpoints):
            print(f" - {name}: {checkpoint_path.name}")

    save_root.mkdir(parents=True, exist_ok=True)

    for data_dir in tqdm(test_path):
        dataset_name = data_dir.resolve().name
        dataset_save_root = save_root / dataset_name
        dataset_save_root.mkdir(parents=True, exist_ok=True)

        save_path = dataset_save_root / eval_cfg.save_name
        (
            x_vals,
            per_model_curves,
            label_order,
            sigma_values_list,
            psnr_df,
            psnr_csv_path,
            input_psnr_vals,
        ) = cache_or_compute_psnr(
            data_dir=data_dir,
            dataset_save_root=dataset_save_root,
            eval_cfg=eval_cfg,
            sigma_values=sigma_values,
            training_sigma=training_sigma,
            models_names=models_names,
            models=models,
            device=device,
            dataset_mode=dataset_mode,
            noise_type=noise_type,
        )

        plot_labels: list[str] = []
        for name in label_order:
            label: str = pretty_labels[name] if name in pretty_labels else name
            plot_labels.append(label)
        plot_colors: dict[str, str | int] = {}
        for name, color in (model_colors or {}).items():
            label: str = pretty_labels[name] if name in pretty_labels else name
            plot_colors[label] = color
        x_axis_label = "quality factor" if noise_type == "jpeg" else None
        identity_curve = input_psnr_vals if noise_type == "jpeg" else None
        identity_line = False if identity_curve is not None else True
        x_ticks_major = list(range(5, 86, 10)) if noise_type == "jpeg" else None
        x_ticks_minor = list(range(5, 86, 5)) if noise_type == "jpeg" else None
        x_limits = (5, 85) if noise_type == "jpeg" else None
        y_limits = (24, 39) if noise_type == "jpeg" else None
        training_sigma_psnr = resolve_training_sigma_psnr(
            training_sigma=training_sigma,
            sigma_values=sigma_values_list,
            input_psnr_vals=list(input_psnr_vals if noise_type == "jpeg" else x_vals),
        )
        fig = plot_psnr_curves(
            x_vals=x_vals,
            per_model_curves=[per_model_curves[label] for label in label_order],
            label_order=plot_labels,
            training_sigma=training_sigma,
            training_sigma_psnr=training_sigma_psnr,
            save_path=save_path,
            model_colors=plot_colors,
            show_legend=eval_cfg.show_legend,
            x_label=x_axis_label,
            show_identity=identity_line,
            identity_curve=identity_curve,
            x_ticks=x_ticks_major,
            x_ticks_minor=x_ticks_minor,
            x_limits=x_limits,
            y_limits=y_limits,
        )

        write_train_sigma_csv(
            dataset_save_root=dataset_save_root,
            save_name=eval_cfg.save_name,
            training_sigma_single=training_sigma_single,
            models_names=models_names,
            training_sigmas=training_sigmas,
            sigma_values_list=sigma_values_list,
            per_model_curves=per_model_curves,
        )

        print(
            f"Saved PSNR plot to {save_path} and CSV to {psnr_csv_path} for dataset {dataset_name}."
        )


if __name__ == "__main__":
    main()
