from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F


@dataclass
class ModelConfig:
    range_bins: int = 64
    angle_bins: int = 64
    doppler_bins: int = 16
    patch_r: int = 4
    patch_a: int = 4
    embed_dim: int = 32
    depth: int = 4
    heads: int = 4
    dropout: float = 0.1
    gate_strength: float = 1.0
    num_joints: int = 17
    window_k: int = 9


def set_seed(seed: int) -> None:
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class PULSEBackbone(nn.Module):
    """
    Shared blocks used in both single-frame and multi-frame inference.
    Paper mapping:
    - "Dual-Domain Feature Representation": spatial magnitude + Doppler tokens.
    - "Controlled Prompting via Conditional Attention": confidence gate.
    - "Alignment and Prompting": patch-level alignment.
    """

    def __init__(self, cfg: ModelConfig) -> None:
        super().__init__()
        if cfg.range_bins % cfg.patch_r != 0 or cfg.angle_bins % cfg.patch_a != 0:
            raise ValueError("Range/angle bins must be divisible by patch size.")

        self.cfg = cfg
        self.spatial_embed = nn.Conv2d(
            in_channels=1,
            out_channels=cfg.embed_dim,
            kernel_size=(cfg.patch_r, cfg.patch_a),
            stride=(cfg.patch_r, cfg.patch_a),
        )
        self.doppler_mlp = nn.Sequential(
            nn.Linear(cfg.doppler_bins, cfg.embed_dim),
            nn.GELU(),
            nn.Linear(cfg.embed_dim, cfg.embed_dim),
        )
        self.gate_proj = nn.Linear(cfg.embed_dim, 1)

    def spatial_tokens(self, h: torch.Tensor) -> torch.Tensor:
        # "Magnitude Structural Stream" in Dual-Domain Feature Representation.
        spatial_mag = h.abs().mean(dim=-1, keepdim=True)  # (B, R, A, 1)
        spatial_mag = spatial_mag.permute(0, 3, 1, 2)  # (B, 1, R, A)
        tokens = self.spatial_embed(spatial_mag)  # (B, d, R', A')
        tokens = tokens.flatten(2).transpose(1, 2)  # (B, N_s, d)
        return tokens

    def doppler_tokens(self, h: torch.Tensor) -> torch.Tensor:
        # "Motion Cue Stream" in Dual-Domain Feature Representation.
        v = h.abs()
        tokens = v.reshape(v.shape[0], -1, self.cfg.doppler_bins)
        tokens = self.doppler_mlp(tokens)
        tokens = tokens.reshape(
            v.shape[0], self.cfg.range_bins, self.cfg.angle_bins, self.cfg.embed_dim
        )
        return tokens

    def doppler_gate(self, tokens: torch.Tensor) -> torch.Tensor:
        # "Prompt confidence gating" in Controlled Prompting section.
        gate = tokens.reshape(tokens.shape[0], -1, self.cfg.embed_dim)
        gate = torch.sigmoid(self.gate_proj(gate))  # (B, R*A, 1)
        gate = gate.reshape(tokens.shape[0], self.cfg.range_bins, self.cfg.angle_bins, 1)
        return gate

    def patch_gate(self, gate: torch.Tensor) -> torch.Tensor:
        # "Patch-level alignment" in Alignment and Prompting section.
        gate = gate.permute(0, 3, 1, 2)  # (B, 1, R, A)
        gate = F.avg_pool2d(
            gate,
            kernel_size=(self.cfg.patch_r, self.cfg.patch_a),
            stride=(self.cfg.patch_r, self.cfg.patch_a),
        )
        gate = gate.flatten(2).transpose(1, 2)  # (B, N_s, 1)
        return gate

    def patch_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
        # Patch-aligned Doppler prompts (Alignment and Prompting section).
        tokens = tokens.permute(0, 3, 1, 2)  # (B, d, R, A)
        tokens = F.avg_pool2d(
            tokens,
            kernel_size=(self.cfg.patch_r, self.cfg.patch_a),
            stride=(self.cfg.patch_r, self.cfg.patch_a),
        )
        tokens = tokens.flatten(2).transpose(1, 2)  # (B, N_s, d)
        return tokens


class PoseHead(nn.Module):
    """
    Lightweight pose regressor for 3D joints (Prediction head in Method).
    """

    def __init__(self, cfg: ModelConfig) -> None:
        super().__init__()
        self.head = nn.Linear(cfg.embed_dim, cfg.num_joints * 3)
        self.cfg = cfg

    def forward(self, pooled: torch.Tensor) -> torch.Tensor:
        joints = self.head(pooled).reshape(-1, self.cfg.num_joints, 3)
        return joints


def build_single_frame_input(cfg: ModelConfig, batch_size: int) -> torch.Tensor:
    return torch.randn(
        batch_size,
        cfg.range_bins,
        cfg.angle_bins,
        cfg.doppler_bins,
    )


def build_multi_frame_input(cfg: ModelConfig, batch_size: int) -> torch.Tensor:
    return torch.randn(
        batch_size,
        cfg.window_k,
        cfg.range_bins,
        cfg.angle_bins,
        cfg.doppler_bins,
    )
