"""Equities JEPA model."""

import copy
import torch
import torch.nn as nn
import torch.nn.functional as F

from .encoder import TemporalEncoder
from .predictor import JEPAPredictor, compute_jepa_loss


class MarketJEPA(nn.Module):
    """JEPA model for self-supervised learning on factor token sequences."""

    def __init__(
        self,
        # Encoder config
        dim: int = 128,
        encoder_depth: int = 6,
        num_heads: int = 4,
        mlp_ratio: float = 4.0,
        # Predictor config
        predictor_dim: int = 64,
        predictor_depth: int = 4,
        # Position embedding config
        max_delta: int = 600,
        num_slots: int = 24,
        use_fourier_pos: bool = True,
        # Regularization
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.1,
        # Other
        use_cls: bool = True,
        init_std: float = 0.02,
    ):
        super().__init__()

        self.encoder = TemporalEncoder(
            dim=dim,
            depth=encoder_depth,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            drop_rate=drop_rate,
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=drop_path_rate,
            max_delta=max_delta,
            num_slots=num_slots,
            use_fourier_pos=use_fourier_pos,
            use_cls=use_cls,
            init_std=init_std,
        )

        self.predictor = JEPAPredictor(
            encoder_dim=dim,
            predictor_dim=predictor_dim,
            depth=predictor_depth,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            drop_rate=drop_rate,
            attn_drop_rate=attn_drop_rate,
            max_delta=max_delta,
            num_slots=num_slots,
            use_fourier_pos=use_fourier_pos,
            init_std=init_std,
        )

        self.use_cls = use_cls
        self.dim = dim

    def forward(self, tokens, mask, stride=1):
        ctx = self.encoder(tokens, mask=mask, stride=stride)
        pred = self.predictor(ctx, mask=mask, stride=stride, has_cls=self.use_cls)
        return pred

    def encode(self, tokens, stride=1):
        return self.encoder(tokens, mask=None, stride=stride)

    def get_cls_features(self, tokens, stride=1):
        x = self.encode(tokens, stride=stride)
        return self.encoder.get_cls_output(x)

    @torch.no_grad()
    def create_target_encoder(self):
        target = copy.deepcopy(self.encoder)
        for p in target.parameters():
            p.requires_grad = False
        return target


@torch.no_grad()
def update_ema(encoder: nn.Module, target_encoder: nn.Module, momentum: float):
    for p_enc, p_tgt in zip(encoder.parameters(), target_encoder.parameters()):
        p_tgt.data.mul_(momentum).add_(p_enc.data, alpha=1 - momentum)


def forward_target(target_encoder: nn.Module, tokens: torch.Tensor, stride: int = 1, normalize: bool = True):
    with torch.no_grad():
        h = target_encoder(tokens, mask=None, stride=stride)
        if target_encoder.use_cls:
            h = h[:, 1:, :]
        if normalize:
            h = F.layer_norm(h, (h.size(-1),))
    return h


def compute_loss(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, loss_exp: float = 1.0):
    return compute_jepa_loss(pred, target, mask, loss_exp)

