import math
from typing import Any, Literal

import torch
import wandb

from eval_metrics import ess, logZ_bounds, sinkhorn_distance, mmd_median
from samplers.masked import sample_forward_trajectory, sample_backward_trajectory
from models import BaseModel
from targets import BaseTarget, GrayCodedTarget, Ising2D, Potts2D


@torch.no_grad()
def generate_eval_trajectories(
    model: BaseModel,
    target: BaseTarget,
    n_eval_samples: int,
    batch_size: int,
    masking_schedule: torch.Tensor,
    direction: Literal["fwd", "bwd"] = "fwd",
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Generate trajectories for evaluation."""
    ndim = target.ndim
    device = target.device
    if direction == "bwd":
        assert target.can_sample

    n_eval_batches = math.ceil(n_eval_samples / batch_size)
    eval_trajectories = torch.empty(
        (n_eval_samples, ndim + 1, ndim), dtype=torch.long, device=device
    )
    eval_log_density = torch.empty((n_eval_samples,), device=device)
    eval_log_rnd = torch.empty((n_eval_samples,), device=device)

    for i in range(n_eval_batches):
        start_idx = i * batch_size
        end_idx = min(start_idx + batch_size, n_eval_samples)
        bsz = end_idx - start_idx
        if direction == "fwd":
            trajectories, log_density, log_rnd, _ = sample_forward_trajectory(
                model, target, bsz, masking_schedule, no_grad=True
            )
        elif direction == "bwd":
            target_x, target_log_density = target.cached_sample(bsz)
            trajectories, log_density, log_rnd, _ = sample_backward_trajectory(
                model, target, target_x, masking_schedule, target_log_density, no_grad=True
            )
        eval_trajectories[start_idx:end_idx] = trajectories
        eval_log_density[start_idx:end_idx] = log_density
        eval_log_rnd[start_idx:end_idx] = log_rnd

    return eval_trajectories, eval_log_density, eval_log_rnd


def evaluate_model(
    model: BaseModel,
    target: BaseTarget,
    n_eval_samples: int,
    batch_size: int,
    masking_schedule: torch.Tensor,
    prefix: str = "",
    visualise: bool = True,
    save_plots: bool = False,
    save_dir: str = "",
    epoch: int | None = None,
) -> dict[str, Any]:
    """
    Evaluate a batch of trajectories and return metrics.

    Args:
        model: The model to evaluate.
        target: The target distribution.
        n_eval_samples: Number of samples to evaluate.
        batch_size: Batch size for evaluation.
        masking_schedule: Masking schedule for evaluation.
        prefix: Optional prefix for metric names (e.g., "buffer_", "mcmc_").
        visualise: Whether to visualise the samples.
        save_plots: Whether to save the visualisation plots.
        save_dir: Directory to save the visualisation plots.
        epoch: Optional epoch number to save the visualisation plots.
    Returns:
        log_dict: Dictionary of metrics.
    """
    model.eval()
    _invtemp = target.invtemp
    target.invtemp = 1.0

    trajectories, _, fwd_log_rnd = generate_eval_trajectories(
        model, target, n_eval_samples, batch_size, masking_schedule, direction="fwd"
    )
    samples = trajectories[:, -1, :]
    if target.can_sample:
        bwd_trajectories, _, bwd_log_rnd = generate_eval_trajectories(
            model, target, n_eval_samples, batch_size, masking_schedule, direction="bwd"
        )
    else:
        bwd_log_rnd = None

    log_dict = {}

    # ESS
    log_dict[f"{prefix}ESS"] = ess(fwd_log_rnd)

    # LogZ Bounds
    elbo, iwelbo, eubo = logZ_bounds(fwd_log_rnd, bwd_log_rnd)
    log_dict[f"{prefix}ELBO"] = elbo
    log_dict[f"{prefix}IWELBO"] = iwelbo
    if eubo != float("nan"):
        log_dict[f"{prefix}EUBO"] = eubo

    # Evaluate Samples (Sinkhorn, MMD, etc)
    if target.can_sample:
        log_dict.update(
            evaluate_samples(
                target,
                samples,
                prefix=prefix,
                visualise=visualise,
                save_plots=save_plots,
                save_dir=save_dir,
                epoch=epoch,
            )
        )

    model.train()
    target.invtemp = _invtemp
    return log_dict


def evaluate_samples(
    target: BaseTarget,
    samples: torch.Tensor,
    prefix: str = "",
    visualise: bool = True,
    save_plots: bool = False,
    save_dir: str = "",
    epoch: int | None = None,
) -> dict[str, Any]:
    """
    Evaluate a batch of samples and return metrics.

    Args:
        target: The target distribution.
        samples: Tensor of samples to evaluate.
        prefix: Optional prefix for metric names (e.g., "buffer_", "mcmc_").
        visualise: Whether to visualise the samples.
        save_plots: Whether to save the visualisation plots.
        save_dir: Directory to save the visualisation plots.
        epoch: Optional epoch number to save the visualisation plots.

    Returns:
        log_dict: Dictionary of metrics.
    """
    assert target.can_sample
    target_samples, _ = target.cached_sample(n=samples.shape[0])

    log_dict = {}

    # Compute Sinkhorn distance
    log_dict[f"{prefix}Sinkhorn_hamming"] = sinkhorn_distance(
        target_samples, samples, epsilon=1e-3, cost_fn="hamming"
    )
    # Compute MMD
    log_dict[f"{prefix}MMD"] = mmd_median(target_samples, samples)

    # Compute Sinkhorn distance & MMD in continuous spaces
    if isinstance(target, GrayCodedTarget):
        samples_conti = target._binary_to_continuous(samples)
        target_samples_conti = target._binary_to_continuous(target_samples)

        # Compute Sinkhorn distance
        log_dict[f"{prefix}Sinkhorn_conti"] = sinkhorn_distance(
            target_samples_conti, samples_conti, epsilon=1e-3, cost_fn="l2"
        )
        # Compute MMD
        log_dict[f"{prefix}MMD_conti"] = mmd_median(target_samples_conti, samples_conti)

    if isinstance(target, (Ising2D, Potts2D)):
        # Compute Magnetization Error
        mag_error = target.magnetization_error(samples, target_samples)
        log_dict[f"{prefix}MagnetizationError"] = mag_error
        # compute two-point correlation error
        tpc_error = target.two_point_correlation_error(samples, target_samples)
        log_dict[f"{prefix}TwoPointCorrelationError"] = tpc_error

    # Visualise samples
    if visualise:
        log_dict.update(
            visualise_samples(
                target,
                samples,
                prefix=prefix,
                save_plots=save_plots,
                save_dir=save_dir,
                epoch=epoch,
            )
        )

    return log_dict


def visualise_samples(
    target: BaseTarget,
    samples: torch.Tensor,
    prefix: str = "",
    save_plots: bool = False,
    save_dir: str = "",
    epoch: int | None = None,
) -> dict[str, wandb.Image]:
    """
    Visualise samples and optionally save plots.

    Args:
        target: The target distribution.
        samples: Tensor of samples to visualise.
        prefix: Prefix for filenames and wandb keys (e.g., "buffer_", "mcmc_").
        save_plots: Whether to save visualization plots.
        save_dir: Directory to save plots.
        epoch: Epoch number for saving plots.

    Returns:
        log_dict: Dictionary of wandb.Image objects for logging.
    """
    img_dict = target.visualise(samples)
    figure_prefix = f"{prefix.replace('/', '_')}Figures/"
    img_dict = {f"{figure_prefix}{key}": img for key, img in img_dict.items()}

    if save_plots:
        epoch = f"epoch{epoch}" if epoch is not None else ""
        for key, img in img_dict.items():
            filename = f"{epoch}_{key.replace('/', '_')}.png"
            img.save(f"{save_dir}/{filename}")

    # Wrap PIL Images in wandb.Image for proper logging
    img_dict = {key: wandb.Image(img) for key, img in img_dict.items()}

    return img_dict


def print_log_dict(log_dict: dict[str, Any], prefix: str = "") -> None:
    """
    Print formatted metrics from a log dictionary.

    Args:
        log_dict: Dictionary containing metrics to log.
        prefix: Prefix for the log message.
    """
    stdouts = []
    for key, value in log_dict.items():
        if isinstance(value, (int, float)):
            stdouts.append(f"{key}: {value:.4f}")
    print(f"{prefix}: {', '.join(stdouts)}")
