import torch as th
from torch import nn

from typing import List, Callable, Union, Any, TypeVar, Tuple, overload


class Encoder(nn.Module):
    def __init__(self, 
                 in_dim: int,  # state dim
                 out_dim: int,  # z dim
                 hidden_dims: List = [],
                 device: str | th.device = 'auto',
                 *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.hidden_dims = hidden_dims

        if device == 'auto':
            self.device = th.device("cuda:0" if th.cuda.is_available() else "cpu")
        elif device == 'cpu' or 'cuda' in device:
            self.device = th.device(device)
        else:
            assert type(device) == th.device
            self.device = self.device
        
        modules = []
        layer_in_dim = self.in_dim
        for layer_out_dim in self.hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Linear(layer_in_dim, layer_out_dim),
                    nn.LeakyReLU()
                    )
                )
            layer_in_dim = layer_out_dim
        
        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(self.hidden_dims[-1], self.out_dim)
        self.fc_logvar = nn.Linear(self.hidden_dims[-1], self.out_dim)

    def encode(self, input: th.Tensor) -> List[th.Tensor]:
        encoder_result = self.encoder(input)
        mu = self.fc_mu(encoder_result)
        logvar = self.fc_logvar(encoder_result)

        return mu, logvar

    def reparameterize(self, mu: th.Tensor, logvar: th.Tensor) -> th.Tensor:
        std = th.exp(0.5 * logvar)
        eps = th.randn_like(std)
        return eps * std + mu
    
    def forward(self, input: th.Tensor, **kwargs) -> List[th.Tensor]:
        mu, logvar = self.encode(input)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar, input



        
