import torch
import torch.nn as nn
import torch.nn.functional as F


class Critic(nn.Module):
    def __init__(self, obs_size: int, action_size: int, hidden_size: int = 256):
        super().__init__()
        self.fc1 = nn.Linear(obs_size + action_size, 256)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, 1)

    def forward(self, state, action=None):
        value = torch.cat([state, action], dim=1)  # type: ignore
        x = F.relu(self.fc1(value))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class ModelPredictiveCodingCritic(nn.Module):
    def __init__(self, obs_size: int, action_size: int, hidden_size: int = 256):
        super().__init__()
        self.fc1 = nn.Linear(obs_size + action_size, 256)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, 1)
        self.mpc_1 = nn.Linear(hidden_size, hidden_size)
        self.mpc_2 = nn.Linear(hidden_size, hidden_size)
        self.next_obs_predictor = nn.Linear(hidden_size, obs_size)

    def forward(self, state, action=None):
        value = torch.cat([state, action], dim=1)  # type: ignore
        x = F.relu(self.fc1(value))
        x = F.relu(self.fc2(x))
        mpc = F.relu(self.mpc_1(x))
        mpc = F.relu(self.mpc_2(mpc))
        next_state_prediction = self.next_obs_predictor(mpc)
        x = self.fc3(x)
        return x, next_state_prediction


class ModelPredictiveCodingSeperateCritic(nn.Module):
    def __init__(self, obs_size: int, action_size: int, hidden_size: int = 256):
        super().__init__()
        self.q_fc1 = nn.Linear(obs_size + action_size, 256)
        self.q_fc2 = nn.Linear(hidden_size, hidden_size)
        self.q_fc3 = nn.Linear(hidden_size, 1)

        self.mpc_fc1 = nn.Linear(obs_size + action_size, 256)
        self.mpc_fc2 = nn.Linear(hidden_size, hidden_size)
        self.next_obs_predictor = nn.Linear(hidden_size, obs_size)

    def forward(self, state, action=None):
        value = torch.cat([state, action], dim=1)  # type: ignore
        x = F.relu(self.q_fc1(value))
        x = F.relu(self.q_fc2(x))
        x = self.q_fc3(x)

        mpc_features = F.relu(self.mpc_fc1(value))
        mpc_features = F.relu(self.mpc_fc2(mpc_features))
        next_state_prediction = self.next_obs_predictor(mpc_features)

        return x, next_state_prediction


class ModelHydraPredictiveCodingCritic(nn.Module):
    def __init__(self, obs_size: int, action_size: int, hidden_size: int = 256):
        super().__init__()
        self.fc1 = nn.Linear(obs_size + action_size, 256)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, 1)
        self.hidden_no_reg = nn.Linear(hidden_size, hidden_size)
        self.fc_no_reg = nn.Linear(hidden_size, 1)
        self.mpc_1 = nn.Linear(hidden_size, hidden_size)
        self.mpc_2 = nn.Linear(hidden_size, hidden_size)
        self.next_obs_predictor = nn.Linear(hidden_size, obs_size)

    def forward(self, state, action=None):
        value = torch.cat([state, action], dim=1)  # type: ignore
        x = F.relu(self.fc1(value))
        x = F.relu(self.fc2(x))
        mpc = F.relu(self.mpc_1(x))
        mpc = F.relu(self.mpc_2(mpc))
        next_state_prediction = self.next_obs_predictor(mpc)
        no_reg = F.relu(self.hidden_no_reg(x))
        q_value_no_reg = self.fc_no_reg(no_reg)
        x = self.fc3(x)
        return x, next_state_prediction, q_value_no_reg


class BigModelHydraPredictiveCodingCritic(nn.Module):
    def __init__(self, obs_size: int, action_size: int, hidden_size: int = 256):
        super().__init__()
        self.fc1 = nn.Linear(obs_size + action_size, 256)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, 1)
        self.hidden_no_reg = nn.Linear(hidden_size, hidden_size)
        self.fc_no_reg = nn.Linear(hidden_size, 1)
        self.mpc_1 = nn.Linear(hidden_size, hidden_size)
        self.mpc_2 = nn.Linear(hidden_size, hidden_size)
        self.mpc_3 = nn.Linear(hidden_size, hidden_size)
        self.mpc_4 = nn.Linear(hidden_size, hidden_size)
        self.next_obs_predictor = nn.Linear(hidden_size, obs_size)

    def forward(self, state, action=None):
        value = torch.cat([state, action], dim=1)  # type: ignore
        x = F.relu(self.fc1(value))
        x = F.relu(self.fc2(x))
        mpc = F.relu(self.mpc_1(x))
        mpc = F.relu(self.mpc_2(mpc))
        mpc = F.relu(self.mpc_3(mpc))
        mpc = F.relu(self.mpc_4(mpc))
        next_state_prediction = self.next_obs_predictor(mpc)
        no_reg = F.relu(self.hidden_no_reg(x))
        q_value_no_reg = self.fc_no_reg(no_reg)
        x = self.fc3(x)
        return x, next_state_prediction, q_value_no_reg

