__all__ = [
    "DetermPolicy",
    "FiniteHorizonPolicy",
    "StochaPolicy",
    "ActionValue",
    "ActionValueDis",
    "StateValue",
    "ActionValueDistri",
    "ActionValueDistri2",
]
import torch
import warnings
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from gops.utils.common_utils import get_activation_func
from functorch import jacrev, vmap
from gops.utils.act_distribution_cls import Action_Distribution
from .ode_cell.wirings import Connected
from .ode_cell.SmODE import SmODE
class Para_dict(dict):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    @property
    def requires_grad(self):
        return self["params"].requires_grad
    @requires_grad.setter
    def requires_grad(self, value):
        self["params"].requires_grad = value
    @property
    def data(self):
        return self["params"].data
class Lips_K(nn.Module):
    def __init__(self, local, Lips_start, sizes) -> None:
        super().__init__()
        self.local = local
        if local:
            layers = []
            for j in range(0, len(sizes) - 2):
                layers += [nn.Linear(sizes[j], sizes[j + 1]), nn.Tanh()]
            layers += [nn.Linear(sizes[-2], sizes[-1], bias=True), nn.Softplus()]
            self.K = nn.Sequential(*layers)
            for i in range(len(self.K)):
                if isinstance(self.K[i], nn.Linear):
                    if isinstance(self.K[i + 1], nn.ReLU):
                        nn.init.kaiming_normal_(self.K[i].weight, nonlinearity="relu")
                    elif isinstance(self.K[i + 1], nn.LeakyReLU):
                        nn.init.kaiming_normal_(
                            self.K[i].weight, nonlinearity="leaky_relu"
                        )
                    else:
                        nn.init.xavier_normal_(self.K[i].weight)
            self.K[-2].bias.data += torch.tensor(Lips_start, dtype=torch.float).data
        else:
            self.K = torch.nn.Parameter(
                torch.tensor(Lips_start, dtype=torch.float), requires_grad=True
            )
    def forward(self, x):
        if self.local:
            return self.K(x)
        else:
            return F.softplus(self.K).repeat(x.shape[0]).unsqueeze(1)
class LipsNet(nn.Module):
    def __init__(
        self,
        sizes,
        activation,
        output_activation=nn.Identity,
        lips_init_value=100,
        eps=1e-5,
        lips_auto_adjust=True,
        loss_lambda=0.1,
        local_lips=False,
        lips_hidden_sizes=None,
    ) -> None:
        super().__init__()
        print("Your PyTorch version is", torch.__version__)
        print("To use LipsNet, the PyTorch version must be >=1.12 and <=2.2")
        layers = []
        for j in range(0, len(sizes) - 2):
            layers += [nn.Linear(sizes[j], sizes[j + 1]), activation()]
        layers += [nn.Linear(sizes[-2], sizes[-1]), output_activation()]
        self.mlp = nn.Sequential(*layers)
        for i in range(len(self.mlp)):
            if isinstance(self.mlp[i], nn.Linear):
                if isinstance(self.mlp[i + 1], nn.ReLU):
                    nn.init.kaiming_normal_(self.mlp[i].weight, nonlinearity="relu")
                elif isinstance(self.mlp[i + 1], nn.LeakyReLU):
                    nn.init.kaiming_normal_(
                        self.mlp[i].weight, nonlinearity="leaky_relu"
                    )
                else:
                    nn.init.xavier_normal_(self.mlp[i].weight)
        self.para_updated = False
        self.local = local_lips
        self.K = Lips_K(local_lips, lips_init_value, lips_hidden_sizes)
        self.loss_lambda = loss_lambda
        self.eps = eps
        self.lips_auto_adjust = lips_auto_adjust
        if lips_auto_adjust:
            self.regular_loss = 0
            self.register_full_backward_pre_hook(backward_hook)
    def forward(self, x):
        K_value = self.K(x)
        if self.lips_auto_adjust and self.training and K_value.requires_grad:
            loss = self.loss_lambda * (K_value**2).mean()
            self.regular_loss += loss
        f_out = self.mlp(x)
        if K_value.requires_grad:
            jacobi = vmap(jacrev(self.mlp))(x)
        else:
            with torch.no_grad():
                jacobi = vmap(jacrev(self.mlp))(x)
        norm = torch.norm(jacobi, 2, dim=(1, 2)).unsqueeze(1)
        f_out_Lips = K_value * f_out / (norm + self.eps)
        return f_out_Lips
def backward_hook(module, gout):
    module.regular_loss.backward(retain_graph=True)
    module.regular_loss = 0
    return gout
def count_vars(module):
    return sum([np.prod(p.shape) for p in module.parameters()])
def CNN(kernel_sizes, channels, strides, activation, input_channel):
    """Implementation of CNN.
    :param list kernel_sizes: list of kernel_size,
    :param list channels: list of channels,
    :param list strides: list of stride,
    :param activation: activation function,
    :param int input_channel: number of channels of input image.
    Return CNN.
    Input shape for CNN: (batch_size, channel_num, height, width).
    """
    layers = []
    for j in range(len(kernel_sizes)):
        act = activation
        if j == 0:
            layers += [
                nn.Conv2d(input_channel, channels[j], kernel_sizes[j], strides[j]),
                act(),
            ]
        else:
            layers += [
                nn.Conv2d(channels[j - 1], channels[j], kernel_sizes[j], strides[j]),
                act(),
            ]
    return nn.Sequential(*layers)
def SmODE(sensory_units, command_units, state_dim, output_dim):
    assert command_units - output_dim >= 0, "command_units must larger than output_dim"
    wiring = Connected(
        sensory_units,
        command_units,
        output_dim,
        sensory_units,
        command_units,
        output_dim,
        output_dim,
    )
    SmODE = SmODE(state_dim, wiring, batch_first=True)
    return SmODE
def MLP(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes) - 1):
        act = activation if j < len(sizes) - 2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j + 1]), act()]
    return nn.Sequential(*layers)
class DetermPolicy(nn.Module, Action_Distribution):
    """
    Approximated function of deterministic policy.
    Input: observation.
    Output: action.
    """
    def __init__(self, **kwargs):
        super(DetermPolicy, self).__init__()
        act_dim = kwargs["act_dim"]
        obs_dim = kwargs["obs_dim"]
        conv_type = kwargs["conv_type"]
        act_high_lim = kwargs["act_high_lim"]
        act_low_lim = kwargs["act_low_lim"]
        self.register_buffer("act_high_lim", torch.from_numpy(act_high_lim))
        self.register_buffer("act_low_lim", torch.from_numpy(act_low_lim))
        self.hidden_activation = get_activation_func(kwargs["hidden_activation"])
        self.output_activation = get_activation_func(kwargs["output_activation"])
        self.action_distribution_cls = kwargs["action_distribution_cls"]
        if conv_type == "type_1":
            conv_kernel_sizes = [8, 4, 3]
            conv_channels = [32, 64, 64]
            conv_strides = [4, 2, 1]
            conv_activation = nn.ReLU
            conv_input_channel = obs_dim[0]
            mlp_hidden_layers = [512, 256]
            self.conv = CNN(
                conv_kernel_sizes,
                conv_channels,
                conv_strides,
                conv_activation,
                conv_input_channel,
            )
            conv_num_dims = (
                self.conv(torch.ones(obs_dim).unsqueeze(0)).reshape(1, -1).shape[-1]
            )
            mlp_sizes = [conv_num_dims] + mlp_hidden_layers + [act_dim]
            self.mlp = MLP(mlp_sizes, self.hidden_activation, self.output_activation)
        elif conv_type == "type_2":
            conv_kernel_sizes = [4, 3, 3, 3, 3, 3]
            conv_channels = [8, 16, 32, 64, 128, 256]
            conv_strides = [2, 2, 2, 2, 1, 1]
            conv_activation = nn.ReLU
            conv_input_channel = obs_dim[0]
            mlp_hidden_layers = [256, 256, 256]
            self.conv = CNN(
                conv_kernel_sizes,
                conv_channels,
                conv_strides,
                conv_activation,
                conv_input_channel,
            )
            conv_num_dims = (
                self.conv(torch.ones(obs_dim).unsqueeze(0)).reshape(1, -1).shape[-1]
            )
            mlp_sizes = [conv_num_dims] + mlp_hidden_layers + [act_dim]
            self.mlp = MLP(mlp_sizes, self.hidden_activation, self.output_activation)
        else:
            raise NotImplementedError
    def forward(self, obs):
        img = self.conv(obs)
        feature = img.view(img.size(0), -1)
        feature = self.mlp(feature)
        action = (self.act_high_lim - self.act_low_lim) / 2 * torch.tanh(feature) + (
            self.act_high_lim + self.act_low_lim
        ) / 2
        return action
class FiniteHorizonPolicy(nn.Module, Action_Distribution):
    def __init__(self, **kwargs):
        raise NotImplementedError
class StochaPolicy(nn.Module, Action_Distribution):
    """
    Approximated function of stochastic policy.
    Input: observation.
    Output: parameters of action distribution.
    """
    def __init__(self, **kwargs):
        super(StochaPolicy, self).__init__()
        act_dim = kwargs["act_dim"]
        obs_dim = kwargs["obs_dim"]
        conv_type = kwargs["conv_type"]
        act_high_lim = kwargs["act_high_lim"]
        act_low_lim = kwargs["act_low_lim"]
        self.register_buffer("act_high_lim", torch.from_numpy(act_high_lim))
        self.register_buffer("act_low_lim", torch.from_numpy(act_low_lim))
        self.hidden_activation = get_activation_func(kwargs["hidden_activation"])
        self.output_activation = get_activation_func(kwargs["output_activation"])
        self.min_log_std = kwargs["min_log_std"]
        self.max_log_std = kwargs["max_log_std"]
        self.action_distribution_cls = kwargs["action_distribution_cls"]
        sensory_units = kwargs["sensory_units"]
        command_units = kwargs["command_units"]
        if conv_type == "type_1":
            conv_kernel_sizes = [8, 4, 3]
            conv_channels = [32, 64, 64]
            conv_strides = [4, 2, 1]
            conv_activation = nn.ReLU
            conv_input_channel = obs_dim[0]
            mlp_hidden_layers = [512, 256]
            self.conv = CNN(
                conv_kernel_sizes,
                conv_channels,
                conv_strides,
                conv_activation,
                conv_input_channel,
            )
            conv_num_dims = (
                self.conv(torch.ones(obs_dim).unsqueeze(0)).reshape(1, -1).shape[-1]
            )
            policy_mlp_sizes = [conv_num_dims] + mlp_hidden_layers + [act_dim]
            self.mean = MLP(
                policy_mlp_sizes, self.hidden_activation, self.output_activation
            )
            self.log_std = MLP(
                policy_mlp_sizes, self.hidden_activation, self.output_activation
            )
        elif conv_type == "type_2":
            conv_kernel_sizes = [4, 3, 3, 3, 3, 3]
            conv_channels = [8, 16, 32, 64, 128, 256]
            conv_strides = [2, 2, 2, 2, 1, 1]
            conv_activation = nn.ReLU
            conv_input_channel = obs_dim[0]
            mlp_hidden_layers = [256, 256, 256]
            self.conv = CNN(
                conv_kernel_sizes,
                conv_channels,
                conv_strides,
                conv_activation,
                conv_input_channel,
            )
            conv_num_dims = (
                self.conv(torch.ones(obs_dim).unsqueeze(0)).reshape(1, -1).shape[-1]
            )
            policy_mlp_sizes = [conv_num_dims] + mlp_hidden_layers  
            self.ode = SmODE(
                sensory_units, command_units, mlp_hidden_layers[-1], act_dim * 2
            )
            self.hx = None
            self.regular_loss = 0
            self.policy = MLP(
                policy_mlp_sizes,
                get_activation_func(kwargs["hidden_activation"]),
                get_activation_func(kwargs["output_activation"]),
            )
        else:
            raise NotImplementedError
    def forward(self, obs, sample=False, **kwargs):
        img = self.conv(obs)
        feature = img.view(img.size(0), -1)
        x = self.policy(feature)
        x = x.reshape(-1, 1, x.shape[-1])
        if sample:
            x, self.hx, _ = self.ode(x, self.hx, sample=sample)
        else:
            x, _, self.regular_loss = self.ode(x, None, sample=sample)
        x = x.reshape(-1, x.shape[-1])
        action_mean, action_log_std = torch.chunk(
            x, chunks=2, dim=-1
        )  
        action_std = torch.clamp(
            action_log_std, self.min_log_std, self.max_log_std
        ).exp()
        if "training" in kwargs:
            return torch.cat((action_mean, action_std), dim=-1), self.regular_loss
        return torch.cat((action_mean, action_std), dim=-1)
class ActionValue(nn.Module, Action_Distribution):
    """
    Approximated function of action-value function.
    Input: observation, action.
    Output: action-value.
    """
    def __init__(self, **kwargs):
        super(ActionValue, self).__init__()
        act_dim = kwargs["act_dim"]
        obs_dim = kwargs["obs_dim"]
        conv_type = kwargs["conv_type"]
        self.hidden_activation = get_activation_func(kwargs["hidden_activation"])
        self.output_activation = get_activation_func(kwargs["output_activation"])
        self.action_distribution_cls = kwargs["action_distribution_cls"]
        if conv_type == "type_1":
            conv_kernel_sizes = [8, 4, 3]
            conv_channels = [32, 64, 64]
            conv_strides = [4, 2, 1]
            conv_activation = nn.ReLU
            conv_input_channel = obs_dim[0]
            mlp_hidden_layers = [512, 256]
            self.conv = CNN(
                conv_kernel_sizes,
                conv_channels,
                conv_strides,
                conv_activation,
                conv_input_channel,
            )
            conv_num_dims = (
                self.conv(torch.ones(obs_dim).unsqueeze(0)).reshape(1, -1).shape[-1]
            )
            mlp_sizes = [conv_num_dims + act_dim] + mlp_hidden_layers + [1]
            self.mlp = MLP(mlp_sizes, self.hidden_activation, self.output_activation)
        elif conv_type == "type_2":
            conv_kernel_sizes = [4, 3, 3, 3, 3, 3]
            conv_channels = [8, 16, 32, 64, 128, 256]
            conv_strides = [2, 2, 2, 2, 1, 1]
            conv_activation = nn.ReLU
            conv_input_channel = obs_dim[0]
            mlp_hidden_layers = [256, 256, 256]
            self.conv = CNN(
                conv_kernel_sizes,
                conv_channels,
                conv_strides,
                conv_activation,
                conv_input_channel,
            )
            conv_num_dims = (
                self.conv(torch.ones(obs_dim).unsqueeze(0)).reshape(1, -1).shape[-1]
            )
            mlp_sizes = [conv_num_dims + act_dim] + mlp_hidden_layers + [1]
            self.mlp = MLP(mlp_sizes, self.hidden_activation, self.output_activation)
        else:
            raise NotImplementedError
    def forward(self, obs, act):
        img = self.conv(obs)
        feature = torch.cat([img.view(img.size(0), -1), act], -1)
        return self.mlp(feature)
class ActionValueDis(nn.Module, Action_Distribution):
    """
    Approximated function of action-value function for discrete action space.
    Input: observation.
    Output: action-value for all action.
    """
    def __init__(self, **kwargs):
        super(ActionValueDis, self).__init__()
        act_num = kwargs["act_num"]
        obs_dim = kwargs["obs_dim"]
        conv_type = kwargs["conv_type"]
        self.hidden_activation = get_activation_func(kwargs["hidden_activation"])
        self.output_activation = get_activation_func(kwargs["output_activation"])
        self.action_distribution_cls = kwargs["action_distribution_cls"]
        if conv_type == "type_1":
            conv_kernel_sizes = [8, 4, 3]
            conv_channels = [32, 64, 64]
            conv_strides = [4, 2, 1]
            conv_activation = nn.ReLU
            conv_input_channel = obs_dim[0]
            mlp_hidden_layers = [512]
            self.conv = CNN(
                conv_kernel_sizes,
                conv_channels,
                conv_strides,
                conv_activation,
                conv_input_channel,
            )
            conv_num_dims = (
                self.conv(torch.ones(obs_dim).unsqueeze(0)).reshape(1, -1).shape[-1]
            )
            mlp_sizes = [conv_num_dims] + mlp_hidden_layers + [act_num]
            self.mlp = MLP(mlp_sizes, self.hidden_activation, self.output_activation)
        elif conv_type == "type_2":
            conv_kernel_sizes = [4, 3, 3, 3, 3, 3]
            conv_channels = [8, 16, 32, 64, 128, 256]
            conv_strides = [2, 2, 2, 2, 1, 1]
            conv_activation = nn.ReLU
            conv_input_channel = obs_dim[0]
            mlp_hidden_layers = [256, 256, 256]
            self.conv = CNN(
                conv_kernel_sizes,
                conv_channels,
                conv_strides,
                conv_activation,
                conv_input_channel,
            )
            conv_num_dims = (
                self.conv(torch.ones(obs_dim).unsqueeze(0)).reshape(1, -1).shape[-1]
            )
            mlp_sizes = [conv_num_dims] + mlp_hidden_layers + [act_num]
            self.mlp = MLP(mlp_sizes, self.hidden_activation, self.output_activation)
        else:
            raise NotImplementedError
    def forward(self, obs):
        img = self.conv(obs)
        feature = img.view(img.size(0), -1)
        act_value_dis = self.mlp(feature)
        return torch.squeeze(act_value_dis, -1)
class ActionValueDistri(nn.Module):
    """
    Approximated function of distributed action-value function.
    Input: observation.
    Output: parameters of action-value distribution.
    """
    def __init__(self, **kwargs):
        super(ActionValueDistri, self).__init__()
        act_dim = kwargs["act_dim"]
        obs_dim = kwargs["obs_dim"]
        conv_type = kwargs["conv_type"]
        self.hidden_activation = get_activation_func(kwargs["hidden_activation"])
        self.output_activation = get_activation_func(kwargs["output_activation"])
        self.action_distribution_cls = kwargs["action_distribution_cls"]
        if "min_log_std" in kwargs or "max_log_std" in kwargs:
            warnings.warn(
                "min_log_std and max_log_std are deprecated in ActionValueDistri."
            )
        if conv_type == "type_1":
            conv_kernel_sizes = [8, 4, 3]
            conv_channels = [32, 64, 64]
            conv_strides = [4, 2, 1]
            conv_activation = nn.ReLU
            conv_input_channel = obs_dim[0]
            mlp_hidden_layers = [512, 256]
            self.conv = CNN(
                conv_kernel_sizes,
                conv_channels,
                conv_strides,
                conv_activation,
                conv_input_channel,
            )
            conv_num_dims = (
                self.conv(torch.ones(obs_dim).unsqueeze(0)).reshape(1, -1).shape[-1]
            )
            mlp_sizes = [conv_num_dims + act_dim] + mlp_hidden_layers + [1]
            self.mean = MLP(mlp_sizes, self.hidden_activation, self.output_activation)
            self.log_std = MLP(
                mlp_sizes, self.hidden_activation, self.output_activation
            )
        elif conv_type == "type_2":
            conv_kernel_sizes = [4, 3, 3, 3, 3, 3]
            conv_channels = [8, 16, 32, 64, 128, 256]
            conv_strides = [2, 2, 2, 2, 1, 1]
            conv_activation = nn.ReLU
            conv_input_channel = obs_dim[0]
            mlp_hidden_layers = [256, 256, 256]
            self.conv = CNN(
                conv_kernel_sizes,
                conv_channels,
                conv_strides,
                conv_activation,
                conv_input_channel,
            )
            conv_num_dims = (
                self.conv(torch.ones(obs_dim).unsqueeze(0)).reshape(1, -1).shape[-1]
            )
            mlp_sizes = [conv_num_dims + act_dim] + mlp_hidden_layers + [1]
            self.mean = MLP(mlp_sizes, self.hidden_activation, self.output_activation)
            self.log_std = MLP(
                mlp_sizes, self.hidden_activation, self.output_activation
            )
        else:
            raise NotImplementedError
    def forward(self, obs, act):
        img = self.conv(obs)
        feature = torch.cat([img.view(img.size(0), -1), act], -1)
        value_mean = self.mean(feature)
        value_std = self.log_std(feature)  
        value_std = torch.nn.functional.softplus(value_std)  
        return torch.cat((value_mean, value_std), dim=-1)
class ActionValueDistri2(nn.Module):
    """
    Approximated function of distributed action-value function.
    Input: observation.
    Output: parameters of action-value distribution.
    """
    def __init__(self, **kwargs):
        super(ActionValueDistri2, self).__init__()
        act_dim = kwargs["act_dim"]
        obs_dim = kwargs["obs_dim"]
        conv_type = kwargs["conv_type"]
        self.hidden_activation = get_activation_func(kwargs["hidden_activation"])
        self.output_activation = get_activation_func(kwargs["output_activation"])
        self.action_distribution_cls = kwargs["action_distribution_cls"]
        if "min_log_std" in kwargs or "max_log_std" in kwargs:
            warnings.warn(
                "min_log_std and max_log_std are deprecated in ActionValueDistri."
            )
        if conv_type == "type_1":
            conv_kernel_sizes = [8, 4, 3]
            conv_channels = [32, 64, 64]
            conv_strides = [4, 2, 1]
            conv_activation = nn.ReLU
            conv_input_channel = obs_dim[0]
            mlp_hidden_layers = [512, 256]
            self.conv = CNN(
                conv_kernel_sizes,
                conv_channels,
                conv_strides,
                conv_activation,
                conv_input_channel,
            )
            conv_num_dims = (
                self.conv(torch.ones(obs_dim).unsqueeze(0)).reshape(1, -1).shape[-1]
            )
            mlp_sizes = [conv_num_dims + act_dim] + mlp_hidden_layers + [1]
            self.mean = MLP(mlp_sizes, self.hidden_activation, self.output_activation)
            self.log_std = MLP(
                mlp_sizes, self.hidden_activation, self.output_activation
            )
        elif conv_type == "type_2":
            conv_kernel_sizes = [4, 3, 3, 3, 3, 3]
            conv_channels = [8, 16, 32, 64, 128, 256]
            conv_strides = [2, 2, 2, 2, 1, 1]
            conv_activation = nn.ReLU
            conv_input_channel = obs_dim[0]
            mlp_hidden_layers = [256, 256, 256]
            self.conv = CNN(
                conv_kernel_sizes,
                conv_channels,
                conv_strides,
                conv_activation,
                conv_input_channel,
            )
            conv_num_dims = (
                self.conv(torch.ones(obs_dim).unsqueeze(0)).reshape(1, -1).shape[-1]
            )
            mlp_sizes = [conv_num_dims + act_dim] + mlp_hidden_layers + [1]
            self.mean = MLP(mlp_sizes, self.hidden_activation, self.output_activation)
            self.log_std = MLP(
                mlp_sizes, self.hidden_activation, self.output_activation
            )
        else:
            raise NotImplementedError
    def forward(self, obs, act):
        img = self.conv(obs)
        feature = torch.cat([img.view(img.size(0), -1), act], -1)
        value_mean = self.mean(feature)
        return value_mean
class StochaPolicyDis(ActionValueDis, Action_Distribution):
    """
    Approximated function of stochastic policy for discrete action space.
    Input: observation.
    Output: parameters of action distribution.
    """
    pass
class StateValue(nn.Module, Action_Distribution):
    """
    Approximated function of state-value function.
    Input: observation, action.
    Output: state-value.
    """
    def __init__(self, **kwargs):
        super(StateValue, self).__init__()
        obs_dim = kwargs["obs_dim"]
        conv_type = kwargs["conv_type"]
        self.hidden_activation = get_activation_func(kwargs["hidden_activation"])
        self.output_activation = get_activation_func(kwargs["output_activation"])
        self.action_distribution_cls = kwargs["action_distribution_cls"]
        if conv_type == "type_1":
            conv_kernel_sizes = [8, 4, 3]
            conv_channels = [32, 64, 64]
            conv_strides = [4, 2, 1]
            conv_activation = nn.ReLU
            conv_input_channel = obs_dim[0]
            mlp_hidden_layers = [512]
            self.conv = CNN(
                conv_kernel_sizes,
                conv_channels,
                conv_strides,
                conv_activation,
                conv_input_channel,
            )
            conv_num_dims = (
                self.conv(torch.ones(obs_dim).unsqueeze(0)).reshape(1, -1).shape[-1]
            )
            mlp_sizes = [conv_num_dims] + mlp_hidden_layers + [1]
            self.mlp = MLP(mlp_sizes, self.hidden_activation, self.output_activation)
        elif conv_type == "type_2":
            conv_kernel_sizes = [4, 3, 3, 3, 3, 3]
            conv_channels = [8, 16, 32, 64, 128, 256]
            conv_strides = [2, 2, 2, 2, 1, 1]
            conv_activation = nn.ReLU
            conv_input_channel = obs_dim[0]
            mlp_hidden_layers = [256, 256, 256]
            self.conv = CNN(
                conv_kernel_sizes,
                conv_channels,
                conv_strides,
                conv_activation,
                conv_input_channel,
            )
            conv_num_dims = (
                self.conv(torch.ones(obs_dim).unsqueeze(0)).reshape(1, -1).shape[-1]
            )
            mlp_sizes = [conv_num_dims] + mlp_hidden_layers + [1]
            self.mlp = MLP(mlp_sizes, self.hidden_activation, self.output_activation)
        else:
            raise NotImplementedError
    def forward(self, obs):
        img = self.conv(obs)
        feature = img.view(img.size(0), -1)
        v = self.mlp(feature)
        v = torch.squeeze(v, -1)
        return v
