"""
base_vlm.py

Abstract class definition of a Vision-Language Model (VLM), with full annotations of class methods, utility functions,
and initialization logic. This is mostly to future-proof the codebase; while all our experiments instantiate
from src.vision_backbone.prismaticVLM, theoretically, this base class should be general enough to cover almost all models (e.g., IDEFICS,
PALI, Fuyu) in the future.

We use Abstract base classes *sparingly* -- mostly as a way to encapsulate any redundant logic or nested inheritance
(e.g., dependence on nn.Module, HF PretrainedModel, etc.). For other abstract objects (e.g., Tokenizers/Transforms),
prefer Protocol definitions instead.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, List, Optional

import torch
import torch.nn as nn
from transformers import GenerationMixin, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast

from src.vision_backbone.prismatic.models.backbones.llm import LLMBackbone
from src.vision_backbone.prismatic.models.backbones.llm.prompting import PromptBuilder
from src.vision_backbone.prismatic.models.backbones.vision import VisionBackbone


# === Abstract Base Class for arbitrary Vision-Language Models ===
class VLM(nn.Module, GenerationMixin, ABC):
    def __init__(
        self,
        model_family: str,
        model_id: str,
        vision_backbone: VisionBackbone,
        llm_backbone: LLMBackbone,
        enable_mixed_precision_training: bool = True,
    ) -> None:
        super().__init__()
        self.model_family, self.model_id = model_family, model_id
        self.vision_backbone, self.llm_backbone = vision_backbone, llm_backbone
        self.enable_mixed_precision_training = enable_mixed_precision_training

        # Instance Attributes for a generic VLM
        self.all_module_keys, self.trainable_module_keys = None, None

        # === GenerationMixin Expected Attributes =>> *DO NOT MODIFY* ===
        self.generation_config = self.llm_backbone.llm.generation_config
        self.main_input_name = "input_ids"

    @property
    def device(self) -> torch.device:
        """Borrowed from `transformers.modeling_utils.py` -- checks parameter device; assumes model on *ONE* device!"""
        return next(self.parameters()).device

    @classmethod
    @abstractmethod
    def from_pretrained(
        cls,
        pretrained_checkpoint: Path,
        model_family: str,
        model_id: str,
        vision_backbone: VisionBackbone,
        llm_backbone: LLMBackbone,
        **kwargs: str,
    ) -> VLM: ...

    @abstractmethod
    def get_prompt_builder(self, system_prompt: Optional[str] = None) -> PromptBuilder: ...

    @abstractmethod
    def freeze_backbones(self, stage: str) -> None: ...

    @abstractmethod
    def load_from_checkpoint(self, stage: str, run_dir: Path, pretrained_checkpoint: Optional[Path] = None) -> None: ...

    @abstractmethod
    def get_fsdp_wrapping_policy(self) -> Callable: ...

    @abstractmethod
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        multimodal_indices: Optional[torch.LongTensor] = None,
    ) -> CausalLMOutputWithPast: ...

    # === GenerationMixin Expected Properties & Methods (DO NOT MODIFY) ===
    @staticmethod
    def can_generate() -> bool:
        return True

    @property
    def config(self) -> PretrainedConfig:
        return self.llm_backbone.llm.config

    # => Beam Search Utility
    def _reorder_cache(self, past_key_values, beam_idx):
        return self.llm_backbone.llm._reorder_cache(past_key_values, beam_idx)
