from typing import Optional, Iterable
from dataclasses import dataclass

import torch
import torch.nn as nn

from transformers.modeling_outputs import ModelOutput
from transformers.cache_utils import DynamicCache
from transformers import PreTrainedModel


@dataclass
class AIMOutput(ModelOutput):
    loss: Optional[torch.Tensor] = None
    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
    value_states: Optional[tuple[torch.FloatTensor, ...]] = None


class PreTrainedModelForAIM(nn.Module):
    def __init__(
        self,
        model: PreTrainedModel,
        n_layers: int = 1,
        frozen_embeddings_mask: Optional[torch.BoolTensor] = None,
        all_layer_outputs: bool = False,
    ):
        super().__init__()

        self.model = self.get_base_model(model)
        self.config = self.model.config
        self.n_layers = n_layers
        self.all_layer_outputs = all_layer_outputs

        layers = self.get_layers()[:n_layers]
        self.set_layers(layers)

        if frozen_embeddings_mask is not None:
            partly_frozen_embeddings = self.create_partly_frozen_embeddings(frozen_embeddings_mask)
            self.model.set_input_embeddings(partly_frozen_embeddings)

    def get_base_model(self, model: PreTrainedModel) -> PreTrainedModel:
        raise NotImplementedError("Override this method for your architecture.")

    def get_layers(self) -> Iterable[nn.Module]:
        raise NotImplementedError("Override this method for your architecture.")
    
    def set_layers(self, layers: Iterable[nn.Module]):
        raise NotImplementedError("Override this method for your architecture.")
    
    def get_input_embeddings(self) -> nn.Embedding:
        return self.model.get_input_embeddings()
    
    def set_input_embeddings(self, embeddings: nn.Embedding):
        self.model.set_input_embeddings(embeddings)
    
    def create_partly_frozen_embeddings(self, frozen_mask: torch.Tensor) -> nn.Embedding:
        raise NotImplementedError("Override this method for your architecture.")

    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        labels: Optional[torch.Tensor] = None,
    ) -> AIMOutput:

        past_key_values = DynamicCache()

        output = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            # gathering required outputs
            past_key_values=past_key_values,
            output_attentions=True,
            use_cache=True,
        )

        if self.all_layer_outputs:
            # stacking along num_heads
            attn_weights = torch.cat(output.attentions, dim=1)
            value_states = torch.cat(
                [past_key_values[i][1] for i in range(self.n_layers)],
                dim=1,
            )
        else:
            attn_weights = output.attentions[self.n_layers - 1]
            value_states = past_key_values[self.n_layers - 1][1]

        if labels is not None:
            raise NotImplementedError('Labels are not supported yet')

        return AIMOutput(
            attentions=attn_weights,
            value_states=value_states,
        )
