"""JEPA Predictor."""

import math
import torch
import torch.nn as nn
from functools import partial

from equitiesjepa.modules import Block, DeltaPositionEmbedding, trunc_normal_


class JEPAPredictor(nn.Module):
    """Predicts representations at masked positions from visible context."""

    def __init__(
        self,
        encoder_dim: int = 128,
        predictor_dim: int = 64,
        depth: int = 4,
        num_heads: int = 4,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        norm_layer=None,
        # Position embedding params
        max_delta: int = 600,
        num_slots: int = 24,
        use_fourier_pos: bool = True,
        init_std: float = 0.02,
    ):
        super().__init__()
        self.encoder_dim = encoder_dim
        self.predictor_dim = predictor_dim
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)

        self.embed = nn.Linear(encoder_dim, predictor_dim)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_dim))
        self.pos_embed = DeltaPositionEmbedding(
            dim=predictor_dim,
            max_delta=max_delta,
            num_slots=num_slots,
            use_fourier=use_fourier_pos,
        )

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.ModuleList([
            Block(
                dim=predictor_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[i],
                norm_layer=norm_layer,
            )
            for i in range(depth)
        ])

        self.norm = norm_layer(predictor_dim)
        self.proj = nn.Linear(predictor_dim, encoder_dim)
        self.init_std = init_std
        self.apply(self._init_weights)
        trunc_normal_(self.mask_token, std=init_std)
        self._rescale_blocks()

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def _rescale_blocks(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def forward(self, context: torch.Tensor, mask: torch.Tensor, stride: int = 1, has_cls: bool = True):
        B = context.size(0)
        device = context.device
        _, L, K = mask.shape
        N = L * K
        if has_cls:
            ctx_tokens = context[:, 1:, :]
        else:
            ctx_tokens = context
        ctx_tokens = self.embed(ctx_tokens)
        mask_flat = mask.reshape(B, N)
        x = self.mask_token.to(ctx_tokens.dtype).expand(B, N, -1).clone()
        x[mask_flat] = ctx_tokens[mask_flat]
        pos = self.pos_embed(L, K, stride=stride, include_cls=False, device=device)
        x = x + pos
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        x = self.proj(x)
        return x


def compute_jepa_loss(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, loss_exp: float = 1.0):
    B, L, K = mask.shape
    mask_flat = mask.reshape(B, L * K)
    loss_mask = ~mask_flat
    pred_masked = pred[loss_mask]
    target_masked = target[loss_mask]
    loss = torch.mean(torch.abs(pred_masked - target_masked) ** loss_exp) / loss_exp
    return loss

