import math
import os
import json
from typing import Callable, Dict, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.swa_utils import AveragedModel, get_ema_multi_avg_fn
import wandb
from tqdm import tqdm
from geomloss import SamplesLoss

try:
    import ot  # type: ignore
except ImportError:
    ot = None

from learn_noise.networks import RQSQuantile, RQSQuantileImage
import learn_noise.utils.sampler as smpl
from learn_noise.training.logging import log_real_rgb_histogram_once
from learn_noise.utils.quantile_viz import log_noise_slices, log_xt_slices, log_noise_slices_fixed
from learn_noise.utils.image_eval import reshape_flat_samples, save_image_grid
from learn_noise.utils.image_latent_viz import build_latent_visualizations


def build_quantile(args, device: torch.device, dim: int):
    is_image = getattr(args, "image_shape", None) is not None
    quantile_cls = RQSQuantileImage if is_image else RQSQuantile
    return quantile_cls(
        dim=dim,
        n_bins=int(args.q_rqs_bins),
        bound=float(args.q_rqs_bound),
        num_layers=int(args.q_rqs_layers),
        eps=float(args.q_u_eps),
    ).to(device)


def build_transport_objective(
    args,
    device: torch.device,
) -> Tuple[Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor, Dict[str, float]]], bool, bool]:
    metric_name_source = getattr(args, "q_objective", None)
    if metric_name_source is None:
        metric_name_source = getattr(args, "q_ot_metric","energy")
    metric_name = metric_name_source.lower()
    if metric_name == 'plan':
        metric_name = 'plan_action'

    blur = max(float(args.q_ot_epsilon) if hasattr(args, "q_ot_epsilon") else 1.0, 1e-8)

    if metric_name in {'energy', 'sinkhorn'}:
        samples_kwargs = {}
        scaling = args.q_ot_scaling if hasattr(args, "q_ot_scaling") else None
        if scaling is not None:
            samples_kwargs['scaling'] = float(scaling)
        samples_loss = SamplesLoss(metric_name, blur=blur, **samples_kwargs)

        def objective(x0: torch.Tensor, dqdt: Optional[torch.Tensor] = None):

            value = samples_loss(x0, dqdt)
            metric_key = f"quantile_ot/{metric_name}"
            metrics = {
                "quantile_ot/transport": float(value.item()),
                metric_key: float(value.item()),
            }
            return value, metrics

        return objective, False, False

    if metric_name == 'plan_action':
        if ot is None:
            raise ImportError("POT (Python Optimal Transport) package is required for q_objective='plan_action'.")

        def objective(x0: torch.Tensor, dqdt_selected: Optional[torch.Tensor] = None):
            if dqdt_selected is None:
                raise ValueError("plan_action objective requires dqdt values; set return_dqdt=True.")
            batch = x0.shape[0]
            
            cost_matrix = 0.5 * (x0.unsqueeze(1) - dqdt_selected.unsqueeze(0)).pow(2).sum(dim=-1)
            cost_np = cost_matrix.detach().cpu().numpy()
            weights = np.full(batch, 1.0 / batch, dtype=cost_np.dtype)
            plan_np = ot.emd(weights, weights, cost_np)
            plan = torch.from_numpy(plan_np).to(device=x0.device, dtype=cost_matrix.dtype)

            idx_best = torch.argmax(plan, dim=0)
            dqdt_best = dqdt_selected

            v_selected = -x0[idx_best] + dqdt_best
            kinetic = 0.5 * v_selected.pow(2).sum(dim=1)
            action_loss = kinetic.mean()

            matched_cost = cost_matrix[torch.arange(batch, device=x0.device), idx_best].mean()
            metrics = {
                "quantile_ot/action": float(action_loss.item()),
                "quantile_ot/plan_cost": float(matched_cost.item()),
                "quantile_ot/transport": float(action_loss.item()),
            }
            return action_loss, metrics

        return objective, True, True

    raise ValueError(
        f"Unsupported q_objective '{metric_name}' (expected 'energy', 'sinkhorn', or 'plan_action')."
    )


def _compute_kl_regularizer(
    quantile: nn.Module,
    U: torch.Tensor,
    tau: torch.Tensor,
    lam_kl: float,
    device: torch.device,
) -> torch.Tensor:
    if lam_kl <= 0.0:
        return torch.zeros((), device=device)

    jac_diag = quantile.diag_du(U, tau, None, create_graph=True)
    logdet = torch.log(jac_diag.clamp_min(1e-12)).sum(dim=1)
    return (-logdet).mean()


def _get_quantile_eval_model(
    quantile: nn.Module,
    ema_quantile: Optional[AveragedModel],
) -> nn.Module:
    return ema_quantile if ema_quantile is not None else quantile


def _maybe_log_image_visualizations(
    *,
    args,
    quant_eval: nn.Module,
    step: int,
    global_step: int,
    is_image_task: bool,
    sample_vis_interval: int,
    sample_vis_count: int,
    sample_vis_nrow: int,
    sample_dir: str,
    fixed_u_vis: Optional[torch.Tensor],
    u_eps: float,
    image_shape: Optional[Tuple[int, ...]],
    device: torch.device,
    dim: int,
    sampler,
) -> None:
    if not (is_image_task and sample_vis_interval > 0 and sample_vis_count > 0):
        return

    log_noise = ((step + 1) % sample_vis_interval == 0)
    latent_viz_samples = int(args.latent_viz_samples)
    log_latent = log_noise and latent_viz_samples > 0

    if not (log_noise or log_latent):
        return

    os.makedirs(sample_dir, exist_ok=True)

    prev_mode = quant_eval.training
    quant_eval.eval()

    with torch.inference_mode():
        if log_noise:
            if fixed_u_vis is not None and fixed_u_vis.shape[0] >= sample_vis_count:
                u_noise = fixed_u_vis[:sample_vis_count].to(device)
            else:
                u_noise = torch.rand(sample_vis_count, dim, device=device)
                fixed_u_vis = u_noise.detach().cpu()
            U_noise = u_eps + (1 - 2 * u_eps) * u_noise
            tau_noise = torch.ones(sample_vis_count, 1, device=device)
            eps_noise = quant_eval(U_noise, tau_noise)
            noise_imgs = reshape_flat_samples(eps_noise, torch.Size(image_shape))
            grid_np = save_image_grid(
                noise_imgs,
                path=os.path.join(sample_dir, f'step_{global_step:06d}_noise.png'),
                nrow=sample_vis_nrow,
            )
            wandb.log({"quantile_ot/noise_grid": wandb.Image(grid_np)}, step=global_step)

        if log_latent:
            num_latent = min(latent_viz_samples, 1024)
            unit_u = torch.rand(num_latent, dim, device=device)
            U_latent = u_eps + (1 - 2 * u_eps) * unit_u
            ones_latent = torch.ones(num_latent, 1, device=device)
            x_aux_latent = torch.zeros(num_latent, dim, device=device)
            eps_latent = quant_eval(U_latent, ones_latent, x_aux=x_aux_latent)

            latents_cpu = eps_latent.detach().cpu()
            atlas_grid = int(args.latent_atlas_grid)
            atlas_images = None
            if atlas_grid > 1:
                n_atlas = min(atlas_grid * atlas_grid, latents_cpu.shape[0])
                if n_atlas > 0 and math.prod(image_shape) == latents_cpu.shape[1]:
                    perm = torch.randperm(latents_cpu.shape[0])[:n_atlas]
                    atlas_images = latents_cpu[perm].reshape(n_atlas, *image_shape)

            viz_payload = build_latent_visualizations(
                latents_cpu,
                image_shape=image_shape,
                atlas_images=atlas_images,
            )

            wandb_viz = {
                "latent/mean_std": wandb.Image(viz_payload.mean_std_fig),
                "latent/hist_qq": wandb.Image(viz_payload.hist_qq_fig),
                "latent/pca": wandb.Image(viz_payload.pca_fig),
                "latent/correlation": wandb.Image(viz_payload.corr_fig),
            }
            if viz_payload.atlas_grid is not None:
                wandb_viz["latent/atlas"] = wandb.Image(viz_payload.atlas_grid)
            wandb.log(wandb_viz, step=global_step)

            for fig in [
                viz_payload.mean_std_fig,
                viz_payload.hist_qq_fig,
                viz_payload.pca_fig,
                viz_payload.corr_fig,
            ]:
                plt.close(fig)

    if prev_mode:
        quant_eval.train()


def _maybe_run_low_dim_evals(
    *,
    args,
    quant_eval: nn.Module,
    sampler,
    device: torch.device,
    global_step: int,
    step: int,
    dim: int,
    u_eps: float,
    normal: torch.distributions.Normal,
) -> None:
    if dim > 2:
        return

    interval = int(args.q_val_interval)
    if interval <= 0:
        return

    if (step + 1) % interval != 0 and step != 0:
        return

    prev_mode = quant_eval.training
    quant_eval.eval()

    with torch.no_grad():
        q_val_samples = int(args.q_val_samples)
        if q_val_samples > 0:
            Uv = torch.rand(q_val_samples, dim, device=device).clamp(u_eps, 1.0 - u_eps)
            z_true = normal.icdf(Uv)
            ones = torch.ones(q_val_samples, 1, device=device)
            eps_tau1 = quant_eval(
                Uv,
                ones,
                x_aux=torch.zeros(q_val_samples, dim, device=device),
            )
            end_mse = F.mse_loss(eps_tau1, z_true).item()
            wandb.log({"quantile_ot/val_endpt_q1_mse": float(end_mse)}, step=global_step)

    with torch.inference_mode():
        x_aux_viz = sampler.sample(4000, device=device)
        log_noise_slices(
            quant_eval,
            x_aux_viz,
            device,
            global_step,
            times=(0.0, 0.03, 0.07, 0.15, 0.25, 0.5, 0.75, 0.9, 0.98, 1.0),
            N=4000,
            dim=dim,
            u_eps=u_eps,
        )

        log_xt_slices(
            quant_eval,
            sampler,
            device,
            global_step,
            times=(0.0, 0.1, 0.25, 0.5, 0.75, 0.9, 1.0),
            N=4000,
            dim=dim,
            u_eps=u_eps,
        )

        x_fixed = sampler.sample(3, device=device)
        log_noise_slices_fixed(
            quant_eval,
            x_fixed,
            device,
            global_step,
            times=(0.0, 0.05, 0.12, 0.25, 0.5, 0.75, 0.9, 0.98, 1.0),
            Nu=4000,
            dim=dim,
            u_eps=u_eps,
            share_u=True,
            same_axes_across_rows=False,
            title_prefix="quantile_ot/noise_slices_fixedx0_with_data",
            include_country=True,
            sampler=sampler,
        )

    if prev_mode:
        quant_eval.train()


def pretrain_quantile(args, *, log_step_offset: int = 0) -> Tuple[nn.Module, str, int]:
    """Quantile training with configurable OT objectives and structured logging."""
    device = torch.device(args.device)
    dim = args.dim
    sampler = smpl.get_distribution(args.target_dataset)

    quantile = build_quantile(args, device, dim)
    opt = torch.optim.Adam([{"params": quantile.parameters(), "lr": args.q_lr}])

    use_ema = bool(args.q_use_ema)
    ema_quantile = AveragedModel(quantile, multi_avg_fn=get_ema_multi_avg_fn(args.q_ema)) if use_ema else None

    transport_objective, requires_dqdt, _ = build_transport_objective(args, device)

    u_eps = float(args.q_u_eps)
    batch_size = int(args.q_batch)
    q_loss_weight = float(getattr(args, "q_loss_weight", 1.0))
    lam_kl = float(args.kl)


    normal = torch.distributions.Normal(0.0, 1.0)

    def _fixed_sampler_batch(batch_size: int, seed_offset: int = 0):
        devices = [device] if device.type == "cuda" else []
        with torch.random.fork_rng(devices=devices):
            torch.manual_seed(args.seed + seed_offset)
            if device.type == "cuda" and torch.cuda.is_available():
                torch.cuda.manual_seed_all(args.seed + seed_offset)
            return sampler.sample(batch_size, device=device, dtype=torch.float32)

    def _fixed_uniform(shape, seed_offset: int = 0):
        gen = torch.Generator()
        gen.manual_seed(args.seed + seed_offset)
        base = torch.rand(shape, generator=gen, dtype=torch.float32)
        return base.to(device)

    if not hasattr(args, "_fixed_quantile_ot_x0") or args._fixed_quantile_ot_x0.shape[0] != batch_size:
        args._fixed_quantile_ot_x0 = _fixed_sampler_batch(batch_size, seed_offset=5)
    if not hasattr(args, "_fixed_quantile_ot_u01") or args._fixed_quantile_ot_u01.shape[0] != batch_size:
        args._fixed_quantile_ot_u01 = _fixed_uniform((batch_size, dim), seed_offset=7)

    x0_fixed = args._fixed_quantile_ot_x0
    u_unit_fixed = args._fixed_quantile_ot_u01

    image_shape = args.image_shape
    image_dim = math.prod(image_shape) if image_shape is not None else None
    is_image_task = image_shape is not None and image_dim == dim

    if is_image_task:
        log_real_rgb_histogram_once(
            args=args,
            sampler=sampler,
            image_shape=image_shape,
            device=device,
            step=0,
        )

    sample_dir = os.path.join(args.runs_dir, "quantile_ot", "samples")
    fixed_u_vis: Optional[torch.Tensor] = None
    if is_image_task and args.sample_vis_interval > 0 and args.sample_vis_count > 0:
        if not hasattr(args, "_fixed_quantile_ot_vis_u") or args._fixed_quantile_ot_vis_u.shape[0] < args.sample_vis_count:
            args._fixed_quantile_ot_vis_u = _fixed_uniform((args.sample_vis_count, dim), seed_offset=19)
        fixed_u_vis = args._fixed_quantile_ot_vis_u

    total_steps = int(args.q_ntrain)
    B= batch_size
    D = args.dim
    for step in tqdm(range(total_steps), desc="Quantile training (ot)"):
        quantile.train()
        opt.zero_grad(set_to_none=True)
        global_step = log_step_offset + step
        # --- Data ---
        with torch.no_grad():
            x0 = sampler.sample(B, device=device, dtype=torch.float32)

        U = u_eps + (1 - 2 * u_eps) * torch.rand(B, D, device=device)
        tau_one = torch.ones(B, 1, device=device)

        eps, dqdt = quantile(
            U,
            tau_one,
            x_aux=None,
            return_dqdt=True,
            requires_grad=True,
        )


        transport_loss, transport_metrics = transport_objective(x0, dqdt)
        kl_loss = _compute_kl_regularizer(quantile, U, tau_one, lam_kl, device) if lam_kl > 0.0 else torch.zeros((), device=device)

        loss = q_loss_weight * transport_loss + lam_kl * kl_loss
        loss.backward()
        grad_quant = torch.nn.utils.clip_grad_norm_(quantile.parameters(), args.model_grad_clip)
        opt.step()
        if use_ema and ema_quantile is not None:
            ema_quantile.update_parameters(quantile)

        metrics = {
            "quantile_ot/loss_total": float(loss.item()),
            "quantile_ot/loss_q": float(transport_loss.item()),
            "quantile_ot/loss_kl": float(kl_loss.item()),
            "quantile_ot/grad": float(grad_quant.item()),
        }
        metrics.update(transport_metrics)
        wandb.log(metrics, step=global_step)

        quant_eval = _get_quantile_eval_model(quantile, ema_quantile if use_ema else None)
        _maybe_log_image_visualizations(
            args=args,
            quant_eval=quant_eval,
            step=step,
            global_step=global_step,
            is_image_task=is_image_task,
            sample_vis_interval=args.sample_vis_interval,
            sample_vis_count=args.sample_vis_count,
            sample_vis_nrow=max(1, args.sample_vis_nrow),
            sample_dir=sample_dir,
            fixed_u_vis=fixed_u_vis,
            u_eps=u_eps,
            image_shape=image_shape,
            device=device,
            dim=dim,
            sampler=sampler,
        )

        _maybe_run_low_dim_evals(
            args=args,
            quant_eval=quant_eval,
            sampler=sampler,
            device=device,
            global_step=global_step,
            step=step,
            dim=dim,
            u_eps=u_eps,
            normal=normal,
        )

    q_dir = os.path.join(args.runs_dir, "quantile_ot")
    os.makedirs(q_dir, exist_ok=True)
    ckpt_path = os.path.join(q_dir, "quantile.pt")
    ckpt_payload = {
        "state_dict": quantile.state_dict(),
        "dim": dim,
        "eps": args.q_u_eps,
        "type": "rqs",
        "n_bins": int(args.q_rqs_bins),
        "bound": float(args.q_rqs_bound),
        "layers": int(args.q_rqs_layers),
    }
    torch.save(ckpt_payload, ckpt_path)

    if use_ema and ema_quantile is not None:
        ckpt_path_ema = os.path.join(q_dir, "quantile_ema.pt")
        ema_payload = {**ckpt_payload, "state_dict": ema_quantile.state_dict(), "ema": args.q_ema}
        torch.save(ema_payload, ckpt_path_ema)

    config = {
        "dim": dim,
        "type": "rqs",
        "mode": "ot",
        "ot_epsilon": float(args.q_ot_epsilon) if hasattr(args, "q_ot_epsilon") else 1.0,
        "objective": getattr(args, "q_objective", getattr(args, "q_ot_metric", getattr(args, "q_ot_objective", "energy"))),
        "rqs": {
            "n_bins": int(args.q_rqs_bins),
            "bound": float(args.q_rqs_bound),
            "layers": int(args.q_rqs_layers),
        },
    }
    with open(os.path.join(q_dir, "config.json"), "w") as f:
        json.dump(config, f, indent=2)

    return _get_quantile_eval_model(quantile, ema_quantile if use_ema else None), ckpt_path, total_steps
