import logging
import numpy as np
from pathlib import Path
from typing import Any
from pipeline.training.runtime_resolve import Experiment
from pipeline.training.train_configs import TrainConfig
from pipeline.eval.eval_configs import EvaluateConfig

def init_logging(experiment: Experiment, train_config: TrainConfig):
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)

    log_dir = experiment.experiment_dir / "logs"
    log_dir.mkdir(parents=True, exist_ok=True)
    file_handler = logging.FileHandler(log_dir / "training.log")
    file_handler.setLevel(logging.INFO)

    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)

    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)
    console_handler.setFormatter(formatter)

    logger.addHandler(file_handler)
    logger.addHandler(console_handler)

    if train_config.use_wandb:
        try:
            import wandb
        except ImportError as e:
            raise ImportError(
                "wandb is not installed. Install it with:\n"
                "  uv sync --extra wandb"
            ) from e
        wandb.init(
            project="Learning_Adversarial_Optimal_Transport_Regularization",
            config=train_config.__dict__,
            dir=str(experiment.experiment_dir),
            name=experiment.experiment_name + "_" + experiment.run_name,
        )
        logger.info("Initialized Weights & Biases logging.")

    logger.info(f"Experiment Directory: {experiment.experiment_dir}")
    logger.info(f"Training Configuration: {train_config}")

    return logger

def log_epoch(
    ckpt_dir: Path,
    epoch: int,
    train_metrics: dict[str, float],
    eval_metrics: dict[str, dict[str, dict[str, Any]]],
    use_wandb: bool,
):

    figure_cooldown = 100
    if not use_wandb:
        return

    import wandb

    log_dict: dict[str, object] = {
        "epoch": epoch,
    }

    figures_enabled = epoch % figure_cooldown == 0
    if figures_enabled:
        import matplotlib.pyplot as plt
        from pipeline.visualization import plot_error_evolution
        from pipeline.visualization.lorentz63_plots import (
            plot_l63_full_double,
            plot_l63_full_single,
        )
        from pipeline.visualization.metric_plots import plot_histogram_comparison

    for k, v in train_metrics.items():
        log_dict[f"train/{k}"] = float(v)

    for group, metrics in eval_metrics['metrics'].items():
        for name, value in metrics.items():
            key = f"val/{group}/{name}"

            # Scalar metrics
            if np.isscalar(value):
                log_dict[key] = float(value)

            # Vector metrics (time / component curves)
            elif isinstance(value, np.ndarray) and value.ndim == 1 and figures_enabled:
                fig_dir = ckpt_dir / f"epoch_{epoch:04d}" / "figures"
                fig_dir.mkdir(parents=True, exist_ok=True)

                # -------- per-time error curves --------
                if name.endswith("_per_time"):
                    fig = plot_error_evolution(
                        mse_per_time=value,
                        title=key,
                    )

                    fig_path = fig_dir / f"{key.replace('/', '_')}.pdf"
                    fig.savefig(fig_path, dpi=300, bbox_inches="tight", )
                    plt.close(fig)

                    wandb.log(
                        {f"{key}_figure": wandb.Image(str(fig_path), caption=key)},
                        step=epoch,
                    )

                # -------- per-component diagnostics --------
                elif name.endswith("_per_component"):
                    # Bar plot is the correct semantic object here
                    fig, ax = plt.subplots(figsize=(6, 4))
                    ax.bar(np.arange(len(value)), value)
                    ax.set_xlabel("Component")
                    ax.set_ylabel(name.replace("_per_component", ""))
                    ax.set_title(key)
                    ax.grid(True, ls="--", alpha=0.5)

                    fig_path = fig_dir / f"{key.replace('/', '_')}.pdf"
                    fig.savefig(fig_path, dpi=300, bbox_inches="tight", )
                    plt.close(fig)

                    wandb.log(
                        {f"{key}_figure": wandb.Image(str(fig_path), caption=key)},
                        step=epoch,
                    )
            else:
                pass

    wandb.log(log_dict, step=epoch)

    if figures_enabled:
        fig_dir = ckpt_dir / f"epoch_{epoch:04d}" / "figures"
        fig_dir.mkdir(parents=True, exist_ok=True)
        # ─────────────────────────────
        u_true = eval_metrics['diagnostics']["u_true"]     # (T, d)
        u_hat = eval_metrics['diagnostics']["u_hat"]
        s_true = eval_metrics['diagnostics']["s_true"]
        s_hat = eval_metrics['diagnostics']["s_hat"]
        hist_true = eval_metrics['diagnostics']["hist_true"]
        hist_hat = eval_metrics['diagnostics']["hist_hat"]

        images = []

        # # ─────────────────────────────
        # # 1. Trajectory comparison (true vs pred)
        # # ─────────────────────────────
        # fig = plot_l63_full_double(
        #     u_true,
        #     u_hat,
        #     plotted_variable_1="True trajectory",
        #     plotted_variable_2="Predicted trajectory",
        # )
        
        # p = fig_dir / "trajectory_double.pdf"
        # fig.savefig(p, dpi=150, )
        # images.append(wandb.Image(str(p), caption=p.name))

        # # ─────────────────────────────
        # # 2. Predicted trajectory only
        # # ─────────────────────────────
        # fig = plot_l63_full_single(
        #     u_hat,
        #     plotted_variable="Predicted trajectory",
        # )
        # p = fig_dir / "trajectory_pred_only.pdf"
        # fig.savefig(p, dpi=150, )
        # images.append(wandb.Image(str(p), caption=p.name))

        # # ─────────────────────────────
        # # 3. Histogram comparison
        # # ─────────────────────────────
        # fig = plot_histogram_comparison(
        #     hist_true=hist_true,
        #     hist_pred=hist_hat,
        #     title="State histogram comparison",
        # )
        # p = fig_dir / "histogram_comparison.pdf"
        # fig.savefig(p, dpi=150, )
        # images.append(wandb.Image(str(p), caption=p.name))

        # # ─────────────────────────────
        # # Log figures
        # # ─────────────────────────────
        # wandb.log(
        #     {"val/figures": images},
        #     step=epoch,
        # )

def eval_log( 
    eval_cfg: EvaluateConfig,
    *,
    wandb_run: Any,
    image_paths: list[Path],
    video_path: Path | None,
    out_dir: Path,
    ckpt_dir: Path,
) -> None:
    if wandb_run:
        import wandb

    results_file = out_dir / "results.txt"

    if results_file.exists() and wandb_run:
        wandb.save(str(results_file), base_path=str(out_dir))
    # Images
    if image_paths and wandb_run:
        wandb.log(
            {"eval/figures": [wandb.Image(str(p), caption=p.name) for p in image_paths]}
        )

    # Video
    if video_path is not None and video_path.exists() and eval_cfg.use_wandb:
        wandb.log({"eval/video": wandb.Video(str(video_path), fps=60, format="mp4")})

    # Artifact: output directory + checkpoint.pkl
    if not wandb_run:
        return
    art = wandb.Artifact(
        name=f"eval-output-{ckpt_dir.parent.parent.name}-{ckpt_dir.name}",
        type="evaluation",
        metadata={
            "checkpoint_dir": str(ckpt_dir),
            "output_dir": str(out_dir),
        },
    )
    art.add_dir(str(out_dir))
    ckpt_file = ckpt_dir / "checkpoint.pkl"
    if ckpt_file.exists():
        art.add_file(str(ckpt_file))
    wandb_run.log_artifact(art)

    wandb_run.finish()
