# Copyright (c) OpenMMLab. All rights reserved.
import random

import torch.nn as nn

from mmpretrain.registry import MODELS
from .modules import FlamingoLayer, GatedCrossAttentionBlock
from .utils import getattr_recursive, setattr_recursive


@MODELS.register_module()
class FlamingoLMAdapter:
    """Mixin to add cross-attention layers to a language model."""

    @classmethod
    def extend_init(
        cls,
        base: object,
        vis_hidden_size: int,
        cross_attn_every_n_layers: int,
        use_media_placement_augmentation: bool,
        only_attend_previous: bool = False,
    ):
        """Initialize Flamingo by adding a new gated cross attn to the decoder.

        Store the media token id for computing the media locations.

        Args:
            base (object): Base module could be any object that represent
                a instance of language model.
            vis_hidden_size: (int): Hidden size of vision embeddings.
            cross_attn_every_n_layers: (int): Additional cross attn for
                every n layers.
            use_media_placement_augmentation: (bool): Whether to use media
                placement augmentation.
        """
        base.set_decoder_layers_attr_name('model.layers')
        gated_cross_attn_layers = nn.ModuleList([
            GatedCrossAttentionBlock(
                dim=base.config.hidden_size, dim_visual=vis_hidden_size) if
            (layer_idx + 1) % cross_attn_every_n_layers == 0 else None
            for layer_idx, _ in enumerate(base._get_decoder_layers())
        ])
        base._set_decoder_layers(
            nn.ModuleList([
                FlamingoLayer(gated_cross_attn_layer, decoder_layer)
                for gated_cross_attn_layer, decoder_layer in zip(
                    gated_cross_attn_layers, base._get_decoder_layers())
            ]))
        base.use_media_placement_augmentation = use_media_placement_augmentation  # noqa
        base.initialized_flamingo = True
        base.only_attend_previous = only_attend_previous
        return base

    def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
        """Set decoder layers attribute name."""
        self.decoder_layers_attr_name = decoder_layers_attr_name

    def _get_decoder_layers(self):
        """Get decoder layers according to attribute name."""
        return getattr_recursive(self, self.decoder_layers_attr_name)

    def _set_decoder_layers(self, value):
        """Set decoder layers according to attribute name."""
        setattr_recursive(self, self.decoder_layers_attr_name, value)

    def forward(self, *input, **kwargs):
        """Condition the Flamingo layers on the media locations before forward
        function."""
        input_ids = kwargs['input_ids'] if 'input_ids' in kwargs else input[0]
        media_locations = input_ids == self.media_token_id
        if self.only_attend_previous:
            attend_previous = True
        elif self.use_media_placement_augmentation:
            attend_previous = (random.random() < 0.5)
        else:
            attend_previous = False

        for layer in self.get_decoder().layers:
            layer.condition_media_locations(media_locations)
            layer.condition_attend_previous(attend_previous)

        return super().forward(
            *input, **kwargs)  # Call the other parent's forward method

    def is_conditioned(self) -> bool:
        """Check whether all decoder layers are already conditioned."""
        return all(layer.is_conditioned()
                   for layer in self._get_decoder_layers())

    def clear_conditioned_layers(self):
        """Clear all conditional layers."""
        for layer in self._get_decoder_layers():
            layer.condition_vis_x(None)
            layer.condition_media_locations(None)
            layer.condition_attend_previous(None)
