import torch
from torch import nn

class Date2VecConvert:
    def __init__(self, model_path="./d2v_model/d2v_98291_17.169918439404636.pth", device="cpu"):
        self.device = torch.device(device)
        obj = torch.load(model_path, map_location=self.device)
        if isinstance(obj, dict) and "state_dict" in obj:
            self.model = Date2Vec()
            self.model.load_state_dict(obj["state_dict"])
        elif isinstance(obj, dict):
            self.model = Date2Vec()
            self.model.load_state_dict(obj)
        else:
            self.model = obj
        self.model.eval().to(self.device)

    def __call__(self, x):
        x = torch.as_tensor(x, dtype=torch.float32, device=self.device)
        if x.ndim == 1:
            x = x.unsqueeze(0)
        with torch.inference_mode():
            z = self.model.encode(x)
        return z.detach().cpu()

class Date2Vec(nn.Module):
    def __init__(self, k=32, act="sin"):
        super(Date2Vec, self).__init__()
        if k % 2 == 0:
            k1 = k // 2
            k2 = k // 2
        else:
            k1 = k // 2
            k2 = k // 2 + 1
        self.fc1 = nn.Linear(6, k1)
        self.fc2 = nn.Linear(6, k2)
        self.d2 = nn.Dropout(0.3)
        self.activation = torch.sin if act == 'sin' else torch.cos
        self.fc3 = nn.Linear(k, k // 2)
        self.d3 = nn.Dropout(0.3)
        self.fc4 = nn.Linear(k // 2, 6)
        self.fc5 = nn.Linear(6, 6)

    def _time_to_radians(self, x):
        mod = torch.tensor([12., 31., 7., 24., 60., 60.], device=x.device, dtype=x.dtype)
        return (x % mod) / mod * (2 * torch.pi)

    def forward(self, x):
        xr = self._time_to_radians(x)
        out1 = self.fc1(xr)
        out2 = self.d2(self.activation(self.fc2(xr)))
        out = torch.cat([out1, out2], 1)
        out = self.d3(self.fc3(out))
        out = self.fc4(out)
        out = self.fc5(out)
        return out

    def encode(self, x):
        xr = self._time_to_radians(x)
        out1 = self.fc1(xr)
        out2 = self.activation(self.fc2(xr))
        out = torch.cat([out1, out2], 1)
        return out

if __name__ == "__main__":
    model = Date2Vec().eval()
    inp = torch.randn(1, 6)
    out = model(inp)
    print(out)
    print(out.shape)