"""
Model factory. Add more description
"""

import gym
import torch.nn as nn
import torch.nn.functional as F

from expground.types import Dict, Any, Union
from expground.utils.preprocessor import get_preprocessor


class Model(nn.Module):
    def __init__(
        self, input_space: Union[int, gym.Space], output_space: Union[int, gym.Space]
    ):
        """Create a model instance. Common abstract methods could be added here.

        Args:
            input_space (Union[int, gym.Space]): Input space description or input dim.
            output_space (Union[int, gym.Space]): Output space description or output dim.
        """

        super(Model, self).__init__()
        if isinstance(input_space, gym.spaces.Space):
            self.input_dim = get_preprocessor(input_space)(input_space).size
        else:
            self.input_dim = input_space

        if isinstance(output_space, gym.spaces.Space):
            self.output_dim = get_preprocessor(output_space)(output_space).size
        else:
            self.output_dim = output_space

    def reset(self):
        for layer in self.modules():
            if hasattr(layer, "reset_parameters"):
                layer.reset_parameters()
