"""Policy and Discriminator are moved to d4rl-cd/common/models.py
"""
import torch
import torch.nn as nn


class DynamicsModel(nn.Module):

    def __init__(
        self,
        state_dim: int,
        act_dim: int,
        domain_dim: int,
        hid_dim: int,
    ):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim + act_dim + domain_dim, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, state_dim),
        )

    def forward(self, z, act, domain_id):
        x = torch.cat((z, act, domain_id), dim=-1)
        return self.net(x)


class InverseDynamicsModel(nn.Module):

    def __init__(
        self,
        state_dim: int,
        act_dim: int,
        domain_dim: int,
        hid_dim: int,
    ):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim * 2 + domain_dim, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, act_dim),
        )

    def forward(self, z, z_next, domain_id):
        x = torch.cat((z, z_next, domain_id), dim=-1)
        return self.net(x)