"""Temporal Encoder for Equities JEPA."""

import math
import torch
import torch.nn as nn
from functools import partial

from equitiesjepa.modules import Block, DeltaPositionEmbedding, trunc_normal_


class TemporalEncoder(nn.Module):
    """Transformer encoder for temporal sequences of factor tokens."""

    def __init__(
        self,
        dim: int = 128,
        depth: int = 6,
        num_heads: int = 4,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.1,
        norm_layer=None,
        # Position embedding params
        max_delta: int = 600,
        num_slots: int = 24,
        use_fourier_pos: bool = True,
        use_cls: bool = True,
        init_std: float = 0.02,
    ):
        super().__init__()
        self.dim = dim
        self.depth = depth
        self.use_cls = use_cls
        self.num_slots = num_slots
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, dim)) if use_cls else None
        self.mask_token = nn.Parameter(torch.zeros(1, 1, dim))
        self.pos_embed = DeltaPositionEmbedding(
            dim=dim,
            max_delta=max_delta,
            num_slots=num_slots,
            use_fourier=use_fourier_pos,
        )

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.ModuleList([
            Block(
                dim=dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[i],
                norm_layer=norm_layer,
            )
            for i in range(depth)
        ])

        self.norm = norm_layer(dim)
        self.init_std = init_std
        self.apply(self._init_weights)
        self._init_special_tokens()
        self._rescale_blocks()

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def _init_special_tokens(self):
        if self.cls_token is not None:
            trunc_normal_(self.cls_token, std=self.init_std)
        trunc_normal_(self.mask_token, std=self.init_std)

    def _rescale_blocks(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def forward(self, tokens: torch.Tensor, mask: torch.Tensor = None, stride: int = 1):
        B, L, K, d = tokens.shape
        device = tokens.device
        x = tokens.reshape(B, L * K, d)
        if mask is not None:
            mask_flat = mask.reshape(B, L * K)
            mask_tokens = self.mask_token.expand(B, L * K, -1)
            x = torch.where(mask_flat.unsqueeze(-1), x, mask_tokens)
        if self.use_cls:
            cls_tokens = self.cls_token.expand(B, -1, -1)
            x = torch.cat([cls_tokens, x], dim=1)
        pos = self.pos_embed(L, K, stride=stride, include_cls=self.use_cls, device=device)
        x = x + pos
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return x

    def forward_features(self, tokens, mask=None, stride=1):
        return self.forward(tokens, mask, stride)

    def get_cls_output(self, x):
        if self.use_cls:
            return x[:, 0]
        raise ValueError("No CLS token")

    def get_token_output(self, x):
        if self.use_cls:
            return x[:, 1:]
        return x

