"""
world_model.py
Full Adaptive World Model (AWML).
Combines encoder, FNO backbone, causal modular blocks, and decoder.
"""

import torch
import torch.nn as nn
from .neural_operator import FNO2d
from .causal_modules import ModularDynamics


class Encoder(nn.Module):
    def __init__(self, obs_dim, latent_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim),
        )

    def forward(self, x):
        return self.net(x)


class Decoder(nn.Module):
    def __init__(self, latent_dim, obs_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, obs_dim),
        )

    def forward(self, z):
        return self.net(z)


class AWMLWorldModel(nn.Module):
    def __init__(self, obs_dim, action_dim, latent_dim, num_modules, use_fno=True):
        super().__init__()
        self.encoder = Encoder(obs_dim, latent_dim)
        self.decoder = Decoder(latent_dim, obs_dim)

        self.use_fno = use_fno
        if use_fno:
            self.operator = FNO2d(modes_height=12, modes_width=12, width=latent_dim)
        else:
            self.operator = None

        self.dynamics = ModularDynamics(num_modules, latent_dim, hidden_dim=128)

    def forward(self, obs, action):
        z = self.encoder(obs)
        z_next, dists = self.dynamics(z, action)
        obs_recon = self.decoder(z_next)
        return obs_recon, z_next, dists
