# maml_rl/anil_networks.py
from typing import Tuple
import torch
import torch.nn as nn
from torch.distributions import Normal

def build_mlp(input_dim: int, hidden_dim: int) -> nn.Sequential:
    return nn.Sequential(
        nn.Linear(input_dim, hidden_dim), nn.Tanh(),
        nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),
    )

class PolicyBackbone(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int):
        super().__init__()
        self.net = build_mlp(input_dim, hidden_dim)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

class GaussianPolicyHead(nn.Module):
    def __init__(self, feat_dim: int, action_dim: int, deterministic: bool = False):
        super().__init__()
        self.mu = nn.Linear(feat_dim, action_dim)
        self.log_std = nn.Parameter(torch.zeros(action_dim))
        self.deterministic = deterministic

    def dist(self, feats: torch.Tensor) -> Tuple[Normal, torch.Tensor]:
        mean = self.mu(feats)
        std = self.log_std.exp().expand_as(mean)
        return Normal(mean, std), mean

    def sample(self, feats: torch.Tensor):
        normal, mean = self.dist(feats)
        if self.deterministic:
            action = mean
            logp = torch.zeros(mean.size(0), 1, device=mean.device)
        else:
            action = normal.rsample()
            logp = normal.log_prob(action).sum(dim=-1, keepdim=True)
        return action, logp

    def log_prob(self, feats: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        normal, _ = self.dist(feats)
        return normal.log_prob(action).sum(dim=-1)

class ValueBackbone(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int):
        super().__init__()
        self.net = build_mlp(input_dim, hidden_dim)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

class ValueHead(nn.Module):
    def __init__(self, feat_dim: int):
        super().__init__()
        self.out = nn.Linear(feat_dim, 1)
    def forward(self, feats: torch.Tensor) -> torch.Tensor:
        return self.out(feats).squeeze(-1)
