# rl_framework/models.py
from typing import Sequence, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class MLP(nn.Module):
    """Generic MLP backbone that flattens any non-batch input shape."""

    def __init__(
        self,
        input_dim: int,
        hidden_sizes: Sequence[int] = (128, 128),
        activation=nn.ReLU,
    ):
        super().__init__()
        layers = []
        last_dim = input_dim
        for h in hidden_sizes:
            layers.append(nn.Linear(last_dim, h))
            layers.append(activation())
            last_dim = h
        self.net = nn.Sequential(*layers)
        self.output_dim = last_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, *obs_shape)
        x = x.view(x.size(0), -1)
        return self.net(x)


class CNN(nn.Module):
    """
    Generic CNN backbone for image observations.

    Key point:
    - Uses AdaptiveAvgPool2d(1), so it works for ANY H, W
      as long as they're large enough for the conv stack.
    """

    def __init__(self, in_channels: int, feature_dim: int = 256):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
        )
        self.pool = nn.AdaptiveAvgPool2d(1)  # -> (B, 64, 1, 1)
        self.fc = nn.Linear(64, feature_dim)
        self.output_dim = feature_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Allow uint8 inputs (e.g., Atari frames)
        if x.dtype == torch.uint8:
            x = x.float() / 255.0
        x = self.conv(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)  # (B, 64)
        x = self.fc(x)
        return F.relu(x)


class QHead(nn.Module):
    """Linear head mapping features to Q-values."""

    def __init__(self, feature_dim: int, num_actions: int):
        super().__init__()
        self.fc = nn.Linear(feature_dim, num_actions)

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        return self.fc(features)


class ActorCriticHead(nn.Module):
    """Shared-parameter actor-critic head."""

    def __init__(self, feature_dim: int, num_actions: int):
        super().__init__()
        self.actor = nn.Linear(feature_dim, num_actions)
        self.critic = nn.Linear(feature_dim, 1)

    def forward(self, features: torch.Tensor):
        logits = self.actor(features)
        value = self.critic(features)
        return logits, value


class ActorCritic(nn.Module):
    """Full actor-critic model = backbone + actor-critic head."""

    def __init__(self, backbone: nn.Module, feature_dim: int, num_actions: int):
        super().__init__()
        self.backbone = backbone
        self.head = ActorCriticHead(feature_dim, num_actions)

    def forward(self, x: torch.Tensor):
        features = self.backbone(x)
        return self.head(features)


class ValueHead(nn.Module):
    """Linear head mapping features to a scalar value."""

    def __init__(self, feature_dim: int):
        super().__init__()
        self.fc = nn.Linear(feature_dim, 1)

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        # (B, 1) -> (B,)
        return self.fc(features).squeeze(-1)


class ValueNet(nn.Module):
    """
    Backbone + value head.
    Used both as a standard value network and as the Hodge potential u(s).
    """

    def __init__(self, backbone: nn.Module, feature_dim: int):
        super().__init__()
        self.backbone = backbone
        self.head = ValueHead(feature_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.backbone(x)
        return self.head(features)


def _is_image_obs_shape(obs_shape: Tuple[int, ...]) -> bool:
    # Heuristic: 3D obs treated as image (C, H, W)
    return len(obs_shape) == 3


def _build_backbone(
    obs_shape: Tuple[int, ...],
    use_cnn: bool = False,
    hidden_sizes: Sequence[int] = (128, 128),
    feature_dim: int = 256,
):
    """
    Internal helper: given obs_shape, decide MLP or CNN backbone,
    return (backbone_module, feature_dim).
    """
    if use_cnn or _is_image_obs_shape(obs_shape):
        in_channels = obs_shape[0]
        backbone = CNN(in_channels=in_channels, feature_dim=feature_dim)
        feature_dim = backbone.output_dim
    else:
        input_dim = int(np.prod(obs_shape))
        backbone = MLP(input_dim=input_dim, hidden_sizes=hidden_sizes)
        feature_dim = backbone.output_dim
    return backbone, feature_dim


def build_q_network(
    obs_shape: Tuple[int, ...],
    num_actions: int,
    use_cnn: bool = False,
    hidden_sizes: Sequence[int] = (128, 128),
    feature_dim: int = 256,
) -> nn.Module:
    """
    Helper that chooses MLP or CNN backbone and adds a Q-head.
    All Q-learning algorithms should use this to share network structure.
    """
    backbone, feature_dim = _build_backbone(
        obs_shape,
        use_cnn=use_cnn,
        hidden_sizes=hidden_sizes,
        feature_dim=feature_dim,
    )
    q_head = QHead(feature_dim, num_actions)
    return nn.Sequential(backbone, q_head)


def build_actor_critic(
    obs_shape: Tuple[int, ...],
    num_actions: int,
    use_cnn: bool = False,
    hidden_sizes: Sequence[int] = (128, 128),
    feature_dim: int = 256,
) -> ActorCritic:
    """
    Helper that chooses MLP or CNN backbone and builds an ActorCritic model.
    """
    backbone, feature_dim = _build_backbone(
        obs_shape,
        use_cnn=use_cnn,
        hidden_sizes=hidden_sizes,
        feature_dim=feature_dim,
    )
    return ActorCritic(backbone, feature_dim, num_actions)


def build_value_network(
    obs_shape: Tuple[int, ...],
    use_cnn: bool = False,
    hidden_sizes: Sequence[int] = (128, 128),
    feature_dim: int = 256,
) -> ValueNet:
    """
    Helper to build a scalar value network V(s) or potential u(s).
    HFPS and any value-based algorithm share this structure.
    """
    backbone, feature_dim = _build_backbone(
        obs_shape,
        use_cnn=use_cnn,
        hidden_sizes=hidden_sizes,
        feature_dim=feature_dim,
    )
    return ValueNet(backbone, feature_dim)
