import numpy as np
import torch
import torch.nn as nn
from ray.rllib.models.utils import get_activation_fn


def check_obs_space(obs_space, model_config):
    total_obs_dim = 0
    for v in model_config['custom_model_config']['obs_shape']:
        total_obs_dim += np.prod(v)

    assert total_obs_dim == obs_space.shape[0], 'obs_shape is not consistent with the total ' + \
        'number of obs_space. Check if the model config matches the observation space of the env.'


def build_convnet(model_config, to_nn_seq=True, no_act_last_layer=False):
    activation = model_config.get('conv_activation')
    activation = get_activation_fn(activation, 'torch')
    filters = model_config['conv_filters']
    modules = nn.ModuleList()
    for i, filt in enumerate(filters):
        modules.append(nn.Conv2d(*filt))
        if (i != len(filters) - 1) or (not no_act_last_layer):
            modules.append(activation())
    if to_nn_seq:
        modules = nn.Sequential(*modules)
    return modules


def build_fcnet(model_config, to_nn_seq=True, no_act_last_layer=False):
    activation = model_config.get('fc_activation')
    activation = get_activation_fn(activation, 'torch')
    filters = model_config['fc_filters']
    modules = nn.ModuleList()
    for i in range(len(filters)-1):
        modules.append(nn.Linear(filters[i], filters[i+1]))
        if not (no_act_last_layer and i == len(filters)-2):
            modules.append(activation())
    if to_nn_seq:
        modules = nn.Sequential(*modules)
    return modules
