import torch as th
import torch.nn as nn
from typing import Any, Dict, Generator, List, Optional, Union


class Discriminator(nn.Module):
    def __init__(self, 
                 in_dim: int, 
                 hidden_dims: List = [],
                 current_obs_only: bool = True,
                 device: str | th.device = 'auto',
                 *args,
                 **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.in_dim = in_dim
        self.hidden_dims = hidden_dims
        self.current_obs_only = current_obs_only

        if device == 'auto':
            self.device = th.device("cuda:0" if th.cuda.is_available() else "cpu")
        elif device == 'cpu' or 'cuda' in device:
            self.device = th.device(device)
        else:
            assert type(device) == th.device
            self.device = self.device
        
        modules = []
        layer_in_dim = self.in_dim
        for layer_out_dim in self.hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Linear(layer_in_dim,
                              layer_out_dim),
                    nn.LeakyReLU()
                )
            )
            layer_in_dim = layer_out_dim
        
        self.feature_extractor = nn.Sequential(*modules)
        self.disc_output = nn.Linear(layer_in_dim, 1)
    
    def forward(self, obs: th.Tensor, next_obs: th.Tensor | None = None) -> th.Tensor:
        obs_feature = self.feature_extractor(obs)
        feature = obs_feature

        if not self.current_obs_only:
            next_obs_feature = self.feature_extractor(next_obs)
            feature = (obs_feature + next_obs_feature) / 2

        logit = self.disc_output(feature)
        logit = th.clamp(logit, -5, 5)
        return logit  
 