# file: prism/models/heads.py
import numpy as np
import torch
import torch.nn as nn

from prism.core.registry import MODELS
from prism.models.backbones import ConvBackbone, _initialize_activation
from prism.utils.config import AttrDict


@MODELS.register("MLPHead")
class MLPHead(nn.Module):
    def __init__(self, in_features, out_features, h_units, use_bias, activation_type):
        super().__init__()
        self.flatten = nn.Flatten()

        layers = []
        current_dim = in_features
        activation_fn = _initialize_activation(activation_type)

        if h_units:
            for h_dim in h_units:
                layers.append(nn.Linear(current_dim, h_dim, bias=use_bias))
                layers.append(activation_fn)
                current_dim = h_dim

        layers.append(nn.Linear(current_dim, out_features, bias=use_bias))
        self.mlp = nn.Sequential(*layers)

    def forward(self, x):
        x = self.flatten(x)
        return self.mlp(x)


@MODELS.register("Classifier")
class Classifier(MLPHead):
    def __init__(self, config):
        model_cfg = config.model
        arch_cfg = model_cfg.architecture.mlp

        super().__init__(
            in_features=model_cfg.latent_space.target_dim,
            out_features=config.data.num_classes,
            h_units=arch_cfg.classifier_h_units,
            use_bias=model_cfg.architecture.use_bias,
            activation_type=model_cfg.architecture.activation_type
        )


@MODELS.register("LatentDiscriminator")
class LatentDiscriminator(nn.Module):
    def __init__(self, config):
        super().__init__()
        use_bias = config.model.architecture.use_bias

        backbone_layers = []
        current_dim = config.model.latent_space.nontarget_dim
        activation_fn = _initialize_activation(config.model.architecture.activation_type)

        h_units = config.model.architecture.mlp.latent_disc_h_units
        if h_units:
            for h_dim in h_units:
                backbone_layers.append(nn.Linear(current_dim, h_dim, bias=use_bias))
                backbone_layers.append(activation_fn)
                current_dim = h_dim

        self.backbone = nn.Sequential(*backbone_layers)
        self.classification_head = nn.Linear(current_dim, config.data.num_classes, bias=use_bias)
        self.adversarial_head = nn.Linear(current_dim, 1, bias=use_bias)

    def forward(self, z0):
        z0 = z0.flatten(start_dim=1)
        shared_features = self.backbone(z0)
        class_logits = self.classification_head(shared_features)
        adversarial_logit = self.adversarial_head(shared_features)
        return class_logits, adversarial_logit


@MODELS.register("DiscriminatorQ")
class DiscriminatorQ(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.features = self._build_backbone()

        model_cfg = self.config.model
        data_cfg = self.config.data
        dummy_input = torch.randn(2, *data_cfg.image_shape)

        with torch.no_grad():
            backbone_out = self.features(dummy_input)
            feature_dim = int(np.prod(backbone_out.shape[1:]))

        nontarget_dim = model_cfg.latent_space.nontarget_dim

        self.discriminator_head = nn.Linear(feature_dim, 1)
        self.q_head_mu = nn.Linear(feature_dim, nontarget_dim)
        self.q_head_logvar = nn.Linear(feature_dim, nontarget_dim)

    def _build_backbone(self):
        dq_cfg = self.config.discriminator_q
        data_cfg = self.config.data

        if dq_cfg.type == 'mlp':
            input_dim = int(np.prod(data_cfg.image_shape))
            output_dim = dq_cfg.architecture.mlp.encoder_h_dims[-1]
            return MLPHead(
                in_features=input_dim,
                out_features=output_dim,
                h_units=dq_cfg.architecture.mlp.encoder_h_dims[:-1],
                use_bias=dq_cfg.architecture.use_bias,
                activation_type=dq_cfg.architecture.activation_type
            )
        elif dq_cfg.type == 'conv':
            temp_config = AttrDict({
                'model': dq_cfg,
                'data': data_cfg
            })
            return nn.Sequential(
                ConvBackbone(temp_config, direction='encoder'),
                nn.AdaptiveAvgPool2d((1, 1)),
                nn.Flatten()
            )
        else:
            raise ValueError(f"Unsupported model type for DiscriminatorQ: {dq_cfg.type}")

    def forward(self, x):
        shared_features = self.features(x)
        d_logits = self.discriminator_head(shared_features)
        q_mu = self.q_head_mu(shared_features)
        q_logvar = self.q_head_logvar(shared_features)
        return d_logits, q_mu, q_logvar