"""
For evaluating parameterized diffusion policy (conditioned on z embeddings).

Similar to DiffusionEval but passes z embeddings to the network.
"""

import logging
import torch

log = logging.getLogger(__name__)

from model.diffusion.diffusion import DiffusionModel
from model.diffusion.sampling import extract


class ParameterizedDiffusionEval(DiffusionModel):
    """
    Evaluation wrapper for parameterized diffusion models.

    Loads EMA weights from checkpoint and passes z embeddings during sampling.
    """

    def __init__(
        self,
        network_path,
        use_ddim=False,
        cfg_guidance_scale=1.0,  # ω: guidance scale for CFG (1.0 = standard conditional, >1 = amplified)
        **kwargs,
    ):
        # Do not let base class load model
        super().__init__(use_ddim=use_ddim, network_path=None, **kwargs)

        # CFG guidance scale
        self.cfg_guidance_scale = cfg_guidance_scale
        if cfg_guidance_scale != 1.0:
            logging.info(f"CFG guidance enabled with scale ω={cfg_guidance_scale}")

        # Load checkpoint
        checkpoint = torch.load(network_path, map_location=self.device, weights_only=False)

        # Set up model with EMA weights
        self.actor = self.network

        # Load EMA weights (same as DiffusionEval)
        ema_weights = {
            key.split("network.")[1]: checkpoint["ema"][key]
            for key in checkpoint["ema"]
            if "network." in key
        }
        # Use strict=False to handle models trained without CFG (missing z_empty buffer)
        missing_keys, unexpected_keys = self.actor.load_state_dict(ema_weights, strict=False)

        # Log any missing/unexpected keys for debugging
        if missing_keys:
            # z_empty is expected to be missing for non-CFG checkpoints
            non_critical_missing = [k for k in missing_keys if k == 'z_empty']
            critical_missing = [k for k in missing_keys if k != 'z_empty']
            if non_critical_missing:
                logging.info(f"Missing z_empty buffer (expected for non-CFG checkpoints)")
            if critical_missing:
                logging.warning(f"Missing keys in state_dict: {critical_missing}")
        if unexpected_keys:
            logging.warning(f"Unexpected keys in state_dict: {unexpected_keys}")

        logging.info(f"Loaded parameterized diffusion EMA weights from {network_path}")

        self.use_guidance = False

        # Store current z embedding (set by eval agent before each episode)
        self.current_z = None

        # Store current start/end positions (for styles setup)
        self.current_start = None
        self.current_end = None

    def p_mean_var(
        self,
        x,
        t,
        cond,
        index=None,
        deterministic=False,
    ):
        """Predict noise with z conditioning - exactly like DiffusionEval but with z."""
        if self.current_z is None:
            raise ValueError("current_z must be set before calling forward()")

        # Expand z to match batch size
        batch_size = x.shape[0]
        if self.current_z.shape[0] == 1 and batch_size > 1:
            # Repeat z to match batch size: (1, latent_dim) -> (B, latent_dim)
            z_expanded = self.current_z.expand(batch_size, -1)
        else:
            z_expanded = self.current_z

        # Add start/end to cond if available (for styles setup)
        if self.current_start is not None:
            start_expanded = self.current_start.expand(batch_size, -1) if self.current_start.shape[0] == 1 else self.current_start
            cond = {**cond, "start": start_expanded}
        if self.current_end is not None:
            end_expanded = self.current_end.expand(batch_size, -1) if self.current_end.shape[0] == 1 else self.current_end
            cond = {**cond, "end": end_expanded}

        # Predict noise with CFG guidance
        if self.cfg_guidance_scale != 1.0 and hasattr(self.actor, 'z_empty'):
            # Classifier-Free Guidance:
            # ε_cfg = ε_uncond + ω * (ε_cond - ε_uncond)

            # Conditional prediction: ε_cond = ε_θ(x_t, t, c, z)
            noise_cond = self.actor(x, t, cond=cond, z=z_expanded)

            # Unconditional prediction: ε_uncond = ε_θ(x_t, t, c, z_∅)
            z_empty = self.actor.z_empty.expand(batch_size, -1)
            noise_uncond = self.actor(x, t, cond=cond, z=z_empty)

            # Guided prediction: ε_cfg = ε_uncond + ω * (ε_cond - ε_uncond)
            noise = noise_uncond + self.cfg_guidance_scale * (noise_cond - noise_uncond)
        else:
            # Standard conditional prediction (no guidance)
            noise = self.actor(x, t, cond=cond, z=z_expanded)

        # Predict x_0 (exact copy from DiffusionEval)
        if self.predict_epsilon:
            if self.use_ddim:
                """
                x₀ = (xₜ - √ (1-αₜ) ε )/ √ αₜ
                """
                alpha = extract(self.ddim_alphas, index, x.shape)
                alpha_prev = extract(self.ddim_alphas_prev, index, x.shape)
                sqrt_one_minus_alpha = extract(
                    self.ddim_sqrt_one_minus_alphas, index, x.shape
                )
                x_recon = (x - sqrt_one_minus_alpha * noise) / (alpha**0.5)
            else:
                """
                x₀ = √ 1\α̅ₜ xₜ - √ 1\α̅ₜ-1 ε
                """
                x_recon = (
                    extract(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
                    - extract(self.sqrt_recipm1_alphas_cumprod, t, x.shape) * noise
                )
        else:  # directly predicting x₀
            x_recon = noise
        if self.denoised_clip_value is not None:
            x_recon.clamp_(-self.denoised_clip_value, self.denoised_clip_value)
            if self.use_ddim:
                # re-calculate noise based on clamped x_recon
                noise = (x - alpha ** (0.5) * x_recon) / sqrt_one_minus_alpha

        # Clip epsilon for numerical stability
        if self.use_ddim and self.eps_clip_value is not None:
            noise.clamp_(-self.eps_clip_value, self.eps_clip_value)

        # Get mu (exact copy from DiffusionEval)
        if self.use_ddim:
            """
            μ = √ αₜ₋₁ x₀ + √(1-αₜ₋₁ - σₜ²) ε
            """
            if deterministic:
                etas = torch.zeros((x.shape[0], 1, 1)).to(x.device)
            else:
                etas = self.eta(cond).unsqueeze(1)  # B x 1 x (Da or 1)
            sigma = (
                etas
                * ((1 - alpha_prev) / (1 - alpha) * (1 - alpha / alpha_prev)) ** 0.5
            ).clamp_(min=1e-10)
            dir_xt_coef = (1.0 - alpha_prev - sigma**2).clamp_(min=0).sqrt()
            mu = (alpha_prev**0.5) * x_recon + dir_xt_coef * noise
            var = sigma**2
            logvar = torch.log(var)
        else:
            """
            μₜ = β̃ₜ √ α̅ₜ₋₁/(1-α̅ₜ)x₀ + √ αₜ (1-α̅ₜ₋₁)/(1-α̅ₜ)xₜ
            """
            mu = (
                extract(self.ddpm_mu_coef1, t, x.shape) * x_recon
                + extract(self.ddpm_mu_coef2, t, x.shape) * x
            )
            logvar = extract(self.ddpm_logvar_clipped, t, x.shape)
        return mu, logvar
