import math
import os
import time
from typing import Optional

import torch
import torch.nn.functional as F
import wandb
from torch.optim.swa_utils import AveragedModel, get_ema_multi_avg_fn
from torchdiffeq import odeint
from tqdm import tqdm
from geomloss import SamplesLoss

from learn_noise.training.pretrain_quantile import build_quantile, pretrain_quantile
from learn_noise.training.common import seed_all, make_fixed_uniform, minibatch_ot_pairing
from learn_noise.networks.model_wrapper import TorchWrapper, ODEWrapper
import learn_noise.utils.sampler as smpl
from learn_noise.training.logging import (
    log_quantile_image_metrics,
    log_quantile_low_dim_metrics,
    log_real_rgb_histogram_once,
)
from learn_noise.utils.image_eval import reshape_flat_samples



def _generate_samples(
    num_samples: int,
    *,
    batch_size: int,
    device: torch.device,
    dim: int,
    u_eps: float,
    quantile,
    ode_func,
    t_eval: torch.Tensor,
    wrapper: TorchWrapper,
    eval_model,
    u_source: Optional[torch.Tensor] = None,
    labels: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Sample outputs from the current flow model for evaluation/logging."""
    wrapper.model = eval_model
    eval_model_prev_mode = eval_model.training
    eval_model.eval()

    outputs = []
    produced = 0
    while produced < num_samples:
        cur_bs = min(batch_size, num_samples - produced)
        if u_source is not None:
            u_unit = u_source[produced:produced + cur_bs].to(device)
        else:
            u_unit = torch.rand(cur_bs, dim, device=device)
        U = u_eps + (1 - 2 * u_eps) * u_unit
        ones = torch.ones(cur_bs, 1, device=device)
        x_aux = torch.zeros(cur_bs, dim, device=device)
        eps_init = quantile(U, ones, x_aux=x_aux)
        if labels is not None:
            lbl_batch = labels[produced:produced + cur_bs].to(device)
            wrapper.set_labels(lbl_batch)
        else:
            wrapper.set_labels(None)
        traj = odeint(ode_func, eps_init, t_eval, method='euler')
        outputs.append(traj[-1].detach().cpu())
        produced += cur_bs

    if eval_model_prev_mode:
        eval_model.train()

    wrapper.set_labels(None)

    return torch.cat(outputs, dim=0)


def train_fm_quantile(args, model, optimizer):
    device = torch.device(args.device)
    seed_all(args.seed)

    global_step_offset = 0


    if args.mode == 'pretrain_quantile':
        quantile, _, _ = pretrain_quantile(args)
        global_step_offset += args.q_ntrain

    else:    
        quantile = build_quantile(args, device, args.dim)
    
    

    ckpt_path = args.quantile_checkpoint
    if ckpt_path:
        state = torch.load(ckpt_path, map_location=device)
        if isinstance(state, dict):
            if 'quantile' in state:
                quantile.load_state_dict(state['quantile'])
            elif 'state_dict' in state:
                quantile.load_state_dict(state['state_dict'])
            else:
                quantile.load_state_dict(state)
        else:
            quantile.load_state_dict(state)
        print('[train_fm_quantile] Loaded quantile checkpoint from', ckpt_path)

    freeze_quantile = bool(args.freeze_quantile)
    if freeze_quantile:
        for param in quantile.parameters():
            param.requires_grad_(False)

    base_lr = float(args.lr)
    param_groups = [{"params": list(model.parameters()), "lr": base_lr}]
    if not freeze_quantile and any(p.requires_grad for p in quantile.parameters()):
        param_groups.append({"params": list(quantile.parameters()), "lr": float(args.q_lr)})

    optimizer = torch.optim.Adam(param_groups)

    warmup_steps = max(0, int(getattr(args, "warmup_lr", 0)))

    def _warmup_lambda(step: int) -> float:
        if warmup_steps <= 0:
            return 1.0
        return min(1.0, float(step + 1) / warmup_steps)

    if len(param_groups) > 1:
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=[_warmup_lambda, lambda _: 1.0],
        )
    else:
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=_warmup_lambda)


    sampler = smpl.get_distribution(args.target_dataset)

    ema = None
    ema_started = False
    ema_avg_fn = get_ema_multi_avg_fn(args.ema)
    wrapper = TorchWrapper(model)
    ode_func = ODEWrapper(wrapper).to(device)

    schedule_label = getattr(args, "quantile_schedule", None)
    if schedule_label is None:
        schedule_label = "full" if bool(getattr(args, "use_minibatch_ot", False)) else "short"
    schedule_label = str(schedule_label).lower()

    if schedule_label == "short":
        quantile_const_iters = 50_000
        quantile_decay_iters = 5_000
    elif schedule_label == "full":
        quantile_const_iters = 25_000
        quantile_decay_iters = 10_000
    else:
        raise ValueError(f"Unknown quantile_schedule '{schedule_label}'")

    quantile_schedule_total = quantile_const_iters + quantile_decay_iters
    if len(optimizer.param_groups) > 1:
        quantile_base_lr = optimizer.param_groups[1]["lr"]
    else:
        quantile_base_lr = 0.0
    ema_start_step = int(args.ema_start_step)

    print(
        f"[train_fm_quantile] quantile_schedule={schedule_label}"
        f" (const={quantile_const_iters}, decay={quantile_decay_iters})"
    )

    u_eps = float(args.q_u_eps)
    print("u_eps", u_eps)

    q_loss_weight = float(getattr(args, "q_loss_weight", 1.0))

    if args.metric == 'mmd':
        metric = SamplesLoss('energy')
    elif args.metric == 'SD':
        metric = SamplesLoss(blur=args.metric_blur)
    else:
        metric = None

    def _quantile_lr_schedule(step: int) -> float:
        if step < quantile_const_iters:
            return quantile_base_lr
        if step < quantile_schedule_total:
            decay_progress = (step - quantile_const_iters + 1) / max(1, quantile_decay_iters)
            return max(0.0, quantile_base_lr * (1.0 - decay_progress))
        return 0.0

    def _active_eval_model():
        return ema if ema is not None else model


    ones_tau = torch.ones(args.batch_size, 1, device=device)

    image_shape = getattr(args, "image_shape", None)
    image_dim = math.prod(image_shape) if image_shape is not None else None
    is_image_task = image_shape is not None and image_dim == args.dim

    if is_image_task:
        log_real_rgb_histogram_once(
            args=args,
            sampler=sampler,
            image_shape=image_shape,
            device=device,
            step=0,
        )

    fid_interval = int(args.fid_eval_interval)
    fid_num_gen = int(args.fid_num_gen)
    fid_image_size = int(args.fid_image_size if image_shape else 0) if fid_interval > 0 else 0
    fid_batch_size = max(1, int(args.fid_batch_size)) if fid_interval > 0 else 0
    fid_gen_batch = max(1, int(args.fid_gen_batch)) if fid_interval > 0 else 0
    fid_real_cache = None
    if is_image_task and fid_interval > 0 and fid_num_gen > 0:
        with torch.no_grad():
            real_samples = sampler.sample(fid_num_gen, device=device, dtype=torch.float32)
            real_imgs = reshape_flat_samples(real_samples, torch.Size(image_shape))
        fid_real_cache = real_imgs.detach().cpu()

    ############## PATH ###############
    sample_dir = os.path.join(args.runs_dir, "samples")
    checkpoint_dir = os.path.join(args.runs_dir, "quantile_fm")
    os.makedirs(checkpoint_dir, exist_ok=True)

    fixed_u_vis = None
    if is_image_task and args.sample_vis_interval > 0 and args.sample_vis_count > 0:
        if not hasattr(args, "_fixed_double_fm_vis_u") or args._fixed_double_fm_vis_u.shape[0] < args.sample_vis_count:
            args._fixed_double_fm_vis_u = make_fixed_uniform((args.sample_vis_count, args.dim), seed=args.seed + 73, device=device)
        fixed_u_vis = args._fixed_double_fm_vis_u

    fixed_eval_u = None if is_image_task else (args._fixed_quantile_eval_u if hasattr(args, "_fixed_quantile_eval_u") else None)

    t_eval = torch.linspace(1.0, 0.0, args.num_steps_eval, device=device)


    train_time_accumulator = 0.0
    global_step = global_step_offset

    for step in tqdm(range(args.epochs), desc="Flow-matching (test trajectories)"):
        iter_start = time.perf_counter()
        model.train()
        if not freeze_quantile and step < quantile_schedule_total:
            quantile.train()
        else:
            quantile.eval()
        optimizer.zero_grad(set_to_none=True)
        global_step = global_step_offset + step

        quantile_lr = _quantile_lr_schedule(global_step)
        if len(optimizer.param_groups) > 1:
            optimizer.param_groups[1]["lr"] = quantile_lr

        if (not ema_started) and global_step >= ema_start_step:
            ema = AveragedModel(model, multi_avg_fn=ema_avg_fn)
            ema.to(device)
            ema.eval()
            wrapper.model = ema
            ema_started = True

        pairing_cost = None

        x0_full = sampler.sample(args.batch_size, device=device, dtype=torch.float32)
        t_full = torch.rand(args.batch_size, 1, device=device)

        U = u_eps + (1 - 2 * u_eps) * torch.rand_like(x0_full)

        if step < quantile_schedule_total:
            eps, dqdt = quantile(U, t_full, x_aux=x0_full, return_dqdt=True, requires_grad=True)
        else:
            eps, dqdt = quantile(U, t_full, x_aux=x0_full, return_dqdt=True)

        x0_train = x0_full
        t_train = t_full

        if args.use_minibatch_ot:
            idx_best, _ = minibatch_ot_pairing(x0_full, dqdt)
            x0_train = x0_full[idx_best]
    
                                                                            
      
        x_t = (1.0 - t_train) * x0_train + eps                 #For independent coupling + W2 this logic needs manual adjustment, resort only for the w2_loss, not the fm loss
        v_target = -x0_train + dqdt         

        v_net = model(t_train, x_t)

        w2_loss = 0.5 * v_target.pow(2).sum(dim=1).mean()

        if args.metric == 'ot':
            loss_q = w2_loss
        elif args.metric in {'mmd', 'SD'}:
            if metric is None:
                raise RuntimeError("Metric operator not initialised for metric='" + str(args.metric) + "'")
            loss_q = metric(dqdt, x0_train)
        else:
            raise NotImplementedError

        kl_loss = torch.zeros((), device=device)

        if args.kl > 0 and step < quantile_schedule_total and not args.freeze_quantile:
            with torch.set_grad_enabled(True):
                diag = quantile.diag_du(U, t_train, None, create_graph=True)
            logdet = torch.log(diag.clamp_min(1e-12)).sum(dim=1)
            kl_loss = (-logdet).mean()

       
        loss_velocity = F.mse_loss(v_net, v_target)
        loss = loss_velocity + q_loss_weight * loss_q + args.kl * kl_loss

        loss.backward()
        grad_model = torch.nn.utils.clip_grad_norm_(model.parameters(), args.model_grad_clip)
        optimizer.step()
        scheduler.step()
        if ema is not None:
            ema.update_parameters(model)

        train_time_accumulator += time.perf_counter() - iter_start

        log_payload = {
            'loss/velocity': float(loss_velocity.item()),
            'loss/q': float(loss_q.item()),
            'loss/w2': float(w2_loss.item()),
            'loss/kl': float(kl_loss.item()) if torch.is_tensor(kl_loss) else float(kl_loss),
            'loss/total': float(loss.item()),
            'grad/model_velocity': float(grad_model.item()),
            'lr/quantile': float(quantile_lr),
        }
        if pairing_cost is not None:
            log_payload['metrics/minibatch_ot_cost'] = float(pairing_cost.item())
        wandb.log(log_payload, step=global_step)

        do_light = (args.eval_sample > 0) and (((step + 1) % args.eval_step) == 0)
        do_heavy = (args.big_eval_samples > 0) and (((step + 1) % args.big_eval_step) == 0)

        # Image-space evaluations (sample grids, latent viz, FID) run on their own schedules.
        if is_image_task:
            run_samples = (
                args.sample_vis_interval > 0
                and args.sample_vis_count > 0
                and ((step + 1) % args.sample_vis_interval == 0)
            )
            run_latent = run_samples and args.latent_viz_samples > 0
            run_fid = (
                fid_interval > 0
                and fid_num_gen > 0
                and fid_real_cache is not None
                and ((step + 1) % fid_interval == 0)
            )
            if run_samples or run_latent or run_fid:
                batch_size_for_logging = (
                    fid_gen_batch
                    if fid_gen_batch > 0
                    else max(1, args.sample_vis_count if args.sample_vis_count > 0 else args.batch_size)
                )
                eval_model_for_logging = _active_eval_model()
                def generate_for_logging(
                    count: int,
                    *,
                    u_source: Optional[torch.Tensor] = None,
                    labels: Optional[torch.Tensor] = None,
                ) -> torch.Tensor:
                    return _generate_samples(
                        count,
                        batch_size=batch_size_for_logging,
                        device=device,
                        dim=args.dim,
                        u_eps=u_eps,
                        quantile=quantile,
                        ode_func=ode_func,
                        t_eval=t_eval,
                        wrapper=wrapper,
                        eval_model=eval_model_for_logging,
                        u_source=u_source,
                        labels=labels,
                    )
                fixed_u_vis = log_quantile_image_metrics(
                    args=args,
                    step=global_step,
                    eval_model=eval_model_for_logging,
                    wrapper=wrapper,
                    quantile=quantile,
                    device=device,
                    image_shape=image_shape,
                    sampler=sampler,
                    sample_vis_interval=args.sample_vis_interval,
                    sample_vis_count=args.sample_vis_count,
                    sample_vis_nrow=max(1, args.sample_vis_nrow),
                    sample_dir=sample_dir,
                    fid_interval=fid_interval,
                    fid_num_gen=fid_num_gen,
                    fid_batch_size=fid_batch_size,
                    fid_image_size=fid_image_size,
                    fid_gen_batch=fid_gen_batch,
                    fid_real_cache=fid_real_cache,
                    generate_samples=generate_for_logging,
                    fixed_u_vis=fixed_u_vis,
                    u_eps=u_eps,
                )
                if fixed_u_vis is not None:
                    args._fixed_double_fm_vis_u = fixed_u_vis
        if (step + 1) % 20_000 == 0:
            ckpt_suffix = f"step_{global_step:06d}.pt"
            quantile_payload = {
                "step": global_step,
                "state_dict": quantile.state_dict(),
                "u_eps": u_eps,
                "dim": args.dim,
            }
            torch.save(quantile_payload, os.path.join(checkpoint_dir, f"quantile_{ckpt_suffix}"))
            if ema is not None:
                ema_payload = {
                    "step": global_step,
                    "state_dict": ema.state_dict(),
                }
                torch.save(ema_payload, os.path.join(checkpoint_dir, f"ema_{ckpt_suffix}"))

        # Low-dimensional trajectory / Sinkhorn metrics retain their own cadence.
        if not is_image_task and (do_light or do_heavy):
            eval_model = _active_eval_model()
            fixed_eval_u = log_quantile_low_dim_metrics(
                args=args,
                step=global_step,
                eval_model=eval_model,
                wrapper=wrapper,
                ode_func=ode_func,
                sampler=sampler,
                quantile=quantile,
                x0_batch=x0_full,
                device=device,
                do_light=do_light,
                do_heavy=do_heavy,
                u_eps=u_eps,
                fixed_eval_u=fixed_eval_u,
            )
            if fixed_eval_u is not None:
                args._fixed_quantile_eval_u = fixed_eval_u

    ckpt_suffix = f"step_{global_step:06d}.pt"
    quantile_payload = {
        "step": global_step,
        "state_dict": quantile.state_dict(),
        "u_eps": u_eps,
        "dim": args.dim,
    }
    torch.save(quantile_payload, os.path.join(checkpoint_dir, f"quantile_{ckpt_suffix}"))
    if ema is not None:
        ema_payload = {
            "step": global_step,
            "state_dict": ema.state_dict(),
        }
        torch.save(ema_payload, os.path.join(checkpoint_dir, f"ema_{ckpt_suffix}"))

    runtime_path = os.path.join(args.runs_dir, "runtime_training_only.txt")
    os.makedirs(args.runs_dir, exist_ok=True)
    with open(runtime_path, "w", encoding="utf-8") as fh:
        fh.write(f"{train_time_accumulator:.6f}\n")

    return quantile, ema
