import abc
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.utils.parametrize as P

from lambda_ac.nn.common import MLP
from lambda_ac.rl_types import Encoder, EncoderCriticModule, FeatureInput, QHead


class MLPQHead(QHead):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        action_dim: int,
        hidden_layers: int,
        spectral_norm: bool = False,
        depend_on_hidden: bool = False,
    ):
        super().__init__()
        self.depend_on_hidden = depend_on_hidden
        input_size = (
            2 * input_dim + action_dim if depend_on_hidden else input_dim + action_dim
        )
        self.head1 = MLP(
            input_size,
            1,
            hidden_dim,
            hidden_layers,
            normalize_input=True,
            apply_spectral_norm=False,
            normalize_last_layer=False,
            batch_norm=False,
        )
        self.head2 = MLP(
            input_size,
            1,
            hidden_dim,
            hidden_layers,
            normalize_input=True,
            apply_spectral_norm=False,
            normalize_last_layer=False,
            batch_norm=False,
        )

    def forward(
        self, features: FeatureInput, action: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.depend_on_hidden:
            x = torch.cat([features.encoded, features.hidden], dim=-1)
        else:
            x = features.encoded
        _xu = torch.cat([x, action], dim=-1)
        return self.head1(_xu), self.head2(_xu)

    def roll(self):
        pass


class QNetwork(EncoderCriticModule):
    def __init__(self, encoder: Encoder, head: MLPQHead):
        super().__init__(encoder)
        self.head: MLPQHead = head

    def forward(
        self, states: torch.Tensor, actions: torch.Tensor, detach_encoder: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        features = FeatureInput.from_output(self.encoder(states))
        if detach_encoder:
            features.detach()
        q1, q2 = self.head.forward(features, actions)
        return q1, q2
