"""
Parameterized Behavior Cloning (PBC) Evaluation Model.

Wraps trained PBC model for evaluation with z conditioning.
Similar to ParameterizedDiffusionEval but for BC (no diffusion sampling).
"""

import logging
import torch
import torch.nn as nn
from collections import namedtuple

log = logging.getLogger(__name__)

Sample = namedtuple("Sample", "trajectories chains")


class ParameterizedBCEval(nn.Module):
    """
    Evaluation wrapper for Parameterized BC.

    Provides interface compatible with EvalCloseDrawerPDPAgent:
    - Has current_z attribute that eval agent sets before each episode
    - Forward pass returns Sample namedtuple with trajectories

    Unlike ParameterizedDiffusionEval, this is a simple forward pass
    with no diffusion sampling - just network(state, z).
    """

    def __init__(
        self,
        network,
        horizon_steps,
        obs_dim,
        action_dim,
        network_path=None,
        device="cuda:0",
        **kwargs,
    ):
        super().__init__()
        self.device = device
        self.horizon_steps = horizon_steps
        self.obs_dim = obs_dim
        self.action_dim = action_dim

        self.network = network.to(device)

        # current_z is set by eval agent before each episode
        self.current_z = None

        # Load weights from checkpoint
        if network_path is not None:
            self._load_checkpoint(network_path)

        log.info(
            f"ParameterizedBCEval initialized with {sum(p.numel() for p in self.parameters())} parameters"
        )

    def _load_checkpoint(self, network_path):
        """Load network weights from checkpoint."""
        log.info(f"Loading PBC weights from {network_path}")
        checkpoint = torch.load(
            network_path, map_location=self.device, weights_only=True
        )

        # Try to load EMA weights first (preferred for evaluation)
        if "ema" in checkpoint:
            # EMA weights are stored with 'network.' prefix
            state_dict = checkpoint["ema"]
            # Filter to only network weights
            network_state_dict = {
                k.replace("network.", ""): v
                for k, v in state_dict.items()
                if k.startswith("network.")
            }
            if len(network_state_dict) > 0:
                self.network.load_state_dict(network_state_dict, strict=False)
                log.info(f"Loaded EMA weights from {network_path}")
            else:
                # Fallback: try loading full state dict
                self.load_state_dict(checkpoint["ema"], strict=False)
                log.info(f"Loaded EMA weights (full) from {network_path}")
        elif "model" in checkpoint:
            state_dict = checkpoint["model"]
            network_state_dict = {
                k.replace("network.", ""): v
                for k, v in state_dict.items()
                if k.startswith("network.")
            }
            if len(network_state_dict) > 0:
                self.network.load_state_dict(network_state_dict, strict=False)
                log.info(f"Loaded model weights from {network_path}")
            else:
                self.load_state_dict(checkpoint["model"], strict=False)
                log.info(f"Loaded model weights (full) from {network_path}")
        else:
            # Try loading directly
            self.load_state_dict(checkpoint, strict=False)
            log.info(f"Loaded weights directly from {network_path}")

    @torch.no_grad()
    def forward(self, cond, deterministic=True):
        """
        Forward pass for evaluation.

        Args:
            cond: dict with 'state' key containing (B, cond_steps, obs_dim)
            deterministic: bool (unused, BC is always deterministic)

        Returns:
            Sample: namedtuple with trajectories field containing predicted actions
        """
        if self.current_z is None:
            raise ValueError(
                "current_z must be set before calling forward(). "
                "Use model.current_z = z_embedding.unsqueeze(0)"
            )

        # Get state
        state = cond["state"]
        if state.dim() == 3:
            state = state[:, -1, :]  # Use most recent state: (B, obs_dim)

        batch_size = state.shape[0]

        # Expand z to match batch size
        if self.current_z.shape[0] == 1 and batch_size > 1:
            z = self.current_z.expand(batch_size, -1)
        else:
            z = self.current_z

        # Direct forward pass (no diffusion sampling)
        trajectory = self.network(state, z)

        return Sample(trajectory, None)

    def set_z(self, z):
        """Convenience method to set current_z."""
        if z.dim() == 1:
            z = z.unsqueeze(0)
        self.current_z = z.to(self.device)
