import torch
import torch.nn as nn
import d3rlpy
from d3rlpy.models.encoders import EncoderFactory

class CustomEncoder(nn.Module):
    def __init__(self, observation_shape, feature_size):
        super(CustomEncoder, self).__init__()
        self.feature_size = feature_size
        self.fc1 = nn.Linear(observation_shape[0], feature_size)
        self.fc2 = nn.Linear(feature_size, feature_size)

    def forward(self, x):
        h = torch.relu(self.fc1(x))
        h = torch.relu(self.fc2(x))
        return h

    # THIS IS IMPORTANT!
    def get_feature_size(self):
        return self.feature_size


class CustomEncoderWithAction(nn.Module):
    def __init__(self, observation_shape, action_size, feature_size, output_size=1):
        super(CustomEncoderWithAction, self).__init__()
        self.feature_size = feature_size
        self.action_size = action_size
        self.fc1 = nn.Linear(observation_shape[0] + action_size, feature_size)
        self.fc2 = nn.Linear(feature_size, feature_size)
        # self.fc3 = nn.Linear(feature_size, output_size)

    def forward(self, x, action):
        h = torch.cat([x, action], dim=1)
        h = torch.relu(self.fc1(h))
        h = torch.relu(self.fc2(h))
        # h = torch.relu(self.fc3(h))
        return h

    def get_feature_size(self):
        return self.feature_size

class CustomEncoderFactory(EncoderFactory):
    TYPE = "custom2layer_action"

    def __init__(self, feature_size):
        self.feature_size = feature_size

    def create(self, observation_shape):
        return CustomEncoder(observation_shape, self.feature_size)

    def create_with_action(self, observation_shape, action_size, discrete_action = False):
        return CustomEncoderWithAction(observation_shape, action_size, self.feature_size)

    def get_params(self, deep=False):
        return {"feature_size": self.feature_size}

