# file: user_extensions/baselines/fader_networks/models.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from prism.core.registry import MODELS
from prism.models.heads import MLPHead
from prism.models.backbones import VGGStyleBlock, VGGUp, UpRes, ResidualBlock


def _build_conditional_upsampling_modules(config, in_channels, out_channels, n_attr):
    modules = []
    arch_cfg = config.model.architecture.conv.decoder
    shared_cfg = config.model.architecture

    if arch_cfg.style == 'residual':
        block_class, up_class = ResidualBlock, UpRes
    elif arch_cfg.style == 'vgg':
        block_class, up_class = VGGStyleBlock, VGGUp
    else:
        raise ValueError(f"Unsupported decoder style: {arch_cfg.style}")

    current_channels = in_channels
    for h_dim, repeat in zip(arch_cfg.h_dims, arch_cfg.block_repeats):
        stage_in_channels = current_channels + n_attr

        for _ in range(repeat - 1):
            modules.append(block_class(
                stage_in_channels, h_dim, arch_cfg.use_bn, shared_cfg.use_bias, shared_cfg.activation_type
            ))
            current_channels = h_dim
            stage_in_channels = current_channels + n_attr

        modules.append(up_class(
            stage_in_channels, h_dim, arch_cfg.use_bn, shared_cfg.use_bias, shared_cfg.activation_type,
            arch_cfg.upsampling_method, arch_cfg.upsampling_factor
        ))
        current_channels = h_dim

    final_conv_in = current_channels + n_attr
    final_conv = nn.Conv2d(final_conv_in, out_channels, kernel_size=3, padding=1, bias=shared_cfg.use_bias)
    final_act = nn.Tanh()

    return modules, final_conv, final_act


class ConditionalConvBackbone(nn.Module):
    def __init__(self, config, in_channels, out_channels, n_attr):
        super().__init__()
        self.n_attr = n_attr

        up_blocks, final_conv, final_act = _build_conditional_upsampling_modules(
            config, in_channels, out_channels, n_attr
        )
        self.up_blocks = nn.ModuleList(up_blocks)
        self.final_conv = final_conv
        self.final_act = final_act

    def forward(self, x, y):
        for block in self.up_blocks:
            y_expanded = y.unsqueeze(2).unsqueeze(3).expand(-1, -1, x.size(2), x.size(3))
            x_cond = torch.cat([x, y_expanded], dim=1)
            x = block(x_cond)

        y_expanded = y.unsqueeze(2).unsqueeze(3).expand(-1, -1, x.size(2), x.size(3))
        x_cond = torch.cat([x, y_expanded], dim=1)
        x = self.final_conv(x_cond)
        x = self.final_act(x)
        return x


@MODELS.register("ConditionalGenerator")
class ConditionalGenerator(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.model_cfg = config.model
        self.data_cfg = config.data
        self.num_classes = self.data_cfg.num_classes
        self.model_type = self.model_cfg.type

        if self.model_type not in ['conv', 'fcn']:
            raise NotImplementedError(
                "The faithful FaderGenerator implementation is only available for 'conv' and 'fcn' model types."
            )

        self.pre_processor, backbone_in_channels = self._build_pre_processor()
        self.backbone = ConditionalConvBackbone(
            config=self.config,
            in_channels=backbone_in_channels,
            out_channels=self.data_cfg.image_shape[0],
            n_attr=self.num_classes
        )

    def _get_decoder_input_shape(self):
        dec_arch_cfg = self.model_cfg.architecture.conv.decoder
        num_downsamples = len(dec_arch_cfg.h_dims)
        downsample_ratio = dec_arch_cfg.upsampling_factor ** num_downsamples
        h, w = self.data_cfg.image_shape[1:]

        if h % downsample_ratio != 0 or w % downsample_ratio != 0:
            raise ValueError(
                f"Image dimensions ({h}, {w}) are not cleanly divisible by the total downsample ratio "
                f"of {downsample_ratio}. Adjust architecture or image size."
            )
        latent_h, latent_w = h // downsample_ratio, w // downsample_ratio
        latent_c = dec_arch_cfg.h_dims[0]
        return (latent_c, latent_h, latent_w)

    def _build_pre_processor(self):
        if self.model_type == 'fcn':
            fcn_cfg = self.model_cfg.fcn_params
            unflatten_shape = (fcn_cfg.latent_channels, fcn_cfg.latent_h, fcn_cfg.latent_w)
            backbone_in_channels = fcn_cfg.latent_channels
            pre_processor = nn.Unflatten(1, unflatten_shape)
        else:  # 'conv' type
            dec_arch_cfg = self.model_cfg.architecture.conv.decoder
            unflatten_shape = self._get_decoder_input_shape()
            pre_conv_features = int(np.prod(unflatten_shape))
            backbone_in_channels = unflatten_shape[0]
            pre_processor = nn.Sequential(
                MLPHead(
                    in_features=self.model_cfg.latent_space.latent_dim,
                    out_features=pre_conv_features,
                    h_units=dec_arch_cfg.mlp_h_units,
                    use_bias=self.model_cfg.architecture.use_bias,
                    activation_type=self.model_cfg.architecture.activation_type
                ),
                nn.Unflatten(1, unflatten_shape)
            )
        return pre_processor, backbone_in_channels

    def forward(self, z, y):
        y_one_hot = F.one_hot(y, num_classes=self.num_classes).float().to(z.device)
        initial_features = self.pre_processor(z)
        reconstructed_image = self.backbone(initial_features, y_one_hot)
        return reconstructed_image


@MODELS.register("FaderDiscriminator")
class FaderDiscriminator(MLPHead):
    def __init__(self, config):
        model_cfg = config.model
        arch_cfg = config.fader_discriminator.architecture
        super().__init__(
            in_features=model_cfg.latent_space.latent_dim,
            out_features=config.data.num_classes,
            h_units=arch_cfg.h_units,
            use_bias=arch_cfg.use_bias,
            activation_type=arch_cfg.activation_type
        )