import torch as th
import torch.nn as nn
import torch.nn.functional as F

from typing import Any, Dict, Generator, List, Optional, Union

from .vae.encoder import Encoder


class Reward(nn.Module):
    def __init__(self, 
                 in_dim: int,
                 hidden_dims: List = [],
                 current_obs_only: bool = True,
                 use_encoder: bool = False,                 
                 device: str | th.device = 'auto',
                 *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.in_dim = in_dim
        self.hidden_dims = hidden_dims
        self.current_obs_only = current_obs_only
        self.use_encoder = use_encoder

        if device == 'auto':
            self.device = th.device("cuda:0" if th.cuda.is_available() else "cpu")
        elif device == 'cpu' or 'cuda' in device:
            self.device = th.device(device)
        else:
            assert type(device) == th.device
            self.device = self.device

        modules = []
        layer_in_dim = self.in_dim
        for layer_out_dim in self.hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Linear(layer_in_dim, layer_out_dim),
                    nn.LeakyReLU()
                    )
                )
            layer_in_dim = layer_out_dim

        self.feature_extractor = nn.Sequential(*modules)
        self.reward_output = nn.Linear(self.hidden_dims[-1], 1)
    
    def forward(self, obs: th.Tensor, next_obs: th.Tensor | None = None, encoder: Encoder | None = None) -> List[th.Tensor]:
        if self.use_encoder:
            with th.no_grad():
                z_obs, mu_obs, logvar_obs, obs = encoder(obs)
                if not self.current_obs_only:
                    z_next_obs, mu, logvar, next_obs = encoder(next_obs)

            obs_feature = self.feature_extractor(z_obs)
            feature = obs_feature

            if not self.current_obs_only:
                next_obs_feature = self.feature_extractor(z_next_obs)
                feature = (obs_feature + next_obs_feature) / 2
        else:
            obs_feature = self.feature_extractor(obs)
            feature = obs_feature

            if not self.current_obs_only:
                next_obs_feature = self.feature_extractor(next_obs)
                feature = (obs_feature + next_obs_feature) / 2

        logit = self.reward_output(feature)
        clamp_logit = th.clamp(logit, -5, 5)
        reward = F.tanh(clamp_logit)
        return reward, clamp_logit