
from __future__ import annotations

import ast
import math, re
import os
import time
from typing import Callable, Optional

import torch
import torch.nn.functional as F
import wandb
from torch.optim.swa_utils import AveragedModel, get_ema_multi_avg_fn
from torchdiffeq import odeint
from tqdm import tqdm

import learn_noise.utils.sampler as smpl
from learn_noise.networks.model_wrapper import TorchWrapper, ODEWrapper
from learn_noise.training.common import seed_all, make_fixed_sampler, minibatch_ot_pairing
from learn_noise.training.logging import (
    log_baseline_evaluation,
    log_baseline_image_metrics,
    log_real_rgb_histogram_once,
)
from learn_noise.utils.image_eval import reshape_flat_samples

def _make_latent_sampler(name: str, *, device: torch.device, args: Optional[object] = None) -> Callable:
    lname = name.lower()

    if lname in {"gauss", "gaussian", "normal"}:
        def _sample(shape): return torch.randn(*shape, device=device)
        return _sample

    if lname in {"uniform", "uni"}:
        def _sample(shape): return torch.rand(*shape, device=device) * 2.0 - 1.0
        return _sample

    if lname in {"student_t", "student-t", "studentt"}:
        default_dtype = torch.get_default_dtype()

        def _coerce_param(value, fallback):
            if value is None:
                return fallback
            if isinstance(value, str):
                text = value.strip()
                try:
                    parsed = ast.literal_eval(text)
                except (ValueError, SyntaxError):
                    parsed = float(text)
                return parsed
            return value

        df_param = _coerce_param(getattr(args, "student_t_df", None) if args is not None else None, 4.0)
        scale_param = _coerce_param(getattr(args, "student_t_scale", None) if args is not None else None, 1.0)

        df_tensor = torch.as_tensor(df_param, dtype=default_dtype, device=device)
        scale_tensor = torch.as_tensor(scale_param, dtype=default_dtype, device=device)
        loc_tensor = torch.zeros_like(df_tensor, dtype=default_dtype, device=device)

        dist = torch.distributions.StudentT(df=df_tensor, loc=loc_tensor, scale=scale_tensor)
        batch_shape = dist.batch_shape  # usually () or (dim,)

        def _sample(shape):
            if not shape:
                sample_shape = torch.Size()
            elif len(batch_shape) == 0:
                sample_shape = torch.Size(shape)
            else:
                if len(shape) < len(batch_shape):
                    raise ValueError(
                        "Requested Student-t sample shape is too small for batch parameters: "
                        f"shape={shape}, batch_shape={tuple(batch_shape)}"
                    )
                expected = tuple(batch_shape)
                actual = tuple(shape[-len(batch_shape):])
                if actual != expected:
                    raise ValueError(
                        "Student-t latent requires the trailing dimensions to match the parameter shape: "
                        f"expected {expected}, got {shape}"
                    )
                sample_shape = torch.Size(shape[:-len(batch_shape)])

            samples = dist.sample(sample_shape)
            if isinstance(samples, torch.Tensor):
                return samples
            return torch.as_tensor(samples, dtype=default_dtype, device=device)

        return _sample

    if any(k in lname for k in {"stable", "alpha_stable", "alpha-stable"}):
        # Allow "stable:1.5" etc.; default α=1.5
        m = re.search(r"(?:stable|alpha[-_]?stable)[:_]?([0-9]*\.?[0-9]+)", lname)
        alpha = float(m.group(1)) if m else 1.5
        if not (0.0 < alpha <= 2.0):
            raise ValueError(f"alpha must be in (0, 2], got {alpha}")

        # α=2 reduces to Gaussian
        if abs(alpha - 2.0) < 1e-12:
            def _sample(shape): return torch.randn(*shape, device=device)
            return _sample

        gamma = alpha / 2.0  # index of positive-stable for the mixture

        def _pos_stable(gamma: float, size: tuple[int, ...], *, work_dtype=torch.float64) -> torch.Tensor:
            """
            Positive strictly stable S_γ with Laplace transform E[e^{-λ S_γ}] = exp(-λ^γ), 0<γ<1,
            via Kanter's algorithm in the log-domain for stability.
            Returns shape `size` tensor on `device` in `work_dtype`.
            """
            # U in (0,1), avoid endpoints (sin(πU)=0)
            U = torch.rand(*size, device=device, dtype=work_dtype)
            U = (U * (1.0 - 2e-12)) + 1e-12  # (0,1) open interval
            E = torch.distributions.Exponential(torch.tensor(1.0, device=device, dtype=work_dtype)).sample(size)

            pi = math.pi
            eps = torch.finfo(work_dtype).tiny

            # log K_γ = γ log sin(πγU) + (1-γ) log sin(π(1-γ)U) - log sin(πU)
            log_sin_piU      = torch.log(torch.sin(pi * U).clamp_min(eps))
            log_sin_gpiU     = torch.log(torch.sin(pi * gamma * U).clamp_min(eps))
            log_sin_1mgpiU   = torch.log(torch.sin(pi * (1.0 - gamma) * U).clamp_min(eps))
            logK = gamma * (log_sin_gpiU - log_sin_piU) + (1.0 - gamma) * (log_sin_1mgpiU - log_sin_piU)

            logE = torch.log(E.clamp_min(eps))
            # S_γ = [ K_γ / E^{1-γ} ]^{1/γ}
            logS = (logK - (1.0 - gamma) * logE) / gamma
            return torch.exp(logS)

        def _sample(shape):
            # Work in float64 for the mixing variable; output in default dtype.
            out_dtype = torch.get_default_dtype()

            # Treat the first axis as batch. One V per batch sample → isotropic SαS over the remaining dims.
            if len(shape) == 0:
                batch = 1
                expand_shape = ()
            else:
                batch = shape[0]
                expand_shape = (batch,) + (1,) * (len(shape) - 1)

            # g ~ N(0, I) with target dtype
            g = torch.randn(*shape, device=device, dtype=out_dtype)

            # V ~ S_{α/2} (positive-stable), then scale = sqrt(2 V)
            V = _pos_stable(gamma, (batch,), work_dtype=torch.float64)
            scale = torch.sqrt(2.0 * V).to(out_dtype).view(expand_shape)

            return g * scale

        return _sample

    raise ValueError(f"Unknown baseline latent '{name}'")


def _generate_baseline_samples(
    num_samples: int,
    *,
    batch_size: int,
    device: torch.device,
    dim: int,
    t_eval: torch.Tensor,
    ode_func: ODEWrapper,
    wrapper: TorchWrapper,
    eval_model,
    latent_sampler: Callable[[tuple[int, ...]], torch.Tensor],
    latents: Optional[torch.Tensor] = None,
    labels: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Integrate the learned flow to produce samples for logging/evaluation."""
    wrapper.model = eval_model
    prev_mode = eval_model.training
    eval_model.eval()

    outputs = []
    produced = 0
    while produced < num_samples:
        cur_bs = min(batch_size, num_samples - produced)
        if latents is not None:
            z0 = latents[produced:produced + cur_bs].to(device)
        else:
            z0 = latent_sampler((cur_bs, dim))
        if labels is not None:
            lbl_batch = labels[produced:produced + cur_bs].to(device)
            wrapper.set_labels(lbl_batch)
        else:
            wrapper.set_labels(None)
        traj = odeint(ode_func, z0, t_eval, method='euler')
        outputs.append(traj[-1].detach().cpu())
        produced += cur_bs

    if prev_mode:
        eval_model.train()

    wrapper.set_labels(None)

    return torch.cat(outputs, dim=0) if outputs else torch.empty(0, dim)


def train_fm_baseline(args, model, optimizer) -> None:
    device = torch.device(args.device)
    seed_all(args.seed)

    sampler = smpl.get_distribution(args.target_dataset)

    warmup_steps = max(0, int(getattr(args, "warmup_lr", 0)))

    def _warmup_lambda(step: int) -> float:
        if warmup_steps <= 0:
            return 1.0
        return min(1.0, float(step + 1) / warmup_steps)

    
    
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=_warmup_lambda)

    latent_sampler = _make_latent_sampler(args.baseline_latent, device=device, args=args)

    ema = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(args.ema))
    wrapper = TorchWrapper(ema)
    ode_func = ODEWrapper(wrapper).to(device)

    fixed_sampler = make_fixed_sampler(sampler, seed=args.seed, device=device)
    if not hasattr(args, "_fixed_baseline_x0"):
        args._fixed_baseline_x0 = fixed_sampler(args.batch_size, seed_offset=0)
    x0_fixed = args._fixed_baseline_x0

    image_shape = getattr(args, "image_shape", None)
    image_dim = math.prod(image_shape) if image_shape is not None else None
    is_image_task = image_shape is not None and image_dim == args.dim

    if is_image_task:
        log_real_rgb_histogram_once(
            args=args,
            sampler=sampler,
            image_shape=image_shape,
            device=device,
            step=0,
        )

    checkpoint_dir = os.path.join(args.runs_dir, "baseline_fm")
    os.makedirs(checkpoint_dir, exist_ok=True)

    fid_interval = int(args.fid_eval_interval) if hasattr(args, "fid_eval_interval") else 0
    fid_num_gen = int(args.fid_num_gen) if hasattr(args, "fid_num_gen") else 0
    fid_batch_size = max(1, int(getattr(args, "fid_batch_size", args.batch_size))) if fid_interval > 0 else 0
    fid_gen_batch = max(1, int(getattr(args, "fid_gen_batch", args.batch_size))) if fid_interval > 0 else 0
    fid_image_size = (
        int(getattr(args, "fid_image_size", 0)) if (fid_interval > 0 and image_shape is not None) else 0
    )
    fid_real_cache = None
    if is_image_task and fid_interval > 0 and fid_num_gen > 0:
        with torch.no_grad():
            real_samples = sampler.sample(fid_num_gen, device=device, dtype=torch.float32)
            fid_real_cache = reshape_flat_samples(real_samples, torch.Size(image_shape)).detach().cpu()

    sample_vis_interval = int(getattr(args, "sample_vis_interval", 0))
    sample_vis_count = int(getattr(args, "sample_vis_count", 0))
    sample_vis_nrow = int(getattr(args, "sample_vis_nrow", 8))

    sample_dir = os.path.join(checkpoint_dir, "samples") if is_image_task else ""
    t_eval = torch.linspace(1.0, 0.0, args.num_steps_eval, device=device)

    fixed_vis_noise = getattr(args, "_fixed_baseline_vis_noise", None) if is_image_task else None

    train_time_accumulator = 0.0

    for step in tqdm(range(args.epochs), desc="Flow-matching baseline"):
        iter_start = time.perf_counter()
        model.train()
        optimizer.zero_grad(set_to_none=True)

        pairing_cost = None
        x_0 = sampler.sample(args.batch_size, device=device, dtype=torch.float32)
        z = latent_sampler((args.batch_size, args.dim))
        t = torch.rand(args.batch_size, 1, device=device)

        if args.use_minibatch_ot:
            idx_best, _ = minibatch_ot_pairing(x_0, z)
            x_0 = x_0[idx_best]

        x_t = (1 - t) * x_0 + t * z
        velocity_target = -x_0 + z
        velocity_pred = model(t, x_t)

        loss = F.mse_loss(velocity_pred, velocity_target)
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.model_grad_clip)
        optimizer.step()
        scheduler.step()
        ema.update_parameters(model)

        train_time_accumulator += time.perf_counter() - iter_start

        log_payload = {
            'loss/velocity': float(loss.item()),
            'grad/model_velocity': float(grad_norm.item()),
        }
        if pairing_cost is not None:
            log_payload['metrics/minibatch_ot_cost'] = float(pairing_cost.item())
        wandb.log(log_payload, step=step)

        do_light = (args.eval_sample > 0) and (((step + 1) % args.eval_step) == 0)
        do_heavy = (args.big_eval_samples > 0) and (((step + 1) % args.big_eval_step) == 0)
        if not is_image_task and (do_light or do_heavy):
            log_baseline_evaluation(
                args=args,
                step=step,
                ema_model=ema,
                wrapper=wrapper,
                ode_func=ode_func,
                sampler=sampler,
                noise_sampler=latent_sampler,
                x0_batch=x_0,
                device=device,
                do_light=do_light,
                do_heavy=do_heavy,
            )

        if is_image_task:
            run_samples = (
                sample_vis_interval > 0
                and sample_vis_count > 0
                and ((step + 1) % sample_vis_interval == 0)
            )
            run_fid = (
                fid_interval > 0
                and fid_num_gen > 0
                and fid_real_cache is not None
                and ((step + 1) % fid_interval == 0)
            )
            if run_samples or run_fid:
                if fid_gen_batch > 0:
                    batch_size_for_logging = fid_gen_batch
                else:
                    fallback_bs = sample_vis_count if sample_vis_count > 0 else args.batch_size
                    batch_size_for_logging = max(1, fallback_bs)

                def generate_for_logging(
                    count: int,
                    *,
                    latents: Optional[torch.Tensor] = None,
                    labels: Optional[torch.Tensor] = None,
                ) -> torch.Tensor:
                    return _generate_baseline_samples(
                        count,
                        batch_size=batch_size_for_logging,
                        device=device,
                        dim=args.dim,
                        t_eval=t_eval,
                        ode_func=ode_func,
                        wrapper=wrapper,
                        eval_model=ema,
                        latent_sampler=latent_sampler,
                        latents=latents,
                        labels=labels,
                    )

                fixed_vis_noise = log_baseline_image_metrics(
                    args=args,
                    step=step,
                    eval_model=ema,
                    wrapper=wrapper,
                    device=device,
                    image_shape=image_shape,
                    sampler=sampler,
                    sample_vis_interval=sample_vis_interval,
                    sample_vis_count=sample_vis_count,
                    sample_vis_nrow=max(1, sample_vis_nrow),
                    sample_dir=sample_dir,
                    fid_interval=fid_interval,
                    fid_num_gen=fid_num_gen,
                    fid_batch_size=fid_batch_size,
                    fid_image_size=fid_image_size,
                    fid_gen_batch=fid_gen_batch,
                    fid_real_cache=fid_real_cache,
                    noise_sampler=latent_sampler,
                    generate_samples=generate_for_logging,
                    fixed_noise=fixed_vis_noise,
                )
                if fixed_vis_noise is not None:
                    args._fixed_baseline_vis_noise = fixed_vis_noise
        
        current_step = step + 1
        if current_step % 20_000 == 0:
            ckpt_suffix = f"step_{current_step:06d}.pt"

            model_payload = {
                "step": current_step,
                "state_dict": model.state_dict(),
            }
            torch.save(model_payload, os.path.join(checkpoint_dir, f"model_{ckpt_suffix}"))

            if ema is not None:
                ema_payload = {
                    "step": current_step,
                    "state_dict": ema.state_dict(),
                }
                torch.save(ema_payload, os.path.join(checkpoint_dir, f"ema_{ckpt_suffix}"))



    final_step = args.epochs
    ckpt_suffix = f"step_{final_step:06d}.pt"
    model_payload = {
        "step": final_step,
        "state_dict": model.state_dict(),
    }
    torch.save(model_payload, os.path.join(checkpoint_dir, f"model_{ckpt_suffix}"))

    if ema is not None:
        ema_payload = {
            "step": final_step,
            "state_dict": ema.state_dict(),
        }
        torch.save(ema_payload, os.path.join(checkpoint_dir, f"ema_{ckpt_suffix}"))

    runtime_path = os.path.join(args.runs_dir, "runtime_training_only.txt")
    os.makedirs(args.runs_dir, exist_ok=True)
    with open(runtime_path, "w", encoding="utf-8") as fh:
        fh.write(f"{train_time_accumulator:.6f}\n")
