from typing import List, Union, Tuple
import math

import torch
import torch.nn as nn

from .base_cd import BaseCDVAE
from .modules import (
    Conv2DBlock, 
    ConvTranspose2DBlock, 
    ResBlock, 
    ResTransposeBlock,
    FlattenAndLinear,
    FlattenAndLinear, 
)


class ConvCDVAE(BaseCDVAE):

    def __init__(
            self, 
            input_channels: int, 
            input_shape: Union[int, Tuple[int,...], List[int]], 
            channels: List[int],
            kernel_size: Union[int, List[int]]=3,
            stride: Union[int, List[int]]=2,
            layer_type: str="conv",
            **kwargs
        ):
        super().__init__(**kwargs)

        if isinstance(input_shape, int):
            input_shape = (input_shape, input_shape)
        elif isinstance(input_shape, list):
            input_shape = tuple(input_shape)
        else:
            raise ValueError("Invalid input_shape. Expected an integer, list, or tuple.")
        
        self.num_layers = len(channels)
        if isinstance(kernel_size, int):
            kernel_size = self.num_layers * [kernel_size]
        if isinstance(stride, int):
            stride = self.num_layers * [stride]

        assert len(channels)==len(kernel_size)==len(stride), \
            "Lengths of channels, kernel_size and stride must be the same."

        self.input_channels = input_channels
        self.input_shape = input_shape
        self.channels = channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.layer_type = layer_type
        
        self._build_model()

    
    #==========Build methods====================
    def _get_layer_type(self, transpose=False):
        if self.layer_type.upper() in ["CONV", "CONVOLUTIONAL"]:
            return ConvTranspose2DBlock if transpose else Conv2DBlock
        elif self.layer_type.upper() in ["RES", "RESIDUAL"]:
            return ResTransposeBlock if transpose else ResBlock
        else:
            raise ValueError
        

    def _calculate_encoder_output_size(self, flatten=True):
        input_tensor = torch.zeros(1, self.input_channels, *self.input_shape)
        encoder_output = self.encoder(input_tensor)
        size = torch.tensor(encoder_output.size())
        return int(torch.prod(size)) if flatten else size[1:]


    def _build_encoder(self):
        block = self._get_layer_type()
        encoder_layers = []
        in_channels = self.input_channels
        for i, out_channels in enumerate(self.channels):
            encoder_layers.append(block(
                in_channels=in_channels, 
                out_channels=out_channels, 
                kernel_size=self.kernel_size[i], 
                stride=self.stride[i],
                relu=True,
            ))
            in_channels = out_channels
        self.encoder = nn.Sequential(*encoder_layers)


    def _build_latent_space(self):
        self.encoder_output_size = self._calculate_encoder_output_size()
        self.latent_mu = FlattenAndLinear(self.encoder_output_size, self.latent_dim)
        self.latent_logvar = FlattenAndLinear(self.encoder_output_size, self.latent_dim)
        self.decoder_fc = nn.Linear(self.latent_dim, self.encoder_output_size)
            

    def _build_decoder(self):
        block_transpose = self._get_layer_type(transpose=True)
        decoder_layers = []
        channels_mod = [self.input_channels] + self.channels
        in_channels = self.channels[-1]
        for i, out_channels in enumerate(reversed(list(channels_mod[:-1]))):
            decoder_layers.append(block_transpose(
                in_channels=in_channels, 
                out_channels=out_channels, 
                kernel_size=self.kernel_size[i], 
                stride=self.stride[i],
                relu=False if i==(len(self.channels)-1) else True,
            ))
            in_channels = out_channels
        self.decoder = nn.Sequential(*decoder_layers)


    def _build_model(self):
        self._build_encoder()
        self._build_latent_space()
        self._build_decoder()


    #==========Forward methods====================
    def encode(self, x):
        x = self.encoder(x)
        latent_mu = self.latent_mu(x)
        latent_logvar = self.latent_logvar(x)
        return latent_mu, latent_logvar


    def decode(self, z):
        x = self.decoder_fc(z)
        x = x.view(
            -1, 
            self.channels[-1], 
            int(self.input_shape[0] / math.prod(self.stride)), 
            int(self.input_shape[1] / math.prod(self.stride))
        )
        x = self.decoder(x)
        return x