import argparse

import torch
import torch.nn as nn

from pulse_common import (
    ModelConfig,
    PULSEBackbone,
    PoseHead,
    build_multi_frame_input,
    set_seed,
)


class PULSEMultiFrame(nn.Module):
    def __init__(self, cfg: ModelConfig) -> None:
        super().__init__()
        self.cfg = cfg
        self.backbone = PULSEBackbone(cfg)
        self.prompt_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=cfg.embed_dim,
            nhead=cfg.heads,
            dim_feedforward=cfg.embed_dim * 4,
            dropout=cfg.dropout,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=cfg.depth)
        self.head = PoseHead(cfg)

    def forward(self, h: torch.Tensor) -> torch.Tensor:
        # h: (B, K, R, A, D)
        if h.shape[1] != self.cfg.window_k:
            raise ValueError("Input K does not match window_k.")

        # "Extension: Multi-frame Doppler Prompt Aggregation" (confidence-weighted).
        gates = []
        tokens = []
        for t in range(self.cfg.window_k):
            doppler_tokens = self.backbone.doppler_tokens(h[:, t])
            gate = self.backbone.doppler_gate(doppler_tokens)
            gates.append(gate)
            tokens.append(doppler_tokens)
        gates = torch.stack(gates, dim=1)  # (B, K, R, A, 1)
        tokens = torch.stack(tokens, dim=1)  # (B, K, R, A, d)
        gate_sum = gates.sum(dim=1)
        token_sum = (gates * tokens).sum(dim=1)
        denom = gate_sum + 1e-6
        agg_tokens = token_sum / denom
        gate_mean = gate_sum / float(self.cfg.window_k)

        # "Magnitude Structural Stream" from center frame (single-frame backbone).
        center_frame = h[:, self.cfg.window_k // 2]
        spatial_tokens = self.backbone.spatial_tokens(center_frame)
        gate = self.backbone.patch_gate(gate_mean)
        doppler_prompt = self.backbone.patch_tokens(agg_tokens)
        doppler_prompt = self.prompt_proj(doppler_prompt)

        # "Controlled Prompting via Conditional Attention" + "Alignment and Prompting".
        spatial_tokens = spatial_tokens + self.cfg.gate_strength * gate * doppler_prompt
        spatial_tokens = self.encoder(spatial_tokens)
        pooled = spatial_tokens.mean(dim=1)
        joints = self.head(pooled)
        return joints


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Multi-frame PULSE-style prediction on synthetic input."
    )
    parser.add_argument("--batch-size", type=int, default=2)
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    cfg = ModelConfig()
    set_seed(args.seed)
    model = PULSEMultiFrame(cfg)
    model.eval()

    dummy_input = build_multi_frame_input(cfg, args.batch_size)
    with torch.no_grad():
        pred = model(dummy_input)

    print("Input shape [batch_size, frames, R, A, D]:", tuple(dummy_input.shape))
    print("Output shape (B, J, 3):", tuple(pred.shape))
    print("Output")
    print(pred[0])


if __name__ == "__main__":
    main()
