import argparse

import torch
import torch.nn as nn

from pulse_common import (
    ModelConfig,
    PULSEBackbone,
    PoseHead,
    build_single_frame_input,
    set_seed,
)


class PULSESingleFrame(nn.Module):
    def __init__(self, cfg: ModelConfig) -> None:
        super().__init__()
        self.cfg = cfg
        self.backbone = PULSEBackbone(cfg)
        self.prompt_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=cfg.embed_dim,
            nhead=cfg.heads,
            dim_feedforward=cfg.embed_dim * 4,
            dropout=cfg.dropout,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=cfg.depth)
        self.head = PoseHead(cfg)

    def forward(self, h: torch.Tensor) -> torch.Tensor:
        # h: (B, R, A, D)
        # "Dual-Domain Feature Representation" (single-frame setting).
        spatial_tokens = self.backbone.spatial_tokens(h)
        doppler_tokens = self.backbone.doppler_tokens(h)
        gate = self.backbone.doppler_gate(doppler_tokens)
        gate = self.backbone.patch_gate(gate)
        doppler_prompt = self.backbone.patch_tokens(doppler_tokens)
        doppler_prompt = self.prompt_proj(doppler_prompt)

        # "Controlled Prompting via Conditional Attention" + "Alignment and Prompting".
        spatial_tokens = spatial_tokens + self.cfg.gate_strength * gate * doppler_prompt
        spatial_tokens = self.encoder(spatial_tokens)
        pooled = spatial_tokens.mean(dim=1)
        joints = self.head(pooled)
        return joints


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Single-person PULSE prediction on synthetic input."
    )
    parser.add_argument("--batch-size", type=int, default=2)
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    cfg = ModelConfig()
    set_seed(args.seed)
    model = PULSESingleFrame(cfg)
    model.eval()

    dummy_input = build_single_frame_input(cfg, args.batch_size)
    with torch.no_grad():
        pred = model(dummy_input)

    print("Input shape [batch_size, R, A, D]:", tuple(dummy_input.shape))
    print("Output shape (B, J, 3):", tuple(pred.shape))
    print("Output:")
    print(pred[0])


if __name__ == "__main__":
    main()
