# layers/periodic_marker.py
import torch
import torch.nn as nn

class PeriodicMarker(nn.Module):

    def __init__(self, c_in, embed_dim=None, use_minute=False):
        super().__init__()
        self.c_in = c_in
        self.embed_dim = c_in if embed_dim is None else embed_dim
        self.use_minute = use_minute

        # Embedding tables
        self.month_emb = nn.Embedding(13, self.embed_dim)   # 0..12
        self.day_emb = nn.Embedding(32, self.embed_dim)     # 0..31
        self.weekday_emb = nn.Embedding(7, self.embed_dim)  # 0..6
        self.hour_emb = nn.Embedding(24, self.embed_dim)    # 0..23
        if self.use_minute:
            self.minute_emb = nn.Embedding(60, self.embed_dim)

        # optional projector to match channels
        if self.embed_dim != self.c_in:
            self.proj = nn.Linear(self.embed_dim, self.c_in, bias=False)
        else:
            self.proj = None

        # small init
        nn.init.normal_(self.month_emb.weight, 0.0, 0.02)
        nn.init.normal_(self.day_emb.weight, 0.0, 0.02)
        nn.init.normal_(self.weekday_emb.weight, 0.0, 0.02)
        nn.init.normal_(self.hour_emb.weight, 0.0, 0.02)
        if self.use_minute:
            nn.init.normal_(self.minute_emb.weight, 0.0, 0.02)

    def forward(self, x_mark):
        """
        x_mark: [B, L, D_timefeat] (long)
        returns gate: [B, L, C]
        """
        if x_mark is None:
            # return zeros gate on same device
            device = next(self.parameters()).device if any(True for _ in self.parameters()) else torch.device('cpu')
            return torch.zeros(1, 1, self.c_in, device=device)

        xm = x_mark.long()
        # ensure we have at least 4 dims
        # mapping assumes standard: month, day, weekday, hour, (minute)
        month = xm[..., 0]
        day = xm[..., 1]
        weekday = xm[..., 2]
        hour = xm[..., 3]
        emb = self.month_emb(month) + self.day_emb(day) + self.weekday_emb(weekday) + self.hour_emb(hour)
        if self.use_minute and xm.size(-1) > 4:
            minute = xm[..., 4]
            emb = emb + self.minute_emb(minute)

        # emb: [B, L, embed_dim]
        if self.proj is not None:
            emb = self.proj(emb)  # -> [B, L, C]

        gate = torch.sigmoid(emb)  # [B, L, C] in (0,1)
        return gate
