import torch
import torch.nn as nn
import math

class FourierTimePolicy(nn.Module):
    def __init__(self, hidden_dim: int = 128, num_layers: int = 2,
                 n_freq: int = 8, include_tau: bool = True, **kwargs):
        super().__init__()
        assert num_layers >= 1
        assert n_freq >= 0

        self.n_freq = n_freq
        self.include_tau = include_tau
        dim_out = 5

        freqs = 2.0 ** torch.arange(n_freq, dtype=torch.get_default_dtype()) # f_k = 2^k, k=0..n_freq-1
        self.register_buffer("freqs", freqs, persistent=False)
        in_dim = 2 + (1 if include_tau else 0) + 2 * n_freq  # [x,y] + [τ?] + [sin,cos]*n_freq
        layers = []
        d = in_dim
        for _ in range(num_layers - 1):
            layers += [nn.Linear(d, hidden_dim), nn.ReLU()]
            d = hidden_dim
        layers += [nn.Linear(d, dim_out)]  # logits
        self.net = nn.Sequential(*layers)

    def _time_features(self, tau: torch.Tensor) -> torch.Tensor:
        if self.n_freq > 0:
            angles = 2 * math.pi * tau.unsqueeze(-1) * self.freqs.unsqueeze(0)  # (B, n_freq)
            sincos = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)  # (B, 2*n_freq)
            return torch.cat([tau.unsqueeze(-1), sincos], dim=-1) if self.include_tau else sincos
        else:
            # sem frequências: retorna só τ (ou vazio)
            return tau.unsqueeze(-1) if self.include_tau else tau.new_zeros((tau.shape[0], 0))

    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        xy = obs[:, :2]
        tau = obs[:, 2].clamp(0.0, 1.0)
        tfeat = self._time_features(tau)
        x = torch.cat([xy, tfeat], dim=-1)
        return self.net(x)  # (B, dim_out)


class Policy(nn.Module):
    def __init__(self, hidden_dim: int = 128, num_layers: int = 2, **kwargs):
        super().__init__()
        layers = []
        dim_in = 3; dim_out = 5
        for _ in range(num_layers - 1):
            layers.append(nn.Linear(dim_in, hidden_dim))
            layers.append(nn.ReLU())
            dim_in = hidden_dim
        layers.append(nn.Linear(dim_in, dim_out))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


class UniformPolicy(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()

    def forward(self, x):
        return torch.ones(5)

