import torch as th
from torch import nn

from typing import List, Callable, Union, Any, TypeVar, Tuple, overload


class Decoder(nn.Module):
    def __init__(self, 
                 in_dim: int,  # z dim
                 out_dim: int,  # state dim
                 hidden_dims: List = [],
                 device: str | th.device = 'auto',
                 *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.hidden_dims = hidden_dims

        if device == 'auto':
            self.device = th.device("cuda:0" if th.cuda.is_available() else "cpu")
        elif device == 'cpu' or 'cuda' in device:
            self.device = th.device(device)
        else:
            assert type(device) == th.device
            self.device = self.device
        
        modules = []
        layer_in_dim = self.in_dim
        for layer_out_dim in self.hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Linear(layer_in_dim, layer_out_dim),
                    nn.LeakyReLU()
                    )
                )
            layer_in_dim = layer_out_dim
        
        self.decoder = nn.Sequential(*modules)
        self.final_layer = nn.Linear(self.hidden_dims[-1], self.out_dim)

    def decode(self, input: th.Tensor) -> th.Tensor:
        decoder_result = self.decoder(input)
        recon = self.final_layer(decoder_result)
        return recon
    
    def forward(self, input: th.Tensor, **kwargs) -> th.Tensor:
        recon = self.decode(input)
        return recon



        
