from typing import Optional, List
import copy
import math
import torch
import torch.nn as nn
import torchvision.transforms.functional as F
import numpy as np
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from einops import rearrange, repeat
from accelerate.hooks import add_hook_to_module, AlignDevicesHook
import loralib as lora

from .configuration_otter import OtterConfig

from flamingo.falcon.modelling_RW import RWForCausalLM
from flamingo.mpt.modeling_mpt import MPTForCausalLM

from transformers.models.auto import AutoTokenizer

from .biovil_encoder import (
    BioVilEncoder, get_image_inference, get_bert_inference,
    ImageTextInferenceEngineForTraining, ImageModelType, BertEncoderType,
    ImageInferenceEngine, create_chest_xray_transform_for_inference
)
from .unimedi_encoder3D import create_vit3D, load_state_with_same_shape
from .unimedi_encoder3D import UnifiedEncoder3D
from .unimedi_encoder2D import create_vit2D
from .unimedi_encoder2D import UniMedIEncoder2D

import sys
import random

# The package importlib_metadata is in a different place, depending on the python version.
if sys.version_info < (3, 8):
    import importlib_metadata
else:
    import importlib.metadata as importlib_metadata

import torch.distributed as dist

# Add this line at the beginning of your script or in your main function
# dist.init_process_group(backend='nccl')

XFORMERS_AVAIL = False
XFORMERS_MSG_PRINTED = False  # Add this global variable
try:
    if not XFORMERS_MSG_PRINTED:  # Check if the message has been printed before
        import xformers.ops as xops
        from xformers_model import CLIPVisionModel, LlamaForCausalLM
        from transformers import LlamaTokenizer

        _xformers_version = importlib_metadata.version("xformers")
        if dist.is_initialized() and dist.get_rank() == 0:  # Check if the current process rank is 0
            print(f"Successfully imported xformers version {_xformers_version}")
except ImportError as e:
    if not XFORMERS_MSG_PRINTED:  # Check if the message has been printed before
        from transformers import CLIPVisionModel, LlamaForCausalLM, LlamaTokenizer

        if dist.is_initialized() and dist.get_rank() == 0:  # Check if the current process rank is 0
            print(f"Failed to import xformers: {e}")
            XFORMERS_AVAIL = False
            print("No xformers found. You are recommended to install xformers via `pip install xformers` or `conda install -c xformers xformers`")
            XFORMERS_MSG_PRINTED = True  # Set the variable to True after printing the message

# from transformers import CLIPVisionModel, LlamaForCausalLM, LlamaTokenizer

__KNOWN_DECODER_LAYERS_ATTR_NAMES = {
    "opt": "model.decoder.layers",
    "gptneo": "transformer.h",
    "gptj": "transformer.h",
    "gpt-j": "transformer.h",
    "pythia": "gpt_neox.layers",
    "llama": "model.layers",
    "RWForCausalLM": "transformer.h",
    "MPTForCausalLM": "transformer.blocks",
}


def _infer_decoder_layers_attr_name(model: nn.Module):
    for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
        if k.lower() in model.__class__.__name__.lower():
            return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]

    raise ValueError(
        f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
    )


def extend_instance(obj, mixin):
    """Apply mixins to a class instance after creation"""
    base_cls = obj.__class__
    base_cls_name = obj.__class__.__name__
    obj.__class__ = type(base_cls_name, (mixin, base_cls), {})  # mixin needs to go first for our forward() logic to work


def getattr_recursive(obj, att):
    """
    Return nested attribute of obj
    Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
    """
    if att == "":
        return obj
    i = att.find(".")
    if i < 0:
        return getattr(obj, att)
    else:
        return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])


def setattr_recursive(obj, att, val):
    """
    Set nested attribute of obj
    Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
    """
    if "." in att:
        obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
    setattr(obj, att.split(".")[-1], val)


def exists(val):
    return val is not None


class Projection(nn.Module):
    """
    Adapter that transform output of vision encoder into same dimension
        as output of original Flamingo/Otter
        (*,input) -> (*,1024)
    """
    def __init__(self, in_feature_size, out_feature_size=1024):
        super().__init__()
        self.fc1 = nn.Linear(in_feature_size, out_feature_size)
        self.fc2 = nn.Linear(out_feature_size, out_feature_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        """
        Convert vision encoder output to 1024 dim
        Args:
            x: Tensor of shape (*,input)
        Returns: Tensor of shape (*,1024)
        """
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


class Classifier(nn.Module):
    """
    Multilabel classifier that classifies the input latent queries, only use the first num_classes queries
        (*,n_cls,n_dim) -> (*,ncls,1)
    """
    def __init__(self, n_dim=1024, n_cls=14):
        super().__init__()
        self.n_cls = n_cls
        self.classifier_list = nn.ModuleList([
            nn.Linear(in_features=n_dim, out_features=1)
            for _ in range(self.n_cls)
        ])

    def forward(self, x):
        """
        Args:
            x: Tensor of shape (*,n_cls,n_dim)
        Returns: Tensor of shape (*,1)
        """
        outputs = []
        for i in range(self.n_cls):
            input = x[..., i, :]  # (...,1,n_dim)
            output = self.classifier_list[i](input)  # (...,1,1)
            outputs.append(output)
        outputs = torch.stack(outputs, dim=-2)  # (...,n_cls,1)
        # outputs = torch.sigmoid(outputs)
        return outputs


class OtterPerceiverBlock(nn.Module):
    def __init__(
        self, *,
        dim: int,
        dim_head: int = 64,
        heads: int = 8,
        mult: int = 4,
        use_lora: bool = False,
        adapter_prompt_length: int = 0
    ):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        self.inner_dim = dim_head * heads
        self.ff_dim = dim * mult
        self.norm_media = nn.LayerNorm(dim)
        self.norm_latents = nn.LayerNorm(dim)

        if use_lora:
            self.to_q = lora.Linear(dim, self.inner_dim, bias=False, merge_weights=False, r=32)  # (1024 -> 512)
            self.to_kv = lora.Linear(dim, self.inner_dim * 2, bias=False, merge_weights=False, r=32)  # (1024 -> 1024)
            self.to_out = lora.Linear(self.inner_dim, dim, bias=False, merge_weights=False, r=32)  # (512 -> 1024)
        else:
            self.to_q = nn.Linear(dim, self.inner_dim, bias=False)
            self.to_kv = nn.Linear(dim, self.inner_dim * 2, bias=False)
            self.to_out = nn.Linear(self.inner_dim, dim, bias=False)

        self.feed_forward = nn.ModuleList(
            [
                nn.LayerNorm(dim),
                lora.Linear(dim, self.ff_dim, bias=False, merge_weights=False, r=32)  # (1024 -> 4096)
                if use_lora else nn.Linear(dim, self.ff_dim, bias=False),
                nn.GELU(),
                lora.Linear(self.ff_dim, dim, bias=False, merge_weights=False, r=32)  # (4096 -> 1024)
                if use_lora else nn.Linear(self.ff_dim, dim, bias=False),
            ]
        )

    def forward(self, x: torch.Tensor, latents: torch.Tensor, adapter_x: torch.Tensor, roi_adapter_x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): image features
                shape (b, T, n1, D)
            latent (torch.Tensor): latent features
                shape (b, T, n2, D)
            adapert_x: dummy input
            roi_adapert_x: dummy input
        """
        x = self.norm_media(x)
        residual_latents = latents
        latents = self.norm_latents(latents)

        h = self.heads

        q = self.to_q(latents)
        kv_input = torch.cat((x, latents), dim=-2)
        k, v = self.to_kv(kv_input).chunk(2, dim=-1)
        q = rearrange(q, "b t n (h d) -> b h t n d", h=h)
        k = rearrange(k, "b t n (h d) -> b h t n d", h=h)
        v = rearrange(v, "b t n (h d) -> b h t n d", h=h)
        q = q * self.scale

        # attention
        sim = torch.einsum("... i d, ... j d  -> ... i j", q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)

        out = torch.einsum("... i j, ... j d -> ... i d", attn, v)
        out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
        out = self.to_out(out) + residual_latents
        residual_out = out
        for layer in self.feed_forward:
            out = layer(out)
        return out + residual_out


class LlamaAdapterPlusPerceiverBlock(OtterPerceiverBlock):
    def __init__(
        self, *,
        dim: int,
        dim_head: int = 64,
        heads: int = 8,
        mult: int = 4,
        use_lora: bool = False,
        **kwargs
    ):
        super().__init__(dim=dim, dim_head=dim_head, heads=heads, mult=mult, use_lora=use_lora)

        self.adapter_gate = nn.Parameter(torch.empty(1, self.heads, 1, 1, 1))
        self.adapter_to_kv = nn.Linear(dim, self.inner_dim * 2, bias=False)
        self.adapter_to_out = nn.Linear(self.inner_dim, dim, bias=False)

    def forward(self, x: torch.Tensor, latents: torch.Tensor, adapter_x: torch.Tensor = None) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): image features
                shape (b, T, n1, D)
            latent (torch.Tensor): latent features
                shape (b, T, n2, D)
            adapter_x (torch.Tensor): adapter image features
                shape (b, T, n1, D)
        """
        x = self.norm_media(x)
        residual_latents = latents
        latents = self.norm_latents(latents)
        h = self.heads
        q = self.to_q(latents)
        kv_input = torch.cat((x, latents), dim=-2)
        k, v = self.to_kv(kv_input).chunk(2, dim=-1)
        q = rearrange(q, "b t n (h d) -> b h t n d", h=h)
        k = rearrange(k, "b t n (h d) -> b h t n d", h=h)
        v = rearrange(v, "b t n (h d) -> b h t n d", h=h)
        q = q * self.scale

        # attention
        sim = torch.einsum("... i d, ... j d  -> ... i j", q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)
        out = torch.einsum("... i j, ... j d -> ... i d", attn, v)
        out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
        out = self.to_out(out) + residual_latents

        if adapter_x is not None:
            adapter_x = self.norm_media(adapter_x)
            adapter_kv_input = torch.cat((adapter_x, latents), dim=-2)
            adapter_k, adapter_v = self.adapter_to_kv(adapter_kv_input).chunk(2, dim=-1)
            adapter_k = rearrange(adapter_k, "b t n (h d) -> b h t n d", h=h)
            adapter_v = rearrange(adapter_v, "b t n (h d) -> b h t n d", h=h)

            # attention
            adapter_sim = torch.einsum("... i d, ... j d  -> ... i j", q, adapter_k)
            adapter_sim = adapter_sim - adapter_sim.amax(dim=-1, keepdim=True).detach()
            adapter_attn = self.adapter_gate.tanh() * adapter_sim.softmax(dim=-1)
            adapter_out = torch.einsum("... i j, ... j d -> ... i d", adapter_attn, adapter_v)
            adapter_out = rearrange(adapter_out, "b h t n d -> b t n (h d)", h=h)
            out += self.adapter_to_out(adapter_out)

        residual_out = out
        for layer in self.feed_forward:
            out = layer(out)
        return out + residual_out


class LlamaAdapterConcatPerceiverBlock(OtterPerceiverBlock):
    def __init__(
        self, *,
        dim: int,
        dim_head: int = 64,
        heads: int = 8,
        mult: int = 4,
        adapter_prompt_length: int = 10,
        use_lora: bool = False
    ):
        super().__init__(dim=dim, dim_head=dim_head, heads=heads, mult=mult, use_lora=use_lora)

        self.adapter_gate = nn.Parameter(torch.empty(1, self.heads, 1, 1, 1))
        self.adapter_prompt_length = adapter_prompt_length
        self.adapter_prompt = nn.Parameter(torch.empty(1, 1, self.adapter_prompt_length, dim))
        self.roi_adapter_prompt = nn.Parameter(torch.empty(1, 1, self.adapter_prompt_length, dim))

    def forward(
        self, x: torch.Tensor,
        latents: torch.Tensor,
        adapter_x: torch.Tensor = None,
        roi_adapter_x: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): image features
                shape (b, T, n1, D)
            latent (torch.Tensor): latent features
                shape (b, T, n2, D)
            adapter_x (torch.Tensor): adapter image features
                shape (b, T, 1, D)
            roi_adapter_x (torch.Tensor): adapter zoomed in roi image features
                shape (b, T, 1, D)
        """
        x = self.norm_media(x)
        residual_latents = latents
        latents = self.norm_latents(latents)
        h = self.heads
        q = self.to_q(latents)
        kv_input = torch.cat((x, latents), dim=-2)
        k, v = self.to_kv(kv_input).chunk(2, dim=-1)
        q = rearrange(q, "b t n (h d) -> b h t n d", h=h)
        k = rearrange(k, "b t n (h d) -> b h t n d", h=h)
        v = rearrange(v, "b t n (h d) -> b h t n d", h=h)
        q = q * self.scale

        if adapter_x is not None:  # biovil (4,2,*,d)
            adapter_x = adapter_x + self.adapter_prompt
            if roi_adapter_x is not None:
                roi_adapter_x = roi_adapter_x + self.roi_adapter_prompt
                adapter_x = torch.cat([adapter_x, roi_adapter_x], dim=-2)
            adapter_x = self.norm_media(adapter_x)
            # adapter_kv_input = torch.cat((adapter_x, latents), dim=-2)
            adapter_kv_input = adapter_x
            adapter_k, adapter_v = self.to_kv(adapter_kv_input).chunk(2, dim=-1)
            adapter_k = rearrange(adapter_k, "b t n (h d) -> b h t n d", h=h)
            adapter_v = rearrange(adapter_v, "b t n (h d) -> b h t n d", h=h)
            k = torch.cat([adapter_k, k], dim=3)  # biovil adapter_k(4,8,2,*,64)
            v = torch.cat([adapter_v, v], dim=3)

        # attention
        sim = torch.einsum("... i d, ... j d  -> ... i j", q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        if adapter_x is not None:
            attn = torch.cat([
                self.adapter_gate.tanh() * sim[:, :, :, :, :self.adapter_prompt_length].softmax(dim=-1),
                sim[:, :, :, :, self.adapter_prompt_length:].softmax(dim=-1)
            ], dim=-1)
        else:
            attn = sim.softmax(dim=-1)

        out = torch.einsum("... i j, ... j d -> ... i d", attn, v)
        out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
        out = self.to_out(out) + residual_latents
        residual_out = out
        for layer in self.feed_forward:
            out = layer(out)
        return out + residual_out


class OtterPerceiverResampler(nn.Module):
    def __init__(
        self,
        *,
        dim: int,
        mode: str,
        depth: int = 6,
        adapter_depth = 6,
        dim_head: int = 64,
        heads: int = 8,
        num_latents: int = 64,
        # max_num_frames: int = 128,
        max_num_media: Optional[int] = None,
        max_num_frames: Optional[int] = None,
        ff_mult: int = 4,
        adapter_prompt_length: int = 10,
        use_lora: bool = False,
    ):
        super().__init__()
        self.vision_encode_mode = mode
        self.latents = nn.Parameter(torch.randn(num_latents, dim))
        self.frame_embs = (
            nn.Parameter(torch.randn(max_num_frames, dim))
            if exists(max_num_frames)
            else None
        )
        self.media_time_embs = (
            nn.Parameter(torch.randn(max_num_media, 1, dim))
            if exists(max_num_media)
            else None
        )

        self.layers = nn.ModuleList([])

        PerceiverBlock = OtterPerceiverBlock
        if self.vision_encode_mode == "llama_adapter_plus":
            PerceiverBlock = LlamaAdapterPlusPerceiverBlock
        elif self.vision_encode_mode == "llama_adapter_concat":
            PerceiverBlock = LlamaAdapterConcatPerceiverBlock

        for _ in range(depth):
            self.layers.append(
                PerceiverBlock(
                    dim=dim, dim_head=dim_head, heads=heads, mult=ff_mult,
                    adapter_prompt_length=adapter_prompt_length, use_lora=use_lora
                )
            )

        self.norm = nn.LayerNorm(dim)
        self.adapter_depth = adapter_depth

    def forward(
        self, x: torch.Tensor,
        adapter_x: torch.Tensor = None,
        roi_adapter_x: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): image features
                shape (b, T, F, v, D)
            adapter_x (torch.Tensor): additional image features
                shape (b, T, F, v, D)
            roi_adapter_x (torch.Tensor): additional zoomed in roi image features
                shape (b, T, F, v, D)
        Returns:
            shape (b, T, n, D) where n is self.num_latents
        """
        b, T, F, v = x.shape[:4]

        # frame and media time embeddings
        if exists(self.frame_embs):
            frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
            x = x + frame_embs
        x = rearrange(x, "b T F v d -> b T (F v) d")  # flatten the frame and spatial dimensions
        if exists(self.media_time_embs):
            x = x + self.media_time_embs[:T]

        if adapter_x is not None:
            if exists(self.frame_embs):
                frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
                adapter_x = adapter_x + frame_embs
            if adapter_x.ndim > 4:
                adapter_x = rearrange(adapter_x, "b T F v d -> b T (F v) d")  # flatten the frame and spatial dimensions
            if exists(self.media_time_embs):
                adapter_x = adapter_x + self.media_time_embs[:T]

            if roi_adapter_x is not None:
                if exists(self.frame_embs):
                    frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
                    roi_adapter_x = roi_adapter_x + frame_embs
                if roi_adapter_x.ndim > 4:
                    roi_adapter_x = rearrange(roi_adapter_x, "b T F v d -> b T (F v) d")  # flatten the frame and spatial dimensions
                if exists(self.media_time_embs):
                    roi_adapter_x = roi_adapter_x + self.media_time_embs[:T]

        # blocks
        latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
        for block in self.layers[:-self.adapter_depth]:
            latents = block(x, latents)
        adapter_layer_index = 0
        for block in self.layers[-self.adapter_depth:]:
            # latents = block(x, latents, adapter_x+self.adapter_prompt[adapter_layer_index])
            latents = block(x, latents, adapter_x, roi_adapter_x)
        return self.norm(latents)


class OtterMaskedCrossAttention(nn.Module):
    def __init__(
        self,
        *,
        dim: int,
        dim_visual: int,
        dim_head: int = 64,
        heads: int = 8,
        only_attend_immediate_media: bool = True,
        use_lora: bool = False,
    ):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm = nn.LayerNorm(dim)

        if use_lora:
            self.to_q = lora.Linear(dim, inner_dim, bias=False, merge_weights=False, r=32)  # (4096 -> 512)
            self.to_kv = lora.Linear(dim_visual, inner_dim * 2, merge_weights=False, bias=False, r=32)  # (1024 -> 512)
            self.to_out = lora.Linear(inner_dim, dim, bias=False, merge_weights=False, r=32)  # (512 -> 4096)
        else:
            self.to_q = nn.Linear(dim, inner_dim, bias=False)
            self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
            self.to_out = nn.Linear(inner_dim, dim, bias=False)


        # whether for text to only attend to immediate preceding image, or all previous images
        self.only_attend_immediate_media = only_attend_immediate_media

    def forward(
        self,
        x: torch.Tensor,
        media: torch.Tensor,
        media_locations: Optional[torch.BoolTensor] = None,
        attend_previous: bool = True,
    ) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): text features
                shape (B, T_txt, D_txt)
            media (torch.Tensor): image features
                shape (B, T_img, n, D_img) where n is the dim of the latents
            media_locations: boolean mask identifying the media tokens in x
                shape (B, T_txt)
            attend_previous: bool
                If false, ignores immediately preceding image and starts attending when following image
        """
        _, T_img, n = media.shape[:3]
        h = self.heads

        x = self.norm(x)

        q = self.to_q(x)
        media = rearrange(media, "b t n d -> b (t n) d")

        k, v = self.to_kv(media).chunk(2, dim=-1)
        if not XFORMERS_AVAIL:
            q = rearrange(q, "b n (h d) -> b h n d", h=h)
            k = rearrange(k, "b n (h d) -> b h n d", h=h)
            v = rearrange(v, "b n (h d) -> b h n d", h=h)
            q = q * self.scale

            sim = torch.einsum("... i d, ... j d -> ... i j", q, k)
            if exists(media_locations):
                # at each boolean of True, increment the time counter (relative to media time)
                text_time = media_locations.cumsum(dim=-1)
                media_time = torch.arange(T_img, device=x.device) + 1

                if not attend_previous:
                    text_time[~media_locations] += 1
                    # make sure max is still the number of images in the sequence
                    text_time[
                        text_time
                        > repeat(
                            torch.count_nonzero(media_locations, dim=1),
                            "b -> b i",
                            i=text_time.shape[1],
                        )
                    ] = 0

                # text time must equal media time if only attending to most immediate image
                # otherwise, as long as text time is greater than media time (if attending to all previous images / media)
                mask_op = torch.eq if self.only_attend_immediate_media else torch.ge

                text_to_media_mask = mask_op(
                    rearrange(text_time, "b i -> b 1 i 1"),
                    repeat(media_time, "j -> 1 1 1 (j n)", n=n),
                )
                sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)

            sim = sim - sim.amax(dim=-1, keepdim=True).detach()
            attn = sim.softmax(dim=-1)

            if exists(media_locations) and self.only_attend_immediate_media:
                # any text without a preceding media needs to have attention zeroed out
                text_without_media_mask = text_time == 0
                text_without_media_mask = rearrange(text_without_media_mask, "b i -> b 1 i 1")
                attn = attn.masked_fill(text_without_media_mask, 0.0)

            out = torch.einsum("... i j, ... j d -> ... i d", attn, v)
            out = rearrange(out, "b h n d -> b n (h d)")
        else:
            q = rearrange(q, "b n (h d) -> b n h d", h=h)
            k = rearrange(k, "b n (h d) -> b n h d", h=h)
            v = rearrange(v, "b n (h d) -> b n h d", h=h)
            attn_mask = None
            out = xops.memory_efficient_attention(q, k, v, attn_bias=attn_mask, scale=self.scale)
        return self.to_out(out)


class OtterGatedCrossAttentionBlock(nn.Module):
    def __init__(
        self,
        *,
        dim: int,
        dim_visual: int,
        dim_head: int = 64,
        heads: int = 8,
        ff_mult: int = 4,
        only_attend_immediate_media: bool = True,
        use_lora: bool = False,
    ):
        super().__init__()
        self.attn = OtterMaskedCrossAttention(
            dim=dim,
            dim_visual=dim_visual,
            dim_head=dim_head,
            heads=heads,
            only_attend_immediate_media=only_attend_immediate_media,
            use_lora=use_lora
        )
        self.attn_gate = nn.Parameter(torch.tensor([0.0]))
        self.feed_forward = nn.ModuleList(
            [
                nn.LayerNorm(dim),
                lora.Linear(dim, dim * ff_mult, bias=False, merge_weights=False, r=32)
                if use_lora else nn.Linear(dim, dim * ff_mult, bias=False),  # (4096 -> 16384)
                nn.GELU(),
                lora.Linear(dim * ff_mult, dim, bias=False, merge_weights=False, r=32)
                if use_lora else nn.Linear(dim * ff_mult, dim, bias=False),  # (16384 -> 4096)
            ]
        )
        self.ff_gate = nn.Parameter(torch.tensor([0.0]))

    def forward(
        self,
        x: torch.Tensor,
        media: torch.Tensor,
        media_locations: Optional[torch.BoolTensor] = None,
        attend_previous: bool = True,
    ) -> torch.Tensor:
        x = (
            self.attn(
                x,
                media,
                media_locations=media_locations,
                attend_previous=attend_previous,
            )
            * self.attn_gate.tanh()
            + x
        )
        residual_x = x
        for ff in self.feed_forward:
            x = ff(x)
        x = x * self.ff_gate.tanh() + residual_x

        return x


class OtterLayer(nn.Module):
    def __init__(self, gated_cross_attn_layer: nn.Module, decoder_layer: nn.Module):
        super().__init__()
        self.gated_cross_attn_layer = gated_cross_attn_layer
        self.decoder_layer = decoder_layer
        self.vis_x = None
        self.media_locations = None

    def is_conditioned(self) -> bool:
        """Check whether the layer is conditioned."""
        return self.vis_x is not None

    # Used this great idea from this implementation of Otter (https://github.com/dhansmair/otter-mini/)
    def condition_vis_x(self, vis_x) -> None:
        self.vis_x = vis_x

    def condition_media_locations(self, media_locations) -> None:
        self.media_locations = media_locations

    def condition_attend_previous(self, attend_previous) -> None:
        self.attend_previous = attend_previous

    def forward(
        self,
        lang_x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        **decoder_layer_kwargs,
    ):
        if self.gated_cross_attn_layer is None:
            return self.decoder_layer(lang_x, attention_mask=attention_mask, **decoder_layer_kwargs)

        if self.vis_x is None:
            raise ValueError("vis_x must be conditioned before forward pass")

        if self.media_locations is None:
            raise ValueError("media_locations must be conditioned before forward pass")

        lang_x = self.gated_cross_attn_layer(
            lang_x,
            self.vis_x,
            media_locations=self.media_locations,
            attend_previous=self.attend_previous,
        )
        lang_x = self.decoder_layer(lang_x, attention_mask=attention_mask, **decoder_layer_kwargs)
        return lang_x


class OtterLMMixin(nn.Module):
    """
    Mixin to add cross-attention layers to a language model.
    """

    def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
        self.decoder_layers_attr_name = decoder_layers_attr_name

    def _get_decoder_layers(self):
        return getattr_recursive(self, self.decoder_layers_attr_name)

    def _set_decoder_layers(self, value):
        setattr_recursive(self, self.decoder_layers_attr_name, value)

    def init_otter(
        self,
        media_token_id: int,
        vis_hidden_size: int,
        cross_attn_every_n_layers: int,
        use_media_placement_augmentation: bool,
        use_lora: bool = False
    ):
        """
        Initialize Otter by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
        """

        gated_cross_attn_layers = nn.ModuleList(
            [
                OtterGatedCrossAttentionBlock(
                    dim=self.config.hidden_size,
                    dim_visual=vis_hidden_size,
                    use_lora=use_lora
                )
                if (layer_idx + 1) % cross_attn_every_n_layers == 0
                else None
                for layer_idx, _ in enumerate(self._get_decoder_layers())
            ]
        )
        self._set_decoder_layers(
            nn.ModuleList(
                [
                    OtterLayer(gated_cross_attn_layer, decoder_layer)
                    for gated_cross_attn_layer, decoder_layer in zip(gated_cross_attn_layers, self._get_decoder_layers())
                ]
            )
        )
        self.media_token_id = media_token_id
        self.use_media_placement_augmentation = use_media_placement_augmentation
        self.initialized_otter = True

    def forward(self, *input, **kwargs):
        """Condition the Otter layers on the media locations before forward()"""
        if not self.initialized_otter:
            raise ValueError("Otter layers are not initialized. Please call `init_otter` first.")

        input_ids = kwargs["input_ids"] if "input_ids" in kwargs else input[0]
        media_locations = input_ids == self.media_token_id
        # IMPORTANT: Force `attend_previous` to True when we place training data as <image>caption<|endofchunk|>
        # attend_previous = (
        #     (random.random() < 0.5) if self.use_media_placement_augmentation else False
        # )
        attend_previous = (random.random() < 0.5) if self.use_media_placement_augmentation else True
        # attend_previous = self.only_attend_previous

        if self.__class__.__name__ != "MPTForCausalLM":
            for layer in self.get_decoder().layers:
                layer.condition_media_locations(media_locations)
                layer.condition_attend_previous(attend_previous)
        else:
            for layer in self.get_decoder().blocks:
                layer.condition_media_locations(media_locations)
                layer.condition_attend_previous(attend_previous)

        return super().forward(*input, **kwargs)  # Call the other parent's forward method

    def is_conditioned(self) -> bool:
        """Check whether all decoder layers are already conditioned."""
        return all(l.is_conditioned() for l in self._get_decoder_layers())

    def clear_conditioned_layers(self) -> None:
        for layer in self._get_decoder_layers():
            layer.condition_vis_x(None)
            layer.condition_media_locations(None)
            layer.condition_attend_previous(None)


class OtterPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = OtterConfig
    base_model_prefix = "otter"
    supports_gradient_checkpointing = True
    _no_split_modules = ["OtterPerceiverBlock", "CLIPEncoderLayer", "OtterLayer"]

    def _init_weights(self, module):
        """Otter requires no specific initialization"""
        return super()._init_weights(module)


class OtterModel(OtterPreTrainedModel):
    config_class = OtterConfig

    def __init__(
        self,
        config: OtterConfig,
    ):
        super().__init__(config)

        ### TODO: give "LlamaForCausalLM" as the name of text_config.architectures of Llama_based flamingo
        if "llama" not in config.text_config._name_or_path:
            if config.text_config.architectures[0] == "MPTForCausalLM":
                text_tokenizer = AutoTokenizer.from_pretrained("mosaicml/mpt-7b-instruct")
                lang_encoder = MPTForCausalLM(config=config.text_config)
            elif config.text_config.architectures[0] == "RWForCausalLM":
                text_tokenizer = AutoTokenizer.from_pretrained("PATH-TO-YOUR-FALCON")
                lang_encoder = RWForCausalLM(config=config.text_config)
        else:
            text_tokenizer = LlamaTokenizer.from_pretrained(config.text_config._name_or_path)
            lang_encoder = LlamaForCausalLM(config=config.text_config)
        vision_encoder = CLIPVisionModel(config=config.vision_config)
        text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>", "<answer>"]})
        if text_tokenizer.pad_token is None:
            text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
        self.text_tokenizer = text_tokenizer
        self.eoc_token_id = text_tokenizer.encode("<|endofchunk|>")[-1]
        self.media_token_id = text_tokenizer.encode("<image>")[-1]

        extend_instance(lang_encoder, OtterLMMixin)
        decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
        lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
        if lang_encoder.__class__.__name__ != "MPTForCausalLM":
            lang_encoder.resize_token_embeddings(len(text_tokenizer))
        self.lang_encoder = lang_encoder

        self.cross_attn_every_n_layers = config.cross_attn_every_n_layers
        # use_media_placement_augmentation is strictly false for Otter model
        self.use_media_placement_augmentation = False  # config.use_media_placement_augmentation
        self.max_num_frames = config.max_num_frames if hasattr(config, "max_num_frames") else None

        vision_encoder.output_tokens = True
        self.vision_encoder = vision_encoder

        self.vis_dim = 1024
        self.perceiver = OtterPerceiverResampler(dim=self.vis_dim, max_num_frames=self.max_num_frames)

        self.lang_encoder.init_otter(
            media_token_id=self.media_token_id,
            vis_hidden_size=self.vis_dim,
            cross_attn_every_n_layers=self.cross_attn_every_n_layers,
            use_media_placement_augmentation=self.use_media_placement_augmentation,
            use_lora=self.use_lora,
        )
        self.post_init()

    def get_input_embeddings(self) -> nn.Module:
        return self.lang_encoder.get_input_embeddings()

    def set_input_embeddings(self, new_embeddings):
        self.lang_encoder.set_input_embeddings(new_embeddings)

    def get_output_embeddings(self) -> nn.Module:
        return self.lang_encoder.get_output_embeddings()

    def set_output_embeddings(self, new_embeddings):
        self.lang_encoder.set_output_embeddings(new_embeddings)

    def get_image_encoder(self) -> nn.Module:
        return self.vision_encoder

    def get_lang_encoder(self) -> nn.Module:
        return self.lang_encoder

    def tie_weights(self):
        return super().tie_weights()

    def init_weights(self):
        # Freeze all parameters in vision encoder
        for param in self.vision_encoder.parameters():
            param.requires_grad = False
        # Freeze all parameters in lang encoders except gated_cross_attn_layers
        for name, param in self.lang_encoder.named_parameters():
            if "gated_cross_attn_layer" not in name:
                param.requires_grad = False
        # Unfreeze LM input embeddings
        self.lang_encoder.get_input_embeddings().requires_grad_(True)
        ## MPTForCausalLM is tied word embedding
        if self.lang_encoder.__class__.__name__ == "LlamaForCausalLM":
            self.lang_encoder.lm_head.requires_grad_(True)
        # assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
        # print model size in billions of parameters in 2 decimal places
        print(f"Trainable param: {(sum(p.numel() for p in self.parameters() if p.requires_grad)) / 1e9:.2f} B")

    def forward(
        self,
        vision_x: torch.Tensor,
        lang_x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        use_cached_vision_x: bool = False,
        clear_conditioned_layers: bool = True,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: bool = False,
        **kwargs,
    ) -> CausalLMOutputWithPast:
        """
        Forward pass of Otter.

        Args:
            vision_x (torch.Tensor): Vision input
                shape (B, T_img, F, C, H, W) with F=1
            lang_x (torch.Tensor): Language input ids
                shape (B, T_txt)
            attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
            labels (torch.Tensor, optional): Labels. Defaults to None.
            clear_conditioned_layers: if True, clear the conditioned layers
                once the foward pass is completed. Set this to false if the
                same set of images will be reused in another subsequent
                forward pass.
            past_key_values: pre-computed values to pass to language model.
                See past_key_values documentation in Hugging Face
                CausalLM models.
            use_cache: whether to use cached key values. See use_cache
                documentation in Hugging Face CausalLM models.
        """
        assert (vision_x is not None) or use_cached_vision_x, "Must provide either vision_x or use_cached_vision_x to True."

        if use_cached_vision_x:
            # Case: use cached; vision_x should be cached and other
            # vision-related inputs should not be provided.
            assert vision_x is None, "Expect vision_x to be None when use_cached_vision_x is True."
            assert self.lang_encoder.is_conditioned()

        else:
            # Case: do not use caching (i.e. this is a standard forward pass);
            self._encode_vision_x(vision_x=vision_x)

        output = self.lang_encoder(
            input_ids=lang_x,
            attention_mask=attention_mask,
            labels=labels,
            past_key_values=past_key_values,
            use_cache=use_cache,
            **kwargs,
        )

        if clear_conditioned_layers:
            self.lang_encoder.clear_conditioned_layers()

        return output

    def _encode_vision_x(self, vision_x: torch.Tensor):
        """
        Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
        Args:
            vision_x (torch.Tensor): Vision input
                shape (B, T_img, F, C, H, W)
                Images in the same chunk are collated along T_img, and frames are collated along F
                Currently only F=1 is supported (single-frame videos)

        rearrange code based on https://github.com/dhansmair/flamingo-mini
        """

        assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
        b, T, F = vision_x.shape[:3]

        vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
        with torch.no_grad():
            vision_x = self.vision_encoder(vision_x)[0][:, 1:, :]
        vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)

        vision_x = self.perceiver(vision_x)  # reshapes to (b, T, n, d)

        for layer in self.lang_encoder._get_decoder_layers():
            layer.condition_vis_x(vision_x)


class OtterForConditionalGeneration(OtterPreTrainedModel):
    config_class = OtterConfig

    def __init__(
        self,
        config: OtterConfig,
        **kwargs
    ):
        super().__init__(config)
        ### TODO: give "LlamaForCausalLM" as the name of text_config.architectures of Llama_based flamingo
        if "llama" not in config.text_config._name_or_path:
            if config.text_config.architectures[0] == "MPTForCausalLM":
                text_tokenizer = AutoTokenizer.from_pretrained("mosaicml/mpt-7b-instruct")
                lang_encoder = MPTForCausalLM(config=config.text_config)
            elif config.text_config.architectures[0] == "RWForCausalLM":
                text_tokenizer = AutoTokenizer.from_pretrained("PATH-TO-YOUR-FALCON")
                lang_encoder = RWForCausalLM(config=config.text_config)
            # TODO: what's the logic here?
            elif config.text_config.architectures[0] == "LlamaForCausalLM":
                text_tokenizer = LlamaTokenizer.from_pretrained(config.text_config._name_or_path)
                lang_encoder = LlamaForCausalLM(config=config.text_config)
            else:
                import pdb

                pdb.set_trace()
        else:
            text_tokenizer = LlamaTokenizer.from_pretrained(config.text_config._name_or_path)
            lang_encoder = LlamaForCausalLM(config=config.text_config)
        vision_encoder = CLIPVisionModel(config=config.vision_config)

        self.vision_encoder = vision_encoder
        self.medical_vision_encoder = None
        self.medical_position_embedding = None
        self.medical_vision_adapter = None
        self.vision_encode_mode = kwargs.get("vision_encode_mode")
        self.num_vision_tokens = kwargs.get("num_vision_tokens")
        self.downsample_frames = kwargs.get("downsample_frames")
        self.use_lora = kwargs.get("use_lora")

        text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>", "<answer>"]})
        if text_tokenizer.pad_token is None:
            text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
        self.text_tokenizer = text_tokenizer
        self.eoc_token_id = text_tokenizer.encode("<|endofchunk|>")[-1]
        self.media_token_id = text_tokenizer.encode("<image>")[-1]

        extend_instance(lang_encoder, OtterLMMixin)
        decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
        lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
        if lang_encoder.__class__.__name__ != "MPTForCausalLM":
            lang_encoder.resize_token_embeddings(len(text_tokenizer))
        self.lang_encoder = lang_encoder

        self.cross_attn_every_n_layers = config.cross_attn_every_n_layers
        # use_media_placement_augmentation is strictly false for Otter model
        self.use_media_placement_augmentation = False  # config.use_media_placement_augmentation
        self.max_num_frames = config.max_num_frames if hasattr(config, "max_num_frames") else None

        # Informative print statement
        if self.max_num_frames is None or self.max_num_frames == 1:
            print(f"The current model version is configured for Otter-Image with max_num_frames set to {self.max_num_frames}.")
        else:
            print(f"The current model version is configured for Otter-Video with a maximum of {self.max_num_frames} frames.")

        vision_encoder.output_tokens = True
        self.vision_encoder = vision_encoder

        self.vis_dim = 1024
        self.perceiver = OtterPerceiverResampler(
            dim=self.vis_dim,
            mode=self.vision_encode_mode,
            adapter_prompt_length=self.num_vision_tokens,
            max_num_frames=self.max_num_frames,
            use_lora=self.use_lora
        )

        self.lang_encoder.init_otter(
            media_token_id=self.media_token_id,
            vis_hidden_size=self.vis_dim,
            cross_attn_every_n_layers=self.cross_attn_every_n_layers,
            use_media_placement_augmentation=self.use_media_placement_augmentation,
            use_lora=self.use_lora
        )
        self.post_init()

    def init_medical_position_embedding(self, method, encoder_grid_size=15, med_image_size=480):
        emb_dim = self.vision_encoder.vision_model.embeddings.position_embedding.embedding_dim
        num_emb = encoder_grid_size**2
        self.medical_position_embedding = nn.Embedding(
            num_embeddings=num_emb,
            embedding_dim=emb_dim
        )
        if method == "flamingo":
            pos_emb = self.vision_encoder.vision_model.embeddings.position_embedding
            old_grid_size = int(math.sqrt(pos_emb.num_embeddings))
            pos_emb = pos_emb(torch.arange(1,pos_emb.num_embeddings))  # 0 idx is cls token in flamingo encoder
            pos_emb = rearrange(pos_emb, "(h w) d -> 1 d h w", h=old_grid_size, w=old_grid_size)
            pos_emb = nn.functional.interpolate(pos_emb, size=encoder_grid_size, mode="bilinear")
            pos_emb = rearrange(pos_emb, "1 d h w -> (h w) d")
            self.medical_position_embedding.weight.data = pos_emb
        elif method == "sin":
            # Code from https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L23
            def get_position_angle_vec(position):
                return [position / np.power(10000, 2 * (i // 2) / emb_dim) for i in range(emb_dim)]
            sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_emb)])
            sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
            sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1
            self.medical_position_embedding.weight.data = torch.FloatTensor(sinusoid_table).to(self.device)
        else:
            raise NotImplementedError
        # Resized embedding for med roi embedding interpolation
        med_pos_emb = self.medical_position_embedding.weight.data
        med_pos_emb = rearrange(med_pos_emb, "(h w) d -> 1 d h w", h=encoder_grid_size, w=encoder_grid_size)
        self.resized_medical_position_embedding = nn.functional.interpolate(med_pos_emb, size=med_image_size, mode="bilinear")


    def init_medical_vision_encoder(self, args):
        if args.vision_encode_mode == "original": return
        if args.dataset_type == "mimic_cxr":
            self.classifier = Classifier(n_dim=self.vis_dim, n_cls=14).to(self.device)
        if args.vision_encoder_type == "biovil":
            self.medical_vision_encoder = BioVilEncoder(
                img_encoder_type="resnet50",
                joint_feature_size=128,
                freeze_encoder=True,
                pretrained_model_path=args.medical_vision_encoder_path,
            ).to(self.device)
            self.medical_vision_adapter = Projection(
                in_feature_size=self.medical_vision_encoder.feature_size,
                out_feature_size=self.vis_dim
            ).to(self.device)
            self.init_medical_position_embedding(method=args.med_pos_emb, encoder_grid_size=15)
        elif args.vision_encoder_type == "unimedi2d":
            checkpoint = torch.load(args.medical_vision_encoder_path)["model"]
            self.medical_vision_encoder = create_vit2D(
                image2D_size=224,
                hidden_dim=2048,
                output_dim=128,
                patch_out_dim=65536,
                with_distill=True,
                masked_im_modeling=False,
                mask_ratio=0.75,
                drop_path_rate=0.0
            ).to(self.device)
            model_dict = {}
            for k,v in checkpoint.items():
                if "image_encoder_q_teacher." in k:
                    model_dict[k.replace("module.image_encoder_q_teacher.", "")] = v
            missing_keys, unexpected_keys = self.medical_vision_encoder.load_state_dict(model_dict, strict=False)
            assert len(missing_keys) == 0
            self.medical_vision_encoder = self.medical_vision_encoder.to(self.device)
            self.medical_vision_adapter = Projection(
                in_feature_size=self.medical_vision_encoder.embed_dim,
                out_feature_size=self.vis_dim
            ).to(self.device)
            self.init_medical_position_embedding(method=args.med_pos_emb, encoder_grid_size=15)  # BioViL grid feature shape
            raise NotImplementedError("check medical position embedding")
        elif args.vision_encoder_type == "unimedi3d":
            checkpoint = torch.load(args.medical_vision_encoder_path)["model"]
            self.medical_vision_encoder = create_vit3D(
                "base",
                image2D_size=224,
                image3D_size=128,
                slices=32,
                hidden_dim=2048,
                output_dim=128,
            ).to(self.device)
            matched_weights = load_state_with_same_shape(self.medical_vision_encoder, checkpoint)
            model_dict = self.medical_vision_encoder.state_dict()
            model_dict.update(matched_weights)
            missing_keys, unexpected_keys = self.medical_vision_encoder.load_state_dict(model_dict, strict=False)
            assert len(missing_keys) == 0
            self.medical_vision_encoder = self.medical_vision_encoder.to(self.device)
            self.medical_vision_adapter = Projection(
                in_feature_size=self.medical_vision_encoder.embed_dim,
                out_feature_size=self.vis_dim
            ).to(self.device)
            self.init_medical_position_embedding(method=args.med_pos_emb, encoder_grid_size=15)
            raise NotImplementedError("check medical position embedding")


    def init_medical_roi_extractor(self, args):
        if args.vision_encode_mode != "original" \
        and args.dataset_type == "mimic_cxr" \
        and args.vision_encoder_type == "biovil":
            self.medical_roi_extractor = ImageTextInferenceEngineForTraining(
                image_inference_engine=ImageInferenceEngine(
                    image_model=self.medical_vision_encoder,
                ),
                text_inference_engine=get_bert_inference(BertEncoderType.CXR_BERT),
            )
            self.medical_roi_extractor.text_inference_engine.model = self.medical_roi_extractor.text_inference_engine.model.to(
                device=self.device
            )

    def freeze_layers(self, mode):
        if mode == "original":
            return
        elif mode in ["medical_only", "llama_adapter_plus", "llama_adapter_concat"]:
            for name, param in self.named_parameters():
                if "lora_" in name:
                    param.requires_grad = True
                elif "adapter" in name:
                    param.requires_grad = True
                elif "classifier" in name:
                    param.requires_grad = True
                # elif "latents" in name and "perceiver" in name:
                #     param.requires_grad = True
                elif "gated_cross_attn_layer" in name:
                    param.requires_grad = True
                # elif "medical_vision_encoder" in name:
                #     param.requires_grad = True
                # elif "perceiver" in name:
                #     param.requires_grad = True
                else:
                    param.requires_grad = False
            # Freeze medical encoder
            for param in self.medical_vision_encoder.parameters():
                param.requires_grad = False
            # Freeze RoI extractor
            if hasattr(self, "medical_roi_extractor"):
                for param in self.medical_roi_extractor.image_inference_engine.model.parameters():
                    param.requires_grad = False
                for param in self.medical_roi_extractor.text_inference_engine.model.parameters():
                    param.requires_grad = False
        else:
            raise NotImplementedError


    def get_input_embeddings(self) -> nn.Module:
        return self.lang_encoder.get_input_embeddings()

    def set_input_embeddings(self, new_embeddings):
        self.lang_encoder.set_input_embeddings(new_embeddings)

    def get_output_embeddings(self) -> nn.Module:
        return self.lang_encoder.get_output_embeddings()

    def set_output_embeddings(self, new_embeddings):
        self.lang_encoder.set_output_embeddings(new_embeddings)

    def get_image_encoder(self) -> nn.Module:
        return self.vision_encoder

    def get_lang_encoder(self) -> nn.Module:
        return self.lang_encoder

    def init_weights(self):
        # Freeze all parameters in vision encoder
        for param in self.vision_encoder.parameters():
            param.requires_grad = False
        # Freeze all parameters in lang encoders except gated_cross_attn_layers
        for name, param in self.lang_encoder.named_parameters():
            if "gated_cross_attn_layer" not in name:
                param.requires_grad = False
        # Unfreeze LM input and output embeddings
        self.lang_encoder.get_input_embeddings().requires_grad_(True)
        ## MPTForCausalLM is tied word embedding
        if self.lang_encoder.__class__.__name__ == "LlamaForCausalLM":
            self.lang_encoder.lm_head.requires_grad_(True)
        # assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0



    def init_adapter_weights(self, mode):
        if mode == "original":
            return
        for name, param in self.named_parameters():
            if "lora_A" in name:
                nn.init.kaiming_uniform_(param, a=math.sqrt(5))
            elif "lora_B" in name:
                nn.init.zeros_(param)
            elif "adapter" in name:
                if "gate" in name:
                    nn.init.zeros_(param)
                    continue
                try:
                    nn.init.kaiming_normal_(param)
                except Exception as e:
                    nn.init.normal_(param)
        if mode == "llama_adapter_plus":
            for layer in self.perceiver.layers:
                layer.adapter_to_kv = copy.deepcopy(layer.to_kv)
                layer.adapter_to_out = copy.deepcopy(layer.to_out)

    def init_classifier_weights(self, mode):
        if mode == "original":
            return
        for classifier in self.classifier.classifier_list:
            nn.init.kaiming_uniform_(classifier.weight, a=math.sqrt(5))
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(classifier.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(classifier.bias, -bound, bound)

    @staticmethod
    def crop_image(orig_image, coord, height=480, width=480):
        """
        Given a center coordinate (y,x), crop image to given size
        """
        orig_height, orig_width = orig_image.shape[1:]
        left = int(coord[1] - width/2)
        right = int(coord[1] + width/2)
        top = int(coord[0] - height/2)
        bottom = int(coord[0] + height/2)
        if left < 0:
            right += -left
            left = 0
        if right >= orig_width:
            left -= right-(orig_width-1)
            right = orig_width-1
        if top < 0:
            bottom += -top
            top = 0
        if bottom >= orig_height:
            top -= bottom-(orig_height-1)
            bottom = orig_height-1
        coord = (int((top+bottom)/2), int((left+right)/2))
        cropped = F.crop(orig_image, top=top, left=left, height=height, width=width)
        # orig_image[0,:,left-3:left+3] = 0  # debug plot
        # orig_image[0,:,right-3:right+3] = 0
        # orig_image[0,top-3:top+3,:] = 0
        # orig_image[0,bottom-3:bottom+3,:] = 0
        # from matplotlib import pyplot as plt
        # plt.imshow(orig_image.permute(1, 2, 0).cpu().numpy())
        # plt.show()
        return cropped, coord

    def extract_medical_roi(self, noaug_vision_x: torch.Tensor, orig_vision_x, query_ids, query_masks):
        orig_vision_x = [x for sample in orig_vision_x for x in sample]
        orig_heights = [x.shape[-2] for x in orig_vision_x]
        orig_widths = [x.shape[-1] for x in orig_vision_x]
        assert noaug_vision_x.ndim == 6
        b, t, f = noaug_vision_x.shape[:3]
        noaug_vision_x = rearrange(noaug_vision_x, "b t f c h w -> (b t f) c h w")
        query_ids = repeat(query_ids, "b d -> (b t f) d", t=t, f=f)
        query_masks = repeat(query_masks, "b d -> (b t f) d", t=t, f=f)
        similarity_maps = self.medical_roi_extractor.get_similarity_map_from_image_and_query_id(
            image=noaug_vision_x,
            orig_heights=orig_heights,
            orig_widths=orig_widths,
            query_ids=query_ids,
            query_masks=query_masks,
            interpolation="bilinear",
        )
        indices_of_max = [np.nanargmax(sim) for sim in similarity_maps]
        max_coords = [np.unravel_index(index, sim.shape) for index, sim in zip(indices_of_max, similarity_maps)]
        cropped_images = []
        coords = []
        for image, coord in zip(orig_vision_x, max_coords):
            croped_image, coord = self.crop_image(image, coord)
            cropped_images.append(croped_image)
            coords.append(coord)
        cropped_images = torch.stack(cropped_images)
        return cropped_images, coords  # center coordinate (y,x) in orig images

    @staticmethod  # Deprecated
    def concat_vision_x(vision_x, med_roi):
        """
        Resize and rearrange med_roi to concat with vision_x
        """
        assert vision_x.ndim == 6
        assert med_roi.ndim == 4
        b, t, f, c, h, w = vision_x.shape
        # Resize
        if h != med_roi.shape[-2] or w != med_roi.shape[-1]:
            resized_med_roi = []
            for image in med_roi:
                resized_med_roi.append(F.resize(image, [h, w]))
            med_roi = torch.stack(resized_med_roi)
        # Expand channel
        if c != 1 and med_roi.shape[-3] != c:
            med_roi = repeat(med_roi, "b c h w -> b (repeat c) h w", repeat=c)
        # Concat
        med_roi = rearrange(med_roi, "(b t f) c h w -> b t f c h w", b=b, t=t, f=f)
        vision_x = torch.concat([vision_x, med_roi], dim=1)
        return vision_x


    def forward(
        self,
        vision_x: torch.Tensor,
        lang_x: torch.Tensor,
        med_vision_x: torch.Tensor = None,
        orig_vision_x: List = None,
        noaug_vision_x: torch.Tensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        use_cached_vision_x: bool = False,
        clear_conditioned_layers: bool = True,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: bool = False,
        use_med_roi: bool = True,
        query_ids: torch.Tensor = None,
        query_masks: torch.Tensor = None,
        **kwargs,
    ) -> CausalLMOutputWithPast:
        """
        Forward pass of Otter.

        Args:
            vision_x (torch.Tensor): Vision input
                shape (B, T_img, F, C, H, W) with F=1
            lang_x (torch.Tensor): Language input ids
                shape (B, T_txt)
            med_vision_x (torch.Tensor): Vision input for medical encoder
                shape (B, T_img, F, C, H, W) with F=1
            noaug_vision_x (torch.tensor) Vision input without any training augmentation
            orig_vision_x (List[List[torch.Tensor]]): Original untransformed images
                shape (B, T_img, (C, H, W))
            attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
            labels (torch.Tensor, optional): Labels. Defaults to None.
            clear_conditioned_layers: if True, clear the conditioned layers
                once the foward pass is completed. Set this to false if the
                same set of images will be reused in another subsequent
                forward pass.
            past_key_values: pre-computed values to pass to language model.
                See past_key_values documentation in Hugging Face
                CausalLM models.
            use_cache: whether to use cached key values. See use_cache
                documentation in Hugging Face CausalLM models.
            use_med_roi: whether to extract medical RoI from original images
            query_ids: query text tokens for extract RoI
                shape (B, T_txt)
            query_masks: attention masks of query_ids
                shape (B, T_txt)
        """
        assert (vision_x is not None) or use_cached_vision_x, "Must provide either vision_x or use_cached_vision_x to True."

        if use_cached_vision_x:
            # Case: use cached; vision_x should be cached and other
            # vision-related inputs should not be provided.
            assert vision_x is None and med_vision_x is None, "Expect vision_x to be None when use_cached_vision_x is True."
            assert self.lang_encoder.is_conditioned()

        else:
            # Case: do not use caching (i.e. this is a standard forward pass);
            # (b,T,n=64,d=1024), 14 of 64 used for chexpert classification
            if use_med_roi and query_ids is not None and query_masks is not None:
                assert hasattr(self, "medical_roi_extractor") and orig_vision_x is not None
                med_roi_vision_x, med_roi_coords = self.extract_medical_roi(  # (B,1,h,w)
                    noaug_vision_x=noaug_vision_x,
                    orig_vision_x=orig_vision_x,
                    query_ids=query_ids,
                    query_masks=query_masks
                )
                # vision_x = self.concat_vision_x(vision_x, med_roi_vision_x)
                # med_vision_x = self.concat_vision_x(med_vision_x, med_roi_vision_x)
            else:
                med_roi_vision_x, med_roi_coords = None, None
            self._encode_vision_x(
                vision_x=vision_x,
                med_vision_x=med_vision_x,
                med_roi_vision_x=med_roi_vision_x,
                med_roi_coords=med_roi_coords,
                orig_vision_x=orig_vision_x,
            )

        if hasattr(self, "classifier"):
            # Use vis_x (output of perceiver) for classification
            pred = self.classifier(self.lang_encoder._get_decoder_layers()[0].vis_x).squeeze()
        else:
            pred = None

        output = self.lang_encoder(
            input_ids=lang_x,
            attention_mask=attention_mask,
            labels=labels,
            past_key_values=past_key_values,
            use_cache=use_cache,
            **kwargs,
        )

        if clear_conditioned_layers:
            self.lang_encoder.clear_conditioned_layers()

        return output, pred

    def roi_position_embedding(self, orig_shape, coords, height=480, width=480, resized_size=512):
        """
        Input center coordinates (y,x) of roi images in list of tuples
        Output corresponding zoom-in interpolated positional embedding based on self.medical_position_embedding
        """
        orig_corners = torch.Tensor(coords).repeat(4,1,1).permute(1,0,2)  # (tl,tr,bl,br)
        orig_corners[:,[0,1],0] -= int(height/2)
        orig_corners[:,[0,2],1] -= int(width/2)
        orig_corners[:,[1,3],1] += int(width/2)
        orig_corners[:,[2,3],0] += int(height/2)
        proportions = orig_corners / repeat(torch.Tensor(orig_shape), "b d -> b 4 d",)
        resized_corners = (resized_size * proportions).to(torch.int)
        ymin, xmin = int((resized_size-height)/2), int((resized_size-width)/2)
        ymax, xmax = ymin + height, xmin + width
        resized_corners[:,:,0] = torch.clamp(resized_corners[:,:,0], min=ymin, max=ymax)
        resized_corners[:,:,1] = torch.clamp(resized_corners[:,:,1], min=xmin, max=xmax)
        resized_corners[:,:,0] -= ymin
        resized_corners[:,:,1] -= xmin
        emb_grid_size = int(math.sqrt(self.medical_position_embedding.num_embeddings))
        med_roi_pos_emb = []
        for corners in resized_corners:
            emb = self.resized_medical_position_embedding[:,:,corners[0,0]:corners[2,0],corners[0,1]:corners[1,1]]
            emb = nn.functional.interpolate(emb, size=emb_grid_size, mode="bilinear")
            med_roi_pos_emb.append(emb)
        med_roi_pos_emb = torch.concat(med_roi_pos_emb, dim=0)
        med_roi_pos_emb = rearrange(med_roi_pos_emb, "b c h w -> b (h w) c")
        # TODO check all repeat, they are not copy but view.
        return med_roi_pos_emb


    def _encode_vision_x(
        self,
        vision_x: torch.Tensor,
        med_vision_x: torch.Tensor = None,
        med_roi_vision_x: torch.Tensor = None,
        med_roi_coords = None,
        orig_vision_x = None,
    ):
        """
        Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
        Args:
            vision_x (torch.Tensor): Vision input
                shape (B, T_img, F, C, H, W)
                Images in the same chunk are collated along T_img, and frames are collated along F
                Currently only F=1 is supported (single-frame videos)
            med_vision_x (torch.Tensor): Same as vision_x with different height and width

        rearrange code based on https://github.com/dhansmair/flamingo-mini
        """

        assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
        b, T, F, c, h, w = vision_x.shape
        if med_vision_x is not None:
            assert med_vision_x.ndim == 6, "med_vision_x should be of shape (b, T_img, F, C, H, W)"
            assert (b, T, F) == med_vision_x.shape[:3], "med_vision_x and vision_x should have same (B,T,F)"
            med_c, med_h, med_w = med_vision_x.shape[3:]

        if 0 < self.downsample_frames < T:
            indices = sorted(random.sample(range(T), self.downsample_frames))
            vision_x = vision_x[:, indices, :, :, :, :]
            T = self.downsample_frames

        if self.vision_encode_mode != "original":  # Medical encoder used
            assert med_vision_x is not None
            # with record_function("forward_medical_enoder_time"):
            med_vision_x = rearrange(med_vision_x, "b T F c h w -> (b T F) c h w")  # ((b,T,F),C,H,W)
            with torch.no_grad():
                med_vision_x = self.medical_vision_encoder(med_vision_x)
                if med_roi_vision_x is not None:
                    med_roi_vision_x = repeat(med_roi_vision_x, "B 1 h w -> B c h w", c=c)
                    med_roi_vision_x = self.medical_vision_encoder(med_roi_vision_x)
            if "llama_adapter" in self.vision_encode_mode or self.vision_encode_mode == "medical_only":
                if isinstance(self.medical_vision_encoder, BioVilEncoder):
                    med_vision_x = med_vision_x.patch_embeddings[-1]
                    med_vision_x = rearrange(med_vision_x, "v d h w -> v (h w) d")
                    if med_roi_vision_x is not None:
                        med_roi_vision_x = med_roi_vision_x.patch_embeddings[-1]
                        med_roi_vision_x = rearrange(med_roi_vision_x, "v d h w -> v (h w) d")
                elif isinstance(self.medical_vision_encoder, UniMedIEncoder2D) \
                or isinstance(self.medical_vision_encoder, UnifiedEncoder3D):
                    med_vision_x = med_vision_x[:, 1:, :]
                else:
                    raise NotImplementedError
                if "llama_adapter" in self.vision_encode_mode:
                    assert self.num_vision_tokens == med_vision_x.shape[1]  # biovil (*,225,2048) unimedi2d (*,196,768)
            else:
                raise NotImplementedError

            # with record_function("forward_medical_vision_adapter_time"):
            med_vision_x = self.medical_vision_adapter(med_vision_x)  # Adapter will be trained
            if med_roi_vision_x is not None:
                med_roi_vision_x = self.medical_vision_adapter(med_roi_vision_x)

            if "llama_adapter" in self.vision_encode_mode or self.vision_encode_mode == "medical_only":
                # positional embedding interpolated from original encoder's
                med_vision_x += self.medical_position_embedding(
                    torch.arange(self.medical_position_embedding.num_embeddings).unsqueeze(0).to(self.device)
                )
                if med_roi_vision_x is not None:
                    assert med_roi_coords is not None and orig_vision_x is not None
                    orig_shape = [x.shape[-2:] for sample in orig_vision_x for x in sample]
                    med_roi_vision_x += self.roi_position_embedding(
                        coords=med_roi_coords,
                        orig_shape=orig_shape,
                        height=med_h,
                        width=med_w,
                    )
                if isinstance(self.medical_vision_encoder, BioVilEncoder) \
                or isinstance(self.medical_vision_encoder, UniMedIEncoder2D):
                    med_vision_x = rearrange(med_vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
                    if med_roi_vision_x is not None:
                        med_roi_vision_x = rearrange(med_roi_vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
                elif isinstance(self.medical_vision_encoder, UnifiedEncoder3D):
                    med_vision_x = repeat(med_vision_x, "b v d -> b T F v d", T=T, F=F)
                else:
                    raise NotImplementedError
            else:
                raise NotImplementedError
        else:  # Medical encoder not used
            med_vision_x = None

        # Flamingo encoder
        with torch.no_grad():
            # with record_function("forward_flamingo_vision_encoder_time"):
            vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")  # ((b,T,F),C,H,W)
            vision_x = self.vision_encoder(vision_x)[0][:, 1:, :]  # (1,(h*w),1024)
            vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)  # (b,T,F,(h*w),1024)

        if self.vision_encode_mode == "medical_only":
            vision_x = med_vision_x
            med_vision_x = None

        # with record_function("forward_perceiver_resampler_time"):
        # med_vision_x: (b,2,1,*,d)
        # if med_roi_vision_x is not None:
        #     med_vision_x = torch.cat([med_vision_x, med_roi_vision_x], dim=1)
        vision_x = self.perceiver(vision_x, med_vision_x, med_roi_vision_x)  # reshapes to (b, T, n, d) where n is number of learnable queries

        for layer in self.lang_encoder._get_decoder_layers():
            layer.condition_vis_x(vision_x)

    @torch.no_grad()
    def generate(
        self,
        vision_x: torch.Tensor,
        lang_x: torch.Tensor,
        med_vision_x: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        query_ids: Optional[torch.Tensor] = None,
        query_masks: Optional[torch.Tensor] = None,
        orig_vision_x: Optional[List] = None,
        noaug_vision_x: Optional[torch.Tensor] = None,
        label_only: bool = False,
        **generate_kwargs,
    ):
        """
        Generate text conditioned on vision and language inputs.

        Args:
            vision_x (torch.Tensor): Vision input
                shape (B, T_img, F, C, H, W)
                images in the same chunk are collated along T_img, and frames are collated along F
                currently only F=1 is supported (single-frame videos)
            lang_x (torch.Tensor): Language input
                shape (B, T_txt)
            med_vision_x (torch.Tensor, optional): Medical vision encoder input,
                same shape as vision_x except H and W. Defaults to None.
            max_length (int, optional): Maximum length of the output. Defaults to None.
            attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
        Returns:
            torch.Tensor: lang_x with generated tokens appended to it
        """
        if hasattr(self, "_hf_hook"):
            # add a hook to make sure that the output of lang_encoder is mapped to the same device as the lang_x
            hook = AlignDevicesHook(
                execution_device=lang_x.device,
                io_same_device=True,
                place_submodules=False,
            )
            add_hook_to_module(self.lang_encoder, hook)

        if not label_only and query_masks is not None and query_ids is not None:
            assert orig_vision_x is not None and noaug_vision_x is not None
            if query_ids.ndim == 1: query_ids = query_ids.unsqueeze(0)
            if query_masks.ndim == 1: query_masks = query_masks.unsqueeze(0)
            query_masks = query_masks.to(self.device)
            query_ids = query_ids.to(self.device)
            med_roi_vision_x, med_roi_coords = self.extract_medical_roi(  # (B,1,h,w)
                noaug_vision_x=noaug_vision_x,
                orig_vision_x=orig_vision_x,
                query_ids=query_ids,
                query_masks=query_masks
            )
        else:
            med_roi_vision_x, med_roi_coords = None, None
        num_beams = generate_kwargs.get("num_beams", 1)
        if num_beams > 1:
            vision_x = vision_x.repeat_interleave(num_beams, dim=0)
            med_vision_x = med_vision_x.repeat_interleave(num_beams, dim=0) if med_vision_x is not None else None
            med_roi_vision_x = med_roi_vision_x.repeat_interleave(num_beams, dim=0) if med_roi_vision_x is not None else None
            med_roi_coords = num_beams * med_roi_coords if med_roi_coords is not None else None
            orig_vision_x = num_beams * orig_vision_x if orig_vision_x is not None else None
        self._encode_vision_x(
            vision_x=vision_x,
            med_vision_x=med_vision_x,
            med_roi_vision_x=med_roi_vision_x,
            med_roi_coords=med_roi_coords,
            orig_vision_x=orig_vision_x,
        )

        if label_only:  # Use vis_x (output of perceiver) for classification
            assert hasattr(self, "classifier")
            pred = torch.sigmoid(self.classifier(self.lang_encoder._get_decoder_layers()[0].vis_x)[0].squeeze())
            valid_ids = [i for i, img in enumerate(vision_x[0]) if torch.count_nonzero(img) > 0]
            pred = pred[valid_ids].mean(dim=0)
            label = pred >= 0.5
            if not torch.any(label):  # Get one max if none above threshold
                label[torch.argmax(pred)] = True
            return label


        output = self.lang_encoder.generate(
            lang_x,
            attention_mask=attention_mask,
            eos_token_id=self.eoc_token_id,
            **generate_kwargs,
        )

        self.lang_encoder.clear_conditioned_layers()
        return output
