""" CLIP Model

Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
# Copyright (c) Meta Platforms, Inc. and affiliates
from .hf_model import HFTextEncoder
from collections import OrderedDict
from dataclasses import dataclass
import logging
import math
from typing import Tuple, Union, Callable, Optional, List, Text

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
import numbers

from .timm_model import TimmModel
from .utils import freeze_batch_norm_2d, to_2tuple

from .hooks.hook import HookManager


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1):
        super().__init__()

        # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu2 = nn.ReLU(inplace=True)

        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()

        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu3 = nn.ReLU(inplace=True)

        self.downsample = None
        self.stride = stride

        if stride > 1 or inplanes != planes * Bottleneck.expansion:
            # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
            self.downsample = nn.Sequential(OrderedDict([
                ("-1", nn.AvgPool2d(stride)),
                ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
                ("1", nn.BatchNorm2d(planes * self.expansion))
            ]))

    def forward(self, x: torch.Tensor):
        identity = x

        out = self.relu1(self.bn1(self.conv1(x)))
        out = self.relu2(self.bn2(self.conv2(out)))
        out = self.avgpool(out)
        out = self.bn3(self.conv3(out))

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu3(out)
        return out


class AttentionPool2d(nn.Module):
    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
        super().__init__()
        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
        self.num_heads = num_heads

    def forward(self, x):
        x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1)  # NCHW -> (HW)NC
        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
        x, _ = F.multi_head_attention_forward(
            query=x, key=x, value=x,
            embed_dim_to_check=x.shape[-1],
            num_heads=self.num_heads,
            q_proj_weight=self.q_proj.weight,
            k_proj_weight=self.k_proj.weight,
            v_proj_weight=self.v_proj.weight,
            in_proj_weight=None,
            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
            bias_k=None,
            bias_v=None,
            add_zero_attn=False,
            dropout_p=0,
            out_proj_weight=self.c_proj.weight,
            out_proj_bias=self.c_proj.bias,
            use_separate_proj_weight=True,
            training=self.training,
            need_weights=False
        )

        return x[0]


class ModifiedResNet(nn.Module):
    """
    A ResNet class that is similar to torchvision's but contains the following changes:
    - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
    - The final pooling layer is a QKV attention instead of an average pool
    """

    def __init__(self, layers, output_dim, heads, image_size=224, width=64):
        super().__init__()
        self.output_dim = output_dim
        self.image_size = image_size

        # the 3-layer stem
        self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(width // 2)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(width // 2)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(width)
        self.relu3 = nn.ReLU(inplace=True)
        self.avgpool = nn.AvgPool2d(2)

        # residual layers
        self._inplanes = width  # this is a *mutable* variable used during construction
        self.layer1 = self._make_layer(width, layers[0])
        self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
        self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
        self.layer4 = self._make_layer(width * 8, layers[3], stride=2)

        embed_dim = width * 32  # the ResNet feature dimension
        self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)

        self.init_parameters()

    def _make_layer(self, planes, blocks, stride=1):
        layers = [Bottleneck(self._inplanes, planes, stride)]

        self._inplanes = planes * Bottleneck.expansion
        for _ in range(1, blocks):
            layers.append(Bottleneck(self._inplanes, planes))

        return nn.Sequential(*layers)

    def init_parameters(self):
        if self.attnpool is not None:
            std = self.attnpool.c_proj.in_features ** -0.5
            nn.init.normal_(self.attnpool.q_proj.weight, std=std)
            nn.init.normal_(self.attnpool.k_proj.weight, std=std)
            nn.init.normal_(self.attnpool.v_proj.weight, std=std)
            nn.init.normal_(self.attnpool.c_proj.weight, std=std)

        for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
            for name, param in resnet_block.named_parameters():
                if name.endswith("bn3.weight"):
                    nn.init.zeros_(param)

    def lock(self, unlocked_groups=0, freeze_bn_stats=False):
        assert unlocked_groups == 0, 'partial locking not currently supported for this model'
        for param in self.parameters():
            param.requires_grad = False
        if freeze_bn_stats:
            freeze_batch_norm_2d(self)

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        # FIXME support for non-transformer
        pass

    def stem(self, x):
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu2(self.bn2(self.conv2(x)))
        x = self.relu3(self.bn3(self.conv3(x)))
        x = self.avgpool(x)
        return x

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.attnpool(x)

        return x


# class LayerNorm(nn.LayerNorm):
#     """Subclass torch's LayerNorm to handle fp16."""

#     def forward(self, x: torch.Tensor):
#         orig_type = x.dtype
#         x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
#         return x.to(orig_type)

class LayerNorm(nn.Module):
    """Subclass torch's LayerNorm (with cast back to input dtype)."""
    def __init__(self, normalized_shape, eps: float = 1e-5, elementwise_affine: bool = True, device=None, dtype=None, 
                 hook: Optional[HookManager] = None):
        super().__init__()
        self.hook = hook or HookManager()
        if isinstance(normalized_shape, numbers.Integral):
            # mypy error: incompatible types in assignment
            normalized_shape = (normalized_shape,)  # type: ignore[assignment]
        self.normalized_shape = tuple(normalized_shape)  # type: ignore[arg-type]
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        if self.elementwise_affine:
            self.weight = torch.nn.Parameter(torch.empty(self.normalized_shape,))
            self.bias = torch.nn.Parameter(torch.empty(self.normalized_shape,))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        assert self.normalized_shape == x.shape[-len(self.normalized_shape):]
        dims = [-(i + 1) for i in range(len(self.normalized_shape))]
        mean = self.hook('mean', ret=x.mean(dim=dims, keepdim=True))
        mean_x2 = (x ** 2).mean(dim=dims, keepdim=True)
        var = mean_x2 - mean ** 2
        x_norm = self.hook('mean_reduced', ret=(x - mean)) / self.hook('sqrt_var', ret=torch.sqrt(var + self.eps))
        if self.elementwise_affine:
            x_norm = self.hook('renorm.post', ret=self.weight * x_norm + self.bias)
        self.hook.finalize()
        return x_norm.to(orig_type)


class QuickGELU(nn.Module):
    # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)

class MultiheadAttention(nn.Module):
    """
    There are variety of ways to look at multihead attention. Because of that I implemented a few so it will be easy to compare.
    """
    def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
                 kdim=None, vdim=None, batch_first=False, device=None, dtype=None, hook: Optional[HookManager] = None,):
        super().__init__()
        self.hook = hook or HookManager()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim

        self.num_heads = num_heads
        self.dropout = dropout
        self.batch_first = batch_first
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        self.in_proj_weight = nn.Parameter(torch.empty((3 * embed_dim, embed_dim)))
 
        if bias:
            self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim))
        else:
            self.register_parameter('in_proj_bias', None)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

        if add_bias_kv:
            self.bias_k = nn.Parameter(torch.empty((1, 1, embed_dim)))
            self.bias_v = nn.Parameter(torch.empty((1, 1, embed_dim)))
        else:
            self.bias_k = self.bias_v = None

        self.add_zero_attn = add_zero_attn
        
    def forward_direct(self, x, attn_mask=None):
        B, N, C = x.shape
        qkv = self.hook('in_proj_bias.post', 
                        ret=self.hook('in_proj.post', 
                        ret=x @ self.in_proj_weight.T) + self.in_proj_bias)
        qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        k = self.hook('k', ret=k)
        q = self.hook('q', ret=q)
        v = self.hook('v', ret=v)
        dk = q.size()[-1]
        q = q  / math.sqrt(dk)
        q = self.hook('q_norm', ret=q)
        attn = q @ k.transpose(-2, -1) # [B, H, N, N]
        attn = self.hook('pre_mask', ret=attn)
        if attn_mask is not None:
            attn += attn_mask
        attn = self.hook('post_mask', ret=attn)
        attn = attn.softmax(dim=-1)
        attn = self.hook('post_softmax', ret=attn)
        x = attn @ v

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.hook('attn_v', ret=x)
        x = self.hook('out_proj_bias.post', 
                      ret=self.hook('out_proj.post', ret=x @ self.out_proj.weight.T) + self.out_proj.bias)    
        return x
    
    def _split_qkv_weight(self):
        q_weight, k_weight, v_weight = (self.in_proj_weight[:self.embed_dim].reshape(self.num_heads, self.head_dim, -1), 
                                        self.in_proj_weight[self.embed_dim:self.embed_dim*2].reshape(self.num_heads, self.head_dim, -1), 
                                        self.in_proj_weight[self.embed_dim*2:].reshape(self.num_heads, self.head_dim, -1)
                                        )
        return q_weight, k_weight, v_weight
    
    def _split_qkv_bias(self):
        q_bias, k_bias, v_bias = (self.in_proj_bias[:self.embed_dim].reshape(1, self.num_heads, 1, self.head_dim), 
                                  self.in_proj_bias[self.embed_dim:self.embed_dim*2].reshape(1, self.num_heads, 1, self.head_dim), 
                                  self.in_proj_bias[self.embed_dim*2:].reshape(1, self.num_heads, 1, self.head_dim)
                                  )
        return q_bias, k_bias, v_bias
    
    def forward_qkv(self, x, attn_mask=None):
        B, N, C = x.shape
        q_weight, k_weight, v_weight = (self.in_proj_weight[:self.embed_dim], 
                                        self.in_proj_weight[self.embed_dim:self.embed_dim*2], 
                                        self.in_proj_weight[self.embed_dim*2:])
        q_bias, k_bias, v_bias = (self.in_proj_bias[:self.embed_dim], 
                                  self.in_proj_bias[self.embed_dim:self.embed_dim*2], 
                                  self.in_proj_bias[self.embed_dim*2:])
        q = self.hook('in_q_bias.post', 
                      ret=self.hook('in_q.post', 
                      ret=x @ q_weight.T) + 
                      q_bias).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = self.hook('in_k_bias.post', 
                      ret=self.hook('in_k.post', 
                      ret=x @ k_weight.T) + 
                      k_bias).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = self.hook('in_v_bias.post', 
                      ret=self.hook('in_v.post', 
                      ret=x @ v_weight.T) + 
                      v_bias).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        dk = q.size()[-1]
        q = q  / math.sqrt(dk)
        q = self.hook('q_norm', ret=q)
        attn = q @ k.transpose(-2, -1)
        attn = self.hook('attention.pre_mask', ret=attn)
        if attn_mask is not None:
            attn += attn_mask
        attn = self.hook('attention.post_mask', ret=attn)
        attn = attn.softmax(dim=-1)
        attn = self.hook('attention.post_softmax', ret=attn) # [B, H, N, N]
        x = torch.einsum('bhnm,bhmc->bhnmc', attn, v)
        x = self.hook('extended_attn_v', ret=x)
        x = x.sum(axis=3).transpose(1, 2).reshape(B, N, C)
        x = self.hook('attn_v', ret=x)
        x = self.hook('out.post_bias', 
                      ret=self.hook('out.post', 
                                    ret=x @ self.out_proj.weight.T) + 
                      self.out_proj.bias)    
        return x
    
    def forward_per_head(self, x, attn_mask=None):
        """ Old Version
        B, N, C = x.shape # batch size, number of tokens, embedding dim
        q_weight, k_weight, v_weight = self._split_qkv_weight()# number of head, head im
        q_bias, k_bias, v_bias = self._split_qkv_bias()
        q = self.hook('in_q_bias.post', 
                      ret=self.hook('in_q.post', 
                      ret=torch.einsum('bnc,hdc->bhnd', x, q_weight)) + 
                      q_bias)
        k = self.hook('in_k_bias.post', 
                      ret=self.hook('in_k.post', 
                      ret=torch.einsum('bnc,hdc->bhnd', x, k_weight)) + 
                      k_bias)
        v = self.hook('in_v_bias.post', 
                      ret=self.hook('in_v.post', 
                      ret=torch.einsum('bnc,hdc->bhnd', x, v_weight)) + 
                      v_bias) # (B, self.num_heads, N, self.head_dim)
        dk = q.size()[-1]
        q = q  / math.sqrt(dk)
        q = self.hook('q_norm', ret=q)
        attn = q @ k.transpose(-2, -1)
        attn = self.hook('attention.pre_mask', ret=attn)
        if attn_mask is not None:
            attn += attn_mask
        attn = self.hook('attention.post_mask', ret=attn)
        attn = attn.softmax(dim=-1)
        attn = self.hook('attention.post_softmax', ret=attn) # [B, H, N, N]
        x = torch.einsum('bhnm,bhmc->bnmhc', attn, v) # We also switch here back from head-first to n-first
        x = self.hook('extended_attn_v', ret=x)
        x = self.hook('out.post', ret=torch.einsum('bnmhc,dhc->bnmhd', x, self.out_proj.weight.reshape(self.embed_dim, self.num_heads, self.head_dim)))    
        x = self.hook('out.post_collapse', ret=x.sum(axis=[2,3])) 
        x = self.hook('out.post_bias', ret=x + self.out_proj.bias) 
        return x"""
        
        B, N, C = x.shape # batch size, number of tokens, embedding dim
        q_weight, k_weight, v_weight = self._split_qkv_weight()# number of head, head im
        q_bias, k_bias, v_bias = self._split_qkv_bias()
        q = self.hook('in_q_bias.post', 
                      ret=self.hook('in_q.post', 
                      ret=torch.einsum('bnc,hdc->bhnd', x, q_weight)) + 
                      q_bias)
        k = self.hook('in_k_bias.post', 
                      ret=self.hook('in_k.post', 
                      ret=torch.einsum('bnc,hdc->bhnd', x, k_weight)) + 
                      k_bias)
        v = self.hook('in_v_bias.post', 
                      ret=self.hook('in_v.post', 
                      ret=torch.einsum('bnc,hdc->bhnd', x, v_weight)) + 
                      v_bias) # (B, self.num_heads, N, self.head_dim)
        dk = q.size()[-1]
        q = q  / math.sqrt(dk)
        q = self.hook('q_norm', ret=q)
        attn = q @ k.transpose(-2, -1)
        attn = self.hook('attention.pre_mask', ret=attn)
        if attn_mask is not None:
            attn += attn_mask
        attn = self.hook('attention.post_mask', ret=attn)
        attn = attn.softmax(dim=-1)
        attn = self.hook('attention.post_softmax', ret=attn) # [B, H, N, N]
        x = torch.einsum('bhnm,bhmc->bnmhc', attn, v) # We also switch here back from head-first to n-first
        x = self.hook('extended_attn_v', ret=x)
        x = self.hook('out.post', ret=torch.einsum('bnmhc,dhc->bnmd', x, self.out_proj.weight.reshape(self.embed_dim, self.num_heads, self.head_dim)))    
        x = self.hook('out.post_collapse', ret=x.sum(axis=[2])) 
        x = self.hook('out.post_bias', ret=x + self.out_proj.bias) 
        return x
    
    def _get_ov_circuit(self,):
        reshaped_o = self.out_proj.weight.reshape(self.embed_dim, self.num_heads, self.head_dim)
        _, _, v_weight = self._split_qkv_weight() # num_heads, head_dim, embed_dim
        _, _, v_bias = self._split_qkv_bias() # 1, num_heads, 1, head_dim
        ov_circuit = torch.einsum('onh,nhi->oni', reshaped_o, v_weight)
        ov_bias_circuit = torch.einsum('onh,bnxh->bnxo', reshaped_o, v_bias) # [1, num_heads, 1, embed_dim]
        return ov_circuit, ov_bias_circuit
      
    def forward_ov_circuit(self, x, attn_mask=None):
        B, N, C = x.shape
        q_weight, k_weight, _ = self._split_qkv_weight()
        q_bias, k_bias, _ = self._split_qkv_bias()
        q = self.hook('in_q_bias.post', 
                      ret=self.hook('in_q.post', 
                      ret=torch.einsum('bnc,hdc->bhnd', x, q_weight)) + 
                      q_bias)
        k = self.hook('in_k_bias.post', 
                      ret=self.hook('in_k.post', 
                      ret=torch.einsum('bnc,hdc->bhnd', x, k_weight)) + 
                      k_bias)
        ov, ov_bias = self._get_ov_circuit()
        ov = self.hook('ov', ret=ov)
        ov_bias = self.hook('ov_bias', ret=ov_bias)
        v = self.hook('ov_bias.post',
                      ret=self.hook('ov.post', 
                      ret=torch.einsum('bnc,dhc->bhnd', x, ov)) + 
                      ov_bias)
        
        dk = q.size()[-1]
        q = q  / math.sqrt(dk)
        q = self.hook('q_norm', ret=q)
        attn = q @ k.transpose(-2, -1)
        attn = self.hook('attention.pre_mask', ret=attn)
        if attn_mask is not None:
            attn += attn_mask
        attn = self.hook('attention.post_mask', ret=attn)
        attn = attn.softmax(dim=-1)
        attn = self.hook('attention.post_softmax', ret=attn) # [B, H, N, N]
        x = torch.einsum('bhnm,bhmc->bnmhc', attn, v) # We also switch here back from head-first to n-first
        x = self.hook('extended_attn_ov', ret=x)
        x = self.hook('out.post_collapse', ret=x.sum(axis=[2,3]))
        x = self.hook('out.post_bias', ret=x + self.out_proj.bias) 
        return x
    
    def forward(self, x, attn_mask=None, method = 'ov_circuit'):
        if method == 'direct':
            x = self.forward_direct(x, attn_mask=attn_mask)
        elif method == 'qkv':
            x = self.forward_qkv(x, attn_mask=attn_mask)
        elif method == 'head':
            x = self.forward_per_head(x, attn_mask=attn_mask)
        elif method == 'ov_circuit':
            x = self.forward_ov_circuit(x, attn_mask=attn_mask)
        self.hook.finalize()
        
        return x

class MLP(nn.Module):
    def __init__(self, d_model: int, mlp_width: int, act_layer: Callable = nn.GELU, hook: Optional[HookManager] = None,):
        super().__init__()
        self.hook = hook or HookManager()
        self.c_fc = nn.Linear(d_model, mlp_width)
        self.gelu = act_layer()
        self.c_proj = nn.Linear(mlp_width, d_model)
    
    def forward(self, x):
        x = self.hook('c_fc.post', ret=self.c_fc(x))
        x = self.hook('gelu.post', ret=self.gelu(x))
        x = self.hook('c_proj.post', ret=self.c_proj(x))
        self.hook.finalize()
        return x

class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model: int, n_head: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU, hook: Optional[HookManager] = None):
        super().__init__()
        self.hook =  hook or HookManager()
        self.attn = nn.MultiheadAttention(d_model, n_head)
        # self.attn = MultiheadAttention(d_model, n_head, hook=hook.fork('attn'))
        self.ln_1 = LayerNorm(d_model, hook=hook.fork('ln_1'))
        mlp_width = int(d_model * mlp_ratio)
        # self.mlp = nn.Sequential(OrderedDict([
        #     ("c_fc", nn.Linear(d_model, mlp_width)),
        #     ("gelu", act_layer()),
        #     ("c_proj", nn.Linear(mlp_width, d_model))
        # ]))
        self.mlp = MLP(d_model, mlp_width, act_layer=act_layer, hook=hook.fork('mlp'))
        self.ln_2 = LayerNorm(d_model, hook=hook.fork('ln_2'))

    def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, output_attention=True):
        # return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
        return self.attn(x, x, x, need_weights=output_attention, attn_mask=attn_mask)

    # def attention(
    #         self,
    #         q_x: torch.Tensor,
    #         attn_mask: Optional[torch.Tensor] = None,
    #         method = 'direct'
    # ):
    #     attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
    #     return self.attn(
    #         q_x, attn_mask=attn_mask,
    #         method=method
    #     )

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, output_attention=False, attn_method = 'direct',):
        x = self.hook('pre', ret=x)
        if output_attention:
            hidden_states, attention_weights = self.attention(self.ln_1(x), attn_mask=attn_mask, output_attention=output_attention)
        else:
            hidden_states = self.attention(q_x=x, attn_mask=attn_mask, method=attn_method)
        hidden_states = self.hook('after_attn', ret=hidden_states)
        x = x + hidden_states
        after_ln2 = self.ln_2(x)
        after_mlp = self.mlp(after_ln2)
        # x = x + self.mlp(self.ln_2(x))
        after_mlp = self.hook('after_mlp', ret=after_mlp)
        x = x + after_mlp
        x = self.hook('post', ret=x)
        self.hook.finalize()
        if output_attention:
            return x, attention_weights
        else:
            return x


class Transformer(nn.Module):
    def __init__(self, width: int, layers: int, heads: int,  mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU, hook: Optional[HookManager] = None):
        super().__init__()
        self.hook = hook or HookManager()
        self.width = width
        self.layers = layers
        self.grad_checkpointing = False

        self.resblocks = nn.ModuleList([
            ResidualAttentionBlock(width, heads, mlp_ratio, act_layer=act_layer, hook=hook.fork(f'resblocks.{i}'))
            for i in range(layers)
        ])

    def forward_intermediates(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, indices=None, output_attention=True):
        intermediates = []
        attention_weights = []
        for i, r in enumerate(self.resblocks):
            if self.grad_checkpointing and not torch.jit.is_scripting():
                (x, attn_weights) = checkpoint(r, x, attn_mask, output_attention)
            else:
                (x, attn_weights) = r(x, attn_mask=attn_mask, output_attention=output_attention)
            if indices is None or i in indices:
                intermediates.append(x.permute(1, 0, 2))  # LND -> NLD
            if attn_weights is not None:
                attention_weights.append(attn_weights)
        return x, intermediates, attention_weights

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, attn_method: Text = 'direct'):
        for r in self.resblocks:
            if self.grad_checkpointing and not torch.jit.is_scripting():
                x = checkpoint(r, x, attn_mask)
            else:
                x = r(x, attn_mask=attn_mask, attn_method=attn_method)
        self.hook.finalize()
        return x


class VisualTransformer(nn.Module):
    def __init__(
            self, image_size: int, patch_size: int, width: int, layers: int, heads: int, mlp_ratio: float,
            output_dim: int, act_layer: Callable = nn.GELU, hook: Optional[HookManager] = None):
        super().__init__()
        self.hook = hook or HookManager()
        self.image_size = to_2tuple(image_size)
        self.patch_size = to_2tuple(patch_size)
        self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1])
        self.output_dim = output_dim
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)

        scale = width ** -0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(width))
        self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
        self.ln_pre = LayerNorm(width, hook=hook.fork('ln_pre'))

        self.transformer = Transformer(width, layers, heads, mlp_ratio, act_layer=act_layer, hook=hook.fork('transformer'))

        self.ln_post = LayerNorm(width, hook=hook.fork('ln_post'))
        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))

    def lock(self, unlocked_groups=0, freeze_bn_stats=False):
        assert unlocked_groups == 0, 'partial locking not currently supported for this model'
        for param in self.parameters():
            param.requires_grad = False

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.transformer.grad_checkpointing = enable

    def forward_intermediates(self, x: torch.Tensor, indices=None, stop_early=False, output_attention=True):
        x = self.conv1(x)  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        x = torch.cat(
            [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
             x], dim=1)  # shape = [*, grid ** 2 + 1, width]
        x = x + self.positional_embedding.to(x.dtype)
        x = self.ln_pre(x)
        
        x = x.permute(1, 0, 2)  # NLD -> LND

        final, intermediates, attn_weights = self.transformer.forward_intermediates(x, indices, output_attention=output_attention)
        final = final.permute(1, 0, 2)  # LND -> NLD
        final_feature = self.ln_post(final[:, 0, :])
        if self.proj is not None:
            final_feature = final_feature @ self.proj

        return {
            "image_features": final_feature,
            "image_intermediates": intermediates,
            "attn_weights": attn_weights,
        }


    def forward(self, x: torch.Tensor, attn_method: Text = 'direct'):
        x = self.hook('conv1.post', ret=self.conv1(x))  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        x = torch.cat(
            [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
             x], dim=1)  # shape = [*, grid ** 2 + 1, width]
        # x = x + self.positional_embedding.to(x.dtype)
        x = self.hook('positional_embedding.post', ret=x + self.positional_embedding.to(x.dtype))
        # x = self.ln_pre(x)
        x = self.hook('ln_pre_post', ret=self.ln_pre(x))

        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x, attn_method=attn_method)
        x = x.permute(1, 0, 2)  # LND -> NLD

        # x = self.ln_post(x[:, 0, :])
        x = self.hook('ln_post_post', ret=self.ln_post(x))

        if self.proj is not None:
            # x = x @ self.proj
            x = self.hook('proj.post', ret=self.hook('proj.pre', ret=x) @ self.proj)
        self.hook.finalize()

        return x


@dataclass
class CLIPVisionCfg:
    layers: Union[Tuple[int, int, int, int], int] = 12
    width: int = 768
    head_width: int = 64
    mlp_ratio: float = 4.0
    patch_size: int = 16
    image_size: Union[Tuple[int, int], int] = 224
    timm_model_name: str = None  # a valid model name overrides layers, width, patch_size
    timm_model_pretrained: bool = False  # use (imagenet) pretrained weights for named model
    timm_pool: str = 'avg'  # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
    timm_proj: str = 'linear'  # linear projection for timm model output ('linear', 'mlp', '')


@dataclass
class CLIPTextCfg:
    context_length: int = 77
    vocab_size: int = 49408
    width: int = 512
    heads: int = 8
    layers: int = 12


class CLIP(nn.Module):
    def __init__(
            self,
            embed_dim: int,
            vision_cfg: CLIPVisionCfg,
            text_cfg: CLIPTextCfg,
            quick_gelu: bool = False,
            text_encoder_name = None,
            hook: Optional[HookManager] = None,
    ):
        super().__init__()
        if isinstance(vision_cfg, dict):
            vision_cfg = CLIPVisionCfg(**vision_cfg)
        if isinstance(text_cfg, dict):
            text_cfg = CLIPTextCfg(**text_cfg)

        self.context_length = text_cfg.context_length
        self.hook_manager = hook or HookManager()

        # OpenAI models are pretrained w/ QuickGELU 
        # NOTE: timm models always use native GELU regardless of quick_gelu flag.
        act_layer = QuickGELU if quick_gelu else nn.GELU

        if vision_cfg.timm_model_name:
            self.visual = TimmModel(
                vision_cfg.timm_model_name,
                pretrained=vision_cfg.timm_model_pretrained,
                pool=vision_cfg.timm_pool,
                proj=vision_cfg.timm_proj,
                embed_dim=embed_dim,
                image_size=vision_cfg.image_size
            )
            act_layer = nn.GELU  # so that text transformer doesn't use QuickGELU w/ timm models
        elif isinstance(vision_cfg.layers, (tuple, list)):
            vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
            self.visual = ModifiedResNet(
                layers=vision_cfg.layers,
                output_dim=embed_dim,
                heads=vision_heads,
                image_size=vision_cfg.image_size,
                width=vision_cfg.width
            )
        else:
            vision_heads = vision_cfg.width // vision_cfg.head_width
            self.visual = VisualTransformer(
                image_size=vision_cfg.image_size,
                patch_size=vision_cfg.patch_size,
                width=vision_cfg.width,
                layers=vision_cfg.layers,
                heads=vision_heads,
                mlp_ratio=vision_cfg.mlp_ratio,
                output_dim=embed_dim,
                act_layer=act_layer,
                hook=self.hook_manager.fork('visual'),
            )
        self.text_encoder = None
        if text_encoder_name is not None:
            self.text_encoder = HFTextEncoder(
                text_encoder_name,
                output_dim=embed_dim,
                proj_type='mlp',
                pooler_type='cls_last_hidden_state_pooler',
                pretrained=True,
                output_tokens=False,
                hook=self.hook_manager.fork('textual')
            )

        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        self.init_parameters()

    def init_parameters(self):
        # nn.init.normal_(self.token_embedding.weight, std=0.02)
        # nn.init.normal_(self.positional_embedding, std=0.01)
        nn.init.constant_(self.logit_scale, np.log(1 / 0.07))

        if hasattr(self.visual, 'init_parameters'):
            self.visual.init_parameters()


    def build_attention_mask(self):
        # lazily create causal attention mask, with full attention between the vision tokens
        # pytorch uses additive attention mask; fill with -inf
        mask = torch.empty(self.context_length, self.context_length)
        mask.fill_(float("-inf"))
        mask.triu_(1)  # zero out the lower diagonal
        return mask

    def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
        # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
        self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.visual.set_grad_checkpointing(enable)
        self.text_encoder.grad_checkpointing = enable

    def encode_image(self, image, attn_method='direct'):
        return self.visual(image, attn_method=attn_method)


    def encode_text(self, text):
        text_features = self.text_encoder(text)  # For BioMedCLIP text encoder
        return text_features
    
    def forward_intermediates(
        self,
        image: Optional[torch.Tensor] = None,
        text: Optional[torch.Tensor] = None,
        image_indices: Optional[Union[int, List[int]]] = None,
        text_indices: Optional[Union[int, List[int]]] = None,
        normalize: bool = True,
        intermediates_only: bool = False
    ):
        output = {}

        if image is not None:
            image_output = self.visual.forward_intermediates(image, indices=image_indices)
            if normalize and "image_features" in image_output:
                image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1)
            output.update(image_output)

        if text is not None:
            text_hidden = self.text_encoder.model.embeddings(text)
            attention_mask = self.text_encoder.model.get_extended_attention_mask(
                text.ne(self.text_encoder.tokenizer.pad_token_id), text.shape, device=text.device
            )
            final, intermediates = self.text_encoder.model.encoder.forward_intermediates(
                text_hidden, attention_mask, indices=text_indices
            )
            output["text_intermediates"] = intermediates
            if not intermediates_only:
                pooled = self.text_encoder.model.pooler(final)
                proj = self.text_encoder.projector(pooled)
                if normalize:
                    proj = F.normalize(proj, dim=-1)
                output["text_features"] = proj

        return output


    def forward(self, image, text, clamp_logit_scale_to=None):
        if image is not None:
            image_features = self.encode_image(image)
            image_features = F.normalize(image_features, dim=-1)
        else:
            image_features = None
        if text is not None:
            # text_features = self.encode_text(text) # For normal CLIP text encoder
            text_features = self.text_encoder(text)  # For BioMedCLIP text encoder
            text_features = F.normalize(text_features, dim=-1)
        else:
            text_features = None
        if clamp_logit_scale_to is not None:
            with torch.no_grad():
                self.logit_scale.data.clamp_(0, clamp_logit_scale_to)
        return image_features, text_features, self.logit_scale.exp()


def convert_weights_to_fp16(model: nn.Module):
    """Convert applicable model parameters to fp16"""

    def _convert_weights_to_fp16(l):
        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
            l.weight.data = l.weight.data.half()
            if l.bias is not None:
                l.bias.data = l.bias.data.half()

        if isinstance(l, nn.MultiheadAttention):
            for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
                tensor = getattr(l, attr)
                if tensor is not None:
                    tensor.data = tensor.data.half()

        for name in ["text_projection", "proj"]:
            if hasattr(l, name):
                attr = getattr(l, name)
                if attr is not None:
                    attr.data = attr.data.half()

    model.apply(_convert_weights_to_fp16)


def build_model_from_openai_state_dict(state_dict: dict):
    vit = "visual.proj" in state_dict

    if vit:
        vision_width = state_dict["visual.conv1.weight"].shape[0]
        vision_layers = len(
            [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
        vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
        grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
        image_size = vision_patch_size * grid_size
    else:
        counts: list = [
            len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
        vision_layers = tuple(counts)
        vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
        output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
        vision_patch_size = None
        assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
        image_size = output_width * 32

    embed_dim = state_dict["text_projection"].shape[1]
    context_length = state_dict["positional_embedding"].shape[0]
    vocab_size = state_dict["token_embedding.weight"].shape[0]
    transformer_width = state_dict["ln_final.weight"].shape[0]
    transformer_heads = transformer_width // 64
    transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))

    vision_cfg = CLIPVisionCfg(
        layers=vision_layers,
        width=vision_width,
        patch_size=vision_patch_size,
        image_size=image_size,
    )
    text_cfg = CLIPTextCfg(
        context_length=context_length,
        vocab_size=vocab_size,
        width=transformer_width,
        heads=transformer_heads,
        layers=transformer_layers
    )
    model = CLIP(
        embed_dim,
        vision_cfg=vision_cfg,
        text_cfg=text_cfg,
        quick_gelu=True,  # OpenAI models were trained with QuickGELU
    )

    for key in ["input_resolution", "context_length", "vocab_size"]:
        state_dict.pop(key, None)

    convert_weights_to_fp16(model)
    model.load_state_dict(state_dict)
    return model.eval()


def trace_model(model, batch_size=256, device=torch.device('cpu')):
    model.eval()
    image_size = model.visual.image_size
    example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
    example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
    model = torch.jit.trace_module(
        model,
        inputs=dict(
            forward=(example_images, example_text),
            encode_text=(example_text,),
            encode_image=(example_images,)
        ))
    model.visual.image_size = image_size
    return model


def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
    # Rescale the grid of position embeddings when loading from state_dict
    old_pos_embed = state_dict.get('visual.positional_embedding', None)
    if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
        return
    grid_size = to_2tuple(model.visual.grid_size)
    extra_tokens = 1  # FIXME detect different token configs (ie no class token, or more)
    new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
    if new_seq_len == old_pos_embed.shape[0]:
        return

    if extra_tokens:
        pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
    else:
        pos_emb_tok, pos_emb_img = None, old_pos_embed
    old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))

    logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
    pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
    pos_emb_img = F.interpolate(
        pos_emb_img,
        size=grid_size,
        mode=interpolation,
        align_corners=True,
    )
    pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
    if pos_emb_tok is not None:
        new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
    else:
        new_pos_embed = pos_emb_img
    state_dict['visual.positional_embedding'] = new_pos_embed
