# modules/opponent/context_builder.py
# Builds a per-step feature vector (context) used by ContextEncoder.

import torch
import pdb

class ContextBuilder:
    def __init__(self, args, scheme):
        self.args = args
        self.scheme = scheme
        self.include_mean_obs = getattr(args, "ctx_mean_obs", True)
        self.include_last_actions = getattr(args, "ctx_last_actions", True)
        self.include_reward = getattr(args, "ctx_reward", False)
        self.n_actions = args.n_actions
        self.obs_dim = scheme["obs"]["vshape"]
        self.n_agents = args.n_agents

    def build_step(self, ep_batch, t):
        B = ep_batch.batch_size
        device = ep_batch.device
        feats = []
        if self.include_mean_obs:
            obs = ep_batch["obs"][:, t]  # [B, n_agents, obs_dim]
            feats.append(obs)
        if self.include_last_actions:
            if t > 0:
                last_a = ep_batch["actions_onehot"][:, t-1]  # [B, n_agents, n_actions]
                feats.append(last_a)
            else:

                feats.append(torch.zeros(B, self.n_agents, self.n_actions, device=device))
        if self.include_reward:
            if t > 0:
                last_r = ep_batch["reward"][:, t-1].unsqueeze(-1).expand(B, self.n_agents, 1)
            else:
                last_r = torch.zeros(B, self.n_agents, 1, device=device)

            feats.append(last_r)
        if len(feats) == 0:
            raise ValueError("ContextBuilder: no features enabled (enable ctx_mean_obs or ctx_last_actions or ctx_reward)")
        return torch.cat(feats, dim=-1)  # [B, N, C]

    def build_window(self, ep_batch, t, window):
        xs = []
        start = max(0, t - window + 1)
        for u in range(start, t + 1):
            xs.append(self.build_step(ep_batch, u))
        # pad at beginning if needed
        if len(xs) < window:
            pad = xs[0].new_zeros(xs[0].shape)
            pads = [pad for _ in range(window - len(xs))]
            xs = pads + xs
        return torch.stack(xs, dim=1)  # [B, T_ctx, n_agent, C]
