# file: prism/models/autoencoder.py
import numpy as np
import torch
import torch.nn as nn

from prism.core.base_objects import BaseModel
from prism.core.registry import MODELS
from prism.models.backbones import ConvBackbone
from prism.models.heads import MLPHead


@MODELS.register("Encoder")
class Encoder(BaseModel):
    def __init__(self, config):
        super().__init__(config)
        self.model_type = self.config.model.type

        if self.model_type == 'mlp':
            self.net = self._build_mlp_encoder()
        elif self.model_type == 'conv':
            self.net = self._build_conv_encoder()
        elif self.model_type == 'fcn':
            self.net = self._build_fcn_encoder()
        else:
            raise ValueError(f"Unsupported model type for Encoder: {self.model_type}")

    def _build_mlp_encoder(self):
        model_cfg = self.config.model
        arch_cfg = model_cfg.architecture.mlp
        image_shape = self.config.data.image_shape
        input_dim = int(np.prod(image_shape))

        return MLPHead(
            in_features=input_dim,
            out_features=model_cfg.latent_space.latent_dim,
            h_units=arch_cfg.encoder_h_dims,
            use_bias=model_cfg.architecture.use_bias,
            activation_type=model_cfg.architecture.activation_type
        )

    def _build_conv_encoder(self):
        backbone = ConvBackbone(self.config, direction='encoder')
        model_cfg = self.config.model
        data_cfg = self.config.data

        dummy_input = torch.randn(2, *data_cfg.image_shape)
        with torch.no_grad():
            backbone_out_shape = backbone(dummy_input).shape
            in_features = int(np.prod(backbone_out_shape[1:]))

        head = MLPHead(
            in_features=in_features,
            out_features=model_cfg.latent_space.latent_dim,
            h_units=model_cfg.architecture.conv.encoder.mlp_h_units,
            use_bias=model_cfg.architecture.use_bias,
            activation_type=model_cfg.architecture.activation_type
        )
        return nn.Sequential(backbone, head)

    def _build_fcn_encoder(self):
        return nn.Sequential(
            ConvBackbone(self.config, direction='encoder'),
            nn.Flatten()
        )

    def forward(self, x):
        return self.net(x)


@MODELS.register("Generator")
class Generator(BaseModel):
    def __init__(self, config):
        super().__init__(config)
        self.model_type = self.config.model.type

        if self.model_type == 'mlp':
            self.net = self._build_mlp_generator()
        elif self.model_type == 'conv':
            self.net = self._build_conv_generator()
        elif self.model_type == 'fcn':
            self.net = self._build_fcn_generator()
        else:
            raise ValueError(f"Unsupported model type for Generator: {self.model_type}")

    def _build_mlp_generator(self):
        model_cfg = self.config.model
        arch_cfg = model_cfg.architecture.mlp
        image_shape = self.config.data.image_shape
        output_dim = int(np.prod(image_shape))

        mlp_head = MLPHead(
            in_features=model_cfg.latent_space.latent_dim,
            out_features=output_dim,
            h_units=arch_cfg.generator_h_dims,
            use_bias=model_cfg.architecture.use_bias,
            activation_type=model_cfg.architecture.activation_type
        )
        return nn.Sequential(
            mlp_head,
            nn.Tanh(),
            nn.Unflatten(1, image_shape)
        )

    def _build_conv_generator(self):
        model_cfg = self.config.model
        dec_arch_cfg = model_cfg.architecture.conv.decoder

        backbone = ConvBackbone(self.config, direction='decoder')
        unflatten_shape = backbone.get_decoder_input_shape()
        pre_conv_features = int(np.prod(unflatten_shape))

        pre_processor = nn.Sequential(
            MLPHead(
                in_features=model_cfg.latent_space.latent_dim,
                out_features=pre_conv_features,
                h_units=dec_arch_cfg.mlp_h_units,
                use_bias=model_cfg.architecture.use_bias,
                activation_type=model_cfg.architecture.activation_type
            ),
            nn.Unflatten(1, unflatten_shape)
        )

        return nn.Sequential(pre_processor, backbone)

    def _build_fcn_generator(self):
        model_cfg = self.config.model
        fcn_cfg = model_cfg.fcn_params
        unflatten_shape = (fcn_cfg.latent_channels, fcn_cfg.latent_h, fcn_cfg.latent_w)
        return nn.Sequential(
            nn.Unflatten(1, unflatten_shape),
            ConvBackbone(self.config, direction='decoder')
        )

    def forward(self, z):
        return self.net(z)


@MODELS.register("Autoencoder")
class Autoencoder(BaseModel):
    def __init__(self, config):
        super().__init__(config)
        self.encoder = Encoder(config)
        self.generator = Generator(config)

    def forward(self, x):
        z = self.encoder(x)
        x_rec = self.generator(z)
        return x_rec, z