import math
from typing import Dict, Tuple

import numpy as np
import torch
import wandb
from torchdiffeq import odeint

try:
    import ot  # type: ignore
except ImportError:  # pragma: no cover - optional dependency
    ot = None

try:
    from geomloss import SamplesLoss  # type: ignore
except ImportError:  # pragma: no cover - optional dependency
    SamplesLoss = None

import learn_noise.utils.plotting as plot

_DEFAULT_W2_SAMPLES = 5000

_MAX_SEED = 2 ** 31 - 1


def _get_seed(base_seed: int, offset: int) -> int:
    seed = (base_seed + offset) % _MAX_SEED
    if seed <= 0:
        seed += 1
    return seed


def _get_target_cache(args) -> Dict[Tuple[str, int], torch.Tensor]:
    if not hasattr(args, "_eval_fixed_targets") or args._eval_fixed_targets is None:
        setattr(args, "_eval_fixed_targets", {})
    return args._eval_fixed_targets


def _get_uniform_cache(args) -> Dict[Tuple[int, int], torch.Tensor]:
    if not hasattr(args, "_eval_fixed_u") or args._eval_fixed_u is None:
        setattr(args, "_eval_fixed_u", {})
    return args._eval_fixed_u


def _get_energy_metric(args):
    if SamplesLoss is None:
        return None
    if not hasattr(args, "_eval_energy_metric") or args._eval_energy_metric is None:
        setattr(args, "_eval_energy_metric", SamplesLoss("energy"))
    return args._eval_energy_metric


def _to_raw_if_needed(sampler, tensor: torch.Tensor) -> torch.Tensor:
    return sampler.to_raw(tensor) if hasattr(sampler, "to_raw") else tensor


def _base_sampler(sampler):
    return sampler.base if hasattr(sampler, "base") else sampler


def _fixed_ground_truth(args, sampler, total: int, device: torch.device) -> torch.Tensor:
    cache = _get_target_cache(args)
    target_dataset = args.target_dataset if hasattr(args, "target_dataset") else "unknown"
    key = (target_dataset, total)
    if key not in cache:
        base_seed = int(args.seed) if hasattr(args, "seed") else 0
        seed = _get_seed(base_seed, 1009 + 131 * total)
        devices = [device] if device.type == "cuda" else []
        with torch.random.fork_rng(devices=devices):
            torch.manual_seed(seed)
            if device.type == "cuda" and torch.cuda.is_available():
                torch.cuda.manual_seed_all(seed)
            samples = sampler.sample(total, device=device, dtype=torch.float32)
            samples = _to_raw_if_needed(sampler, samples).cpu()
        cache[key] = samples
    return cache[key]


def _fixed_uniform(args, total: int, dim: int, offset: int = 0) -> torch.Tensor:
    cache = _get_uniform_cache(args)
    key = (total, dim)
    if key not in cache:
        base_seed = int(args.seed) if hasattr(args, "seed") else 0
        seed = _get_seed(base_seed, 2027 + 137 * total + offset)
        gen = torch.Generator()
        gen.manual_seed(seed)
        cache[key] = torch.rand((total, dim), generator=gen, dtype=torch.float32)
    return cache[key]


def _compute_w2_distance(
    samples_a: torch.Tensor,
    samples_b: torch.Tensor,
    *,
    max_samples: int = _DEFAULT_W2_SAMPLES,
) -> float | None:
    """Compute the 2-Wasserstein distance between two empirical sets."""
    if ot is None:
        return None
    if samples_a.ndim != 2 or samples_b.ndim != 2:
        raise ValueError("W2 computation expects 2D tensors (batch, dim).")

    count = int(min(max_samples, samples_a.shape[0], samples_b.shape[0]))
    if count <= 0:
        return None

    # Work on CPU/NumPy to avoid excessive GPU memory usage.
    a_np = samples_a[:count].detach().cpu().numpy().astype(np.float64, copy=False)
    b_np = samples_b[:count].detach().cpu().numpy().astype(np.float64, copy=False)
    weights = np.full(count, 1.0 / count, dtype=np.float64)
    cost = ot.dist(a_np, b_np, metric="euclidean") ** 2
    w2_sq = ot.emd2(weights, weights, cost)
    return float(math.sqrt(max(w2_sq, 0.0)))

@torch.no_grad()
def heavy_eval_batched(
    args,
    x_0, 
    ode_func, 
    sampler,
    step, 
    big_eval=False,
    device='cpu', 
    noise = None,
    quantile=None,
):
    """
    Massive eval to probe tails with VRAM-safe batching.
    - Generates eps at τ=1 in batches (Student-t base)
    - Integrates ODE to t=0, collects running NLL mean
    - Keeps a capped subset for plotting (both latent eps and generated x)
    - Logs GeomLoss Sinkhorn/MMD metrics on cached subsets for non-funnel targets
    """
    dim = args.dim
    output_dir = args.runs_dir

    device = torch.device(device)

    if big_eval:
        total    = int(args.big_eval_samples)
    else: 
        total    = int(args.eval_sample)

    if total <= 0:
        return

    batch_size    = int(args.eval_batch)

    keep  = total#int(args.eval_plot_samples)
    assert batch_size > 0, "big_eval_batch must be > 0"

    t_vals = torch.linspace(1, 0.0, args.num_steps_eval, device=device)

    nll_sum = 0.0
    seen = 0

    kept_x = []
    kept_eps = []

    path_length_sum = 0.0
    path_length_count = 0

    u_unit_cache = None
    if quantile is not None and total > 0:
        u_unit_cache = _fixed_uniform(args, total, dim)

    target_name = (args.target_dataset if hasattr(args, "target_dataset") else "funnel").lower()
    raw_sampler = _base_sampler(sampler)

    # progress loop
    num_loops = (total + batch_size - 1) // batch_size
    for loop_idx in range(num_loops):
        current_batch_size = min(batch_size, total - seen)
        if current_batch_size <= 0:
            break
        # Initial noise at τ=1: prefer quantile if provided for consistency
        if quantile is not None:
            u_eps = float(args.q_u_eps) if hasattr(args, "q_u_eps") else 5e-5
            u_slice = u_unit_cache[seen: seen + current_batch_size].to(device)
            Uv = u_eps + (1 - 2 * u_eps) * u_slice
            ones_t = torch.ones(current_batch_size, 1, device=device)
            with torch.no_grad():
                eps = quantile(Uv, ones_t)
        elif noise is not None:
            eps = noise((current_batch_size, dim)).to(device)
        else:
            eps = torch.randn(current_batch_size, dim, device=device)

        x_T = eps

        trajectory = odeint(ode_func, x_T, t_vals, method="dopri5")
        x_gen = trajectory[-1]   # (cur_bs, dim)

        diffs = trajectory[1:] - trajectory[:-1]
        segment_lengths = torch.linalg.norm(diffs, dim=-1)
        batch_path_lengths = segment_lengths.sum(dim=0)
        path_length_sum += float(batch_path_lengths.sum().item())
        path_length_count += current_batch_size

        # Accumulate NLL sum to compute global mean at the end
        nll_sum += (-sampler.log_prob(x_gen)).sum().item()
        seen += current_batch_size

        # Keep a proportionate random subset from this batch for plotting
        per_batch_keep = max(1, int(round(keep * (current_batch_size / total)))) if keep > 0 else 0
        if per_batch_keep > 0:
            #idx = torch.randperm(current_batch_size, device=device)[:per_batch_keep]
            kept_x.append(x_gen.detach().cpu())
            kept_eps.append(eps.detach().cpu())

    x_gen = torch.stack(kept_x, dim=0).reshape(-1, dim)
    x_gen_raw = _to_raw_if_needed(sampler, x_gen)
    eps_kept = torch.stack(kept_eps, dim=0).reshape(-1, dim) if kept_eps else None

    avg_path_length = None
    if path_length_count > 0:
        avg_path_length = path_length_sum / path_length_count

    '''# Plot (downsample to exactly 'keep' if we slightly overshot)
    if keep > 0 and kept_x:
        X = torch.cat(kept_x, dim=0)
        E = torch.cat(kept_eps, dim=0)
        if X.shape[0] > keep:
            perm = torch.randperm(X.shape[0])[:keep]
            X = X[perm]
            E = E[perm]'''

    # Choose plotting pipeline based on target
    if target_name in {"funnel", "nealfunnel"}:
        plot.plot_funnel_2d(x_gen_raw, raw_sampler, step, big_eval, output_dir)
    else:
        plot.plot_generic_2d(x_gen, sampler, step, big_eval, output_dir)
    #print(funnel_eval.evaluate_x2_marginal_metrics(x_gen))
    
    # New: latent colored by norm of reached target x
    if eps_kept is not None:
        plot.plot_latent_colored_by_target_norm(eps_kept, x_gen_raw, step, output_dir, big_eval=big_eval)

    w2_payload = {}
    w2_sample_cap = int(getattr(args, "w2_eval_samples", _DEFAULT_W2_SAMPLES))
    w2_sample_count = min(w2_sample_cap, x_gen_raw.shape[0])
    w2_q_count = min(w2_sample_cap, eps_kept.shape[0]) if eps_kept is not None else 0

    gt_total_needed = 0
    if w2_sample_count >= 2:
        gt_total_needed = max(gt_total_needed, w2_sample_count)
    if w2_q_count >= 2:
        gt_total_needed = max(gt_total_needed, w2_q_count)

    gt_samples = None
    if gt_total_needed >= 2:
        gt_samples = _fixed_ground_truth(args, sampler, gt_total_needed, device=device)

    if gt_samples is not None and w2_sample_count >= 2:
        w2_gen = _compute_w2_distance(
            x_gen_raw[:w2_sample_count],
            gt_samples[:w2_sample_count],
            max_samples=w2_sample_count,
        )
        if w2_gen is not None:
            w2_payload["metrics/w2_generation"] = float(w2_gen)
        elif not getattr(args, "_w2_ot_missing_warned", False):
            print("[eval] POT not installed; skipping W2 computation.")
            args._w2_ot_missing_warned = True

    # Compare the learned quantile (at τ=1) directly with the target.
    if gt_samples is not None and w2_q_count >= 2:
        w2_quantile = _compute_w2_distance(
            eps_kept[:w2_q_count],
            gt_samples[:w2_q_count],
            max_samples=w2_q_count,
        )
        if w2_quantile is not None:
            w2_payload["metrics/w2_quantile"] = float(w2_quantile)

    # Energy MMD (GeomLoss) between generated samples and the target.
    energy_count = min(w2_sample_cap, x_gen_raw.shape[0])
    if gt_samples is not None and energy_count >= 2:
        energy_metric = _get_energy_metric(args)
        if energy_metric is not None:
            gen_subset = x_gen_raw[:energy_count]
            tgt_subset = gt_samples[:energy_count]
            energy_value = energy_metric(gen_subset, tgt_subset)
            w2_payload["metrics/mmd_energy"] = float(energy_value.item())
        elif not getattr(args, "_geomloss_missing_warned", False):
            print("[eval] GeomLoss not installed; skipping energy MMD computation.")
            args._geomloss_missing_warned = True

    if avg_path_length is not None:
        w2_payload["metrics/path_length_avg"] = float(avg_path_length)

    if w2_payload:
        wandb.log(w2_payload, step=step)
