# file: prism/models/backbones.py
import torch.nn as nn


def _initialize_activation(activation_type):
    if activation_type == 'relu':
        return nn.ReLU()
    elif activation_type == 'leaky_relu':
        return nn.LeakyReLU()
    elif activation_type == 'elu':
        return nn.ELU()
    elif activation_type == 'gelu':
        return nn.GELU()
    elif activation_type == 'swish':
        return nn.SiLU()
    elif activation_type == 'tanh':
        return nn.Tanh()
    else:
        raise ValueError(f"Unsupported activation function: {activation_type}")


def _initialize_downsampling(method, stride, channels, use_bn, use_bias, activation_type):
    if method == 'maxpool':
        return [nn.MaxPool2d(kernel_size=stride, stride=stride)]
    elif method == 'avgpool':
        return [nn.AvgPool2d(kernel_size=stride, stride=stride)]
    elif method == 'conv':
        return [
            nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1, bias=use_bias),
            nn.BatchNorm2d(channels) if use_bn else nn.Identity(),
            _initialize_activation(activation_type)
        ]
    else:
        raise ValueError(f"Unsupported downsampling method: {method}")


def _initialize_upsampling(method, stride, channels, use_bn, use_bias, activation_type):
    if method == 'convtranspose':
        return [
            nn.ConvTranspose2d(channels, channels, kernel_size=2, stride=stride, bias=use_bias),
            nn.BatchNorm2d(channels) if use_bn else nn.Identity(),
            _initialize_activation(activation_type)
        ]
    elif method == 'pixelshuffle':
        return [
            nn.Conv2d(channels, channels * stride ** 2, kernel_size=1, bias=use_bias),
            nn.BatchNorm2d(channels * stride ** 2) if use_bn else nn.Identity(),
            nn.PixelShuffle(stride),
            _initialize_activation(activation_type)
        ]
    elif method == 'nearest':
        return [nn.Upsample(scale_factor=stride, mode='nearest')]
    else:
        raise ValueError(f"Unsupported upsampling method: {method}")


def _build_downsampling_modules(config, in_channels, out_channels, block_class, down_class):
    modules = []
    arch_cfg = config.model.architecture.conv.encoder
    shared_cfg = config.model.architecture

    current_channels = in_channels
    for h_dim, repeat in zip(arch_cfg.h_dims, arch_cfg.block_repeats):
        for _ in range(repeat - 1):
            modules.append(block_class(current_channels, h_dim, arch_cfg.use_bn, shared_cfg.use_bias, shared_cfg.activation_type))
            current_channels = h_dim
        modules.append(down_class(
            current_channels, h_dim, arch_cfg.use_bn, shared_cfg.use_bias, shared_cfg.activation_type,
            arch_cfg.downsampling_method, arch_cfg.downsampling_factor
        ))
        current_channels = h_dim

    modules.append(nn.Conv2d(current_channels, out_channels, kernel_size=3, padding=1, bias=shared_cfg.use_bias))
    return modules


def _build_upsampling_modules(config, in_channels, out_channels, block_class, up_class):
    modules = []
    arch_cfg = config.model.architecture.conv.decoder
    shared_cfg = config.model.architecture

    current_channels = in_channels
    for h_dim, repeat in zip(arch_cfg.h_dims, arch_cfg.block_repeats):
        for _ in range(repeat - 1):
            modules.append(block_class(current_channels, h_dim, arch_cfg.use_bn, shared_cfg.use_bias, shared_cfg.activation_type))
            current_channels = h_dim
        modules.append(up_class(
            current_channels, h_dim, arch_cfg.use_bn, shared_cfg.use_bias, shared_cfg.activation_type,
            arch_cfg.upsampling_method, arch_cfg.upsampling_factor
        ))
        current_channels = h_dim

    modules.append(nn.Conv2d(current_channels, out_channels, kernel_size=3, padding=1, bias=shared_cfg.use_bias))
    modules.append(nn.Tanh())
    return modules


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_bn, use_bias, activation_type):
        super().__init__()
        self.main_flow = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=use_bias),
            nn.BatchNorm2d(out_channels) if use_bn else nn.Identity(),
            _initialize_activation(activation_type),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=use_bias),
            nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
        )
        if in_channels != out_channels:
            self.skip_flow = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=use_bias),
                nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
            )
        else:
            self.skip_flow = nn.Identity()
        self.activation = _initialize_activation(activation_type)

    def forward(self, x):
        return self.activation(self.main_flow(x) + self.skip_flow(x))


class DownRes(ResidualBlock):
    def __init__(self, in_channels, out_channels, use_bn, use_bias, activation_type, downsampling_method, stride):
        super().__init__(in_channels, out_channels, use_bn, use_bias, activation_type)
        self.downsample = nn.Sequential(*_initialize_downsampling(downsampling_method, stride, out_channels, use_bn, use_bias, activation_type))

    def forward(self, x):
        out = super().forward(x)
        return self.downsample(out)


class UpRes(ResidualBlock):
    def __init__(self, in_channels, out_channels, use_bn, use_bias, activation_type, upsampling_method, stride):
        super().__init__(in_channels, out_channels, use_bn, use_bias, activation_type)
        self.upsample = nn.Sequential(*_initialize_upsampling(upsampling_method, stride, out_channels, use_bn, use_bias, activation_type))

    def forward(self, x):
        out = super().forward(x)
        return self.upsample(out)


class VGGStyleBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_bn, use_bias, activation_type):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=use_bias),
            nn.BatchNorm2d(out_channels) if use_bn else nn.Identity(),
            _initialize_activation(activation_type)
        )

    def forward(self, x):
        return self.layer(x)


class VGGDown(VGGStyleBlock):
    def __init__(self, in_channels, out_channels, use_bn, use_bias, activation_type, downsampling_method, stride):
        super().__init__(in_channels, out_channels, use_bn, use_bias, activation_type)
        self.layer.extend(_initialize_downsampling(downsampling_method, stride, out_channels, use_bn, use_bias, activation_type))


class VGGUp(VGGStyleBlock):
    def __init__(self, in_channels, out_channels, use_bn, use_bias, activation_type, upsampling_method, stride):
        super().__init__(in_channels, out_channels, use_bn, use_bias, activation_type)
        self.layer.extend(_initialize_upsampling(upsampling_method, stride, out_channels, use_bn, use_bias, activation_type))


class ConvBackbone(nn.Module):
    def __init__(self, config, direction):
        super().__init__()
        self.config = config
        self.direction = direction

        if direction == 'encoder':
            self._build_encoder()
        elif direction == 'decoder':
            self._build_decoder()
        else:
            raise ValueError(f"Invalid direction for ConvBackbone: {direction}")

    def _build_encoder(self):
        enc_cfg = self.config.model.architecture.conv.encoder
        shared_cfg = self.config.model.architecture
        in_channels = self.config.data.image_shape[0]
        out_channels = self.config.model.fcn_params.latent_channels if self.config.model.type == 'fcn' else enc_cfg.h_dims[-1]

        if enc_cfg.style == 'residual':
            block, down_block = ResidualBlock, DownRes
        elif enc_cfg.style == 'vgg':
            block, down_block = VGGStyleBlock, VGGDown
        else:
            raise ValueError(f"Unsupported encoder style: {enc_cfg.style}")

        init_block = VGGStyleBlock(in_channels, enc_cfg.h_dims[0], enc_cfg.use_bn, shared_cfg.use_bias, shared_cfg.activation_type)
        modules = _build_downsampling_modules(self.config, enc_cfg.h_dims[0], out_channels, block, down_block)
        self.net = nn.Sequential(init_block, *modules)

    def _build_decoder(self):
        dec_cfg = self.config.model.architecture.conv.decoder
        out_channels = self.config.data.image_shape[0]

        if self.config.model.type == 'fcn':
            in_channels = self.config.model.fcn_params.latent_channels
        else:
            in_channels = dec_cfg.h_dims[0]

        if dec_cfg.style == 'residual':
            block, up_block = ResidualBlock, UpRes
        elif dec_cfg.style == 'vgg':
            block, up_block = VGGStyleBlock, VGGUp
        else:
            raise ValueError(f"Unsupported decoder style: {dec_cfg.style}")

        modules = _build_upsampling_modules(self.config, in_channels, out_channels, block, up_block)
        self.net = nn.Sequential(*modules)

    def get_decoder_input_shape(self):
        dec_arch_cfg = self.config.model.architecture.conv.decoder
        downsample_ratio = dec_arch_cfg.upsampling_factor ** len(dec_arch_cfg.h_dims)
        h, w = self.config.data.image_shape[1:]

        if h % downsample_ratio != 0 or w % downsample_ratio != 0:
            raise ValueError(
                f"Image dimensions ({h}, {w}) are not cleanly divisible by the total downsample ratio "
                f"of {downsample_ratio}. Adjust architecture or image size."
            )

        latent_h, latent_w = h // downsample_ratio, w // downsample_ratio
        latent_c = dec_arch_cfg.h_dims[0]
        return (latent_c, latent_h, latent_w)

    def forward(self, x):
        return self.net(x)