from collections import OrderedDict
from dataclasses import asdict
from functools import partial
from logging import getLogger
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, Literal

import torch
import torch.nn as nn
from einops import rearrange
from torch import nn
from torch.nn import functional as F
from torch.nn import Module, ModuleList
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.parameter import Parameter
from torch.utils.checkpoint import checkpoint
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10

from math import log, pi, sqrt

from einops import rearrange, repeat
from torch import Tensor, broadcast_tensors, einsum, nn
from torch.amp import autocast

from .configuration_pe import PEConfig, PE_VISION_CONFIG, fetch_pe_checkpoint



from dataclasses import dataclass
from typing import Optional, Dict
@dataclass
class PeSkipLinkModelOutput(BaseModelOutput):
    skiplink_hidden_states: Optional[Dict[int, torch.Tensor]] = None
@dataclass
class PeSkipLinkOutputWithPooling(BaseModelOutputWithPooling):
    skiplink_hidden_states: Optional[Dict[int, torch.Tensor]] = None


logger = getLogger()

def _calculate_fan_in_and_fan_out(tensor):
    dimensions = tensor.dim()
    if dimensions < 2:
        raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")

    num_input_fmaps = tensor.size(1)
    num_output_fmaps = tensor.size(0)
    receptive_field_size = 1
    if tensor.dim() > 2:
        receptive_field_size = tensor[0][0].numel()
    fan_in = num_input_fmaps * receptive_field_size
    fan_out = num_output_fmaps * receptive_field_size

    return fan_in, fan_out

def lecun_normal_(tensor, mode='fan_in'):
    """Initialize the weight tensor using LeCun normal initialization.
    
    Args:
        tensor: an n-dimensional `torch.Tensor`
        mode: either 'fan_in' (default) or 'fan_out'. Choosing 'fan_in'
            preserves the magnitude of the variance of the weights in the
            forward pass. Choosing 'fan_out' preserves the magnitudes in the
            backwards pass.
    """
    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
    if mode == 'fan_in':
        denom = fan_in
    elif mode == 'fan_out':
        denom = fan_out
    else:
        raise ValueError(f"Invalid mode: {mode}")

    with torch.no_grad():
        variance = 1.0 / denom
        tensor.normal_(std=sqrt(variance))

# helper functions
def exists(val):
    return val is not None


def default(val, d):
    return val if exists(val) else d


# broadcat, as tortoise-tts was using it


def broadcat(tensors, dim=-1):
    broadcasted_tensors = broadcast_tensors(*tensors)
    return torch.cat(broadcasted_tensors, dim=dim)


# rotary embedding helper functions


def rotate_half(x):
    x = rearrange(x, "... (d r) -> ... d r", r=2)
    x1, x2 = x.unbind(dim=-1)
    x = torch.stack((-x2, x1), dim=-1)
    return rearrange(x, "... d r -> ... (d r)")


@autocast("cuda", enabled=False)
def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
    dtype = t.dtype

    if t.ndim == 3:
        seq_len = t.shape[seq_dim]
        freqs = freqs[-seq_len:]

    rot_dim = freqs.shape[-1]
    end_index = start_index + rot_dim

    assert (
        rot_dim <= t.shape[-1]
    ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"

    t_left, t, t_right = (
        t[..., :start_index],
        t[..., start_index:end_index],
        t[..., end_index:],
    )
    t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
    out = torch.cat((t_left, t, t_right), dim=-1)

    return out.type(dtype)


# learned rotation helpers


def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
    if exists(freq_ranges):
        rotations = einsum("..., f -> ... f", rotations, freq_ranges)
        rotations = rearrange(rotations, "... r f -> ... (r f)")

    rotations = repeat(rotations, "... n -> ... (n r)", r=2)
    return apply_rotary_emb(rotations, t, start_index=start_index)


# classes
class RotaryEmbedding(Module):
    def __init__(
        self,
        dim,
        custom_freqs: Optional[Tensor] = None,
        freqs_for: Union[
            Literal["lang"], Literal["pixel"], Literal["constant"]
        ] = "lang",
        theta=10000,
        max_freq=10,
        num_freqs=1,
        learned_freq=False,
        use_xpos=False,
        xpos_scale_base=512,
        interpolate_factor=1.0,
        theta_rescale_factor=1.0,
        seq_before_head_dim=False,
        cache_if_possible=True,
    ):
        super().__init__()
        # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
        # has some connection to NTK literature

        theta *= theta_rescale_factor ** (dim / (dim - 2))

        self.freqs_for = freqs_for

        if exists(custom_freqs):
            freqs = custom_freqs
        elif freqs_for == "lang":
            freqs = 1.0 / (
                theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
            )
        elif freqs_for == "pixel":
            freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
        elif freqs_for == "constant":
            freqs = torch.ones(num_freqs).float()

        self.cache_if_possible = cache_if_possible

        self.tmp_store("cached_freqs", None)
        self.tmp_store("cached_scales", None)

        self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)

        self.learned_freq = learned_freq

        # dummy for device

        self.tmp_store("dummy", torch.tensor(0))

        # default sequence dimension

        self.seq_before_head_dim = seq_before_head_dim
        self.default_seq_dim = -3 if seq_before_head_dim else -2

        # interpolation factors

        assert interpolate_factor >= 1.0
        self.interpolate_factor = interpolate_factor

        # xpos

        self.use_xpos = use_xpos
        if not use_xpos:
            self.tmp_store("scale", None)
            return

        scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)

        self.scale_base = xpos_scale_base
        self.tmp_store("scale", scale)

        # add apply_rotary_emb as static method

        self.apply_rotary_emb = staticmethod(apply_rotary_emb)

    @property
    def device(self):
        return self.dummy.device

    def tmp_store(self, key, value):
        self.register_buffer(key, value, persistent=False)

    def get_seq_pos(self, seq_len, device, dtype, offset=0):
        return (
            torch.arange(seq_len, device=device, dtype=dtype) + offset
        ) / self.interpolate_factor

    def rotate_queries_or_keys(self, t, seq_dim=None, offset=0):
        seq_dim = default(seq_dim, self.default_seq_dim)

        assert (
            not self.use_xpos
        ), "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings"

        device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]

        freqs = self.forward(
            self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset),
            seq_len=seq_len,
            offset=offset,
        )

        if seq_dim == -3:
            freqs = rearrange(freqs, "n d -> n 1 d")

        return apply_rotary_emb(freqs, t, seq_dim=seq_dim)

    def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0):
        seq_dim = default(seq_dim, self.default_seq_dim)

        q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
        assert q_len <= k_len

        rotated_q = self.rotate_queries_or_keys(
            q, seq_dim=seq_dim, offset=k_len - q_len + offset
        )
        rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, offset=offset)

        rotated_q = rotated_q.type(q.dtype)
        rotated_k = rotated_k.type(k.dtype)

        return rotated_q, rotated_k

    def rotate_queries_and_keys(self, q, k, seq_dim=None):
        seq_dim = default(seq_dim, self.default_seq_dim)

        assert self.use_xpos
        device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]

        seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)

        freqs = self.forward(seq, seq_len=seq_len)
        scale = self.get_scale(seq, seq_len=seq_len).to(dtype)

        if seq_dim == -3:
            freqs = rearrange(freqs, "n d -> n 1 d")
            scale = rearrange(scale, "n d -> n 1 d")

        rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim)
        rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim)

        rotated_q = rotated_q.type(q.dtype)
        rotated_k = rotated_k.type(k.dtype)

        return rotated_q, rotated_k

    def get_scale(self, t: Tensor, seq_len: Optional[int] = None, offset=0):
        assert self.use_xpos

        should_cache = self.cache_if_possible and exists(seq_len)

        if (
            should_cache
            and exists(self.cached_scales)
            and (seq_len + offset) <= self.cached_scales.shape[0]
        ):
            return self.cached_scales[offset : (offset + seq_len)]

        scale = 1.0
        if self.use_xpos:
            power = (t - len(t) // 2) / self.scale_base
            scale = self.scale ** rearrange(power, "n -> n 1")
            scale = torch.cat((scale, scale), dim=-1)

        if should_cache:
            self.tmp_store("cached_scales", scale)

        return scale

    def get_axial_freqs(self, *dims):
        Colon = slice(None)
        all_freqs = []

        for ind, dim in enumerate(dims):
            if self.freqs_for == "pixel":
                pos = torch.linspace(-1, 1, steps=dim, device=self.device)
            else:
                pos = torch.arange(dim, device=self.device)

            freqs = self.forward(pos, seq_len=dim)

            all_axis = [None] * len(dims)
            all_axis[ind] = Colon

            new_axis_slice = (Ellipsis, *all_axis, Colon)
            all_freqs.append(freqs[new_axis_slice])

        all_freqs = broadcast_tensors(*all_freqs)
        return torch.cat(all_freqs, dim=-1)

    @autocast("cuda", enabled=False)
    def forward(self, t: Tensor, seq_len=None, offset=0):
        should_cache = (
            self.cache_if_possible
            and not self.learned_freq
            and exists(seq_len)
            and self.freqs_for != "pixel"
        )

        if (
            should_cache
            and exists(self.cached_freqs)
            and (offset + seq_len) <= self.cached_freqs.shape[0]
        ):
            return self.cached_freqs[offset : (offset + seq_len)].detach()

        freqs = self.freqs

        freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
        freqs = repeat(freqs, "... n -> ... (n r)", r=2)

        if should_cache:
            self.tmp_store("cached_freqs", freqs.detach())

        return freqs



class Rope2D:
    """ Helper class to apply RoPE2D as well as interpolate on the fly. """

    def __init__(self, dim, use_cls_token=False):
        self.dim = dim
        self.use_cls_token = use_cls_token
        self.grid_size = None
        self.freq = None

    def init_tensors(self):
        self.rope = RotaryEmbedding(self.dim // 2)

    def update_grid(self, device, grid_h, grid_w):
        if self.grid_size != (grid_h, grid_w):
            self.grid_size = (grid_h, grid_w)

            self.rope = self.rope.to(device)

            if self.use_cls_token:
                # +1 to leave space for the cls token to be (0, 0)
                grid_y_range = torch.arange(grid_h, device=device) + 1
                grid_x_range = torch.arange(grid_w, device=device) + 1
            else:
                grid_y_range = torch.arange(grid_h, device=device)
                grid_x_range = torch.arange(grid_w, device=device)

            freqs_y = self.rope(grid_y_range)[:, None].expand(grid_h, grid_w, -1)
            freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1)
            freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1)

            if self.use_cls_token:
                freq = torch.cat(
                    [torch.zeros(1, freq.shape[-1], device=device), freq], dim=0
                )

            self.freq = freq[None, ...]

        self.freq = self.freq.to(device)

    def __call__(self, q, k):
        # batch, heads, seq, dim = q.shape
        q = apply_rotary_emb(self.freq[:, None, :, :], q)
        k = apply_rotary_emb(self.freq[:, None, :, :], k)

        return q, k







class LayerScale(nn.Module):
    def __init__(self, dim, init_values=1e-5, inplace=False):
        super().__init__()
        self.inplace = inplace
        self.dim = dim
        self.init_values = init_values

    def forward(self, x):
        return x.mul_(self.gamma) if self.inplace else x * self.gamma

    def init_tensors(self):
        self.gamma = nn.Parameter(self.init_values * torch.ones(self.dim))


class AttentionPooling(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        num_probe: int = 1,
        mlp_ratio: int = 4,
        act_layer: Callable = nn.GELU,
        norm_layer: Callable = nn.LayerNorm,
    ):
        super().__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads

        assert (
            self.embed_dim % num_heads == 0
        ), "embed_dim must be divisible by num_heads"

        self.probe = nn.Parameter(torch.randn(1, num_probe, self.embed_dim))
        self.attn = nn.MultiheadAttention(
            self.embed_dim, self.num_heads, batch_first=True
        )

        self.layernorm = norm_layer(embed_dim)
        self.mlp_width = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            OrderedDict(
                [
                    ("c_fc", nn.Linear(self.embed_dim, self.mlp_width)),
                    ("gelu", act_layer()),
                    ("c_proj", nn.Linear(self.mlp_width, self.embed_dim)),
                ]
            )
        )

    def forward(self, x: torch.Tensor):
        batch, _, _ = x.shape

        q = self.probe.repeat((batch, 1, 1)).to(x.dtype)
        x = self.attn(q, x, x, need_weights=False)[0]
        x = x + self.mlp(self.layernorm(x))

        return x


class SelfAttention(nn.Module):
    r"""
    Implements sequence packed attention and RoPe
    """

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        rope: Optional[nn.Module] = None,
        attn_implementation: str = "eager",
    ):
        super(SelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert (
            self.head_dim * num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"

        # To make this compatibile with nn.MultiHeadAttention
        self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
        self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)

        self.rope = rope
        self.scale = self.head_dim ** (-0.5)
        self.attn_implementation = attn_implementation
        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

    def init_tensors(self):
        xavier_uniform_(self.in_proj_weight)
        constant_(self.in_proj_bias, 0.0)
        constant_(self.out_proj.bias, 0.0)

    def forward(self, x, attn_mask=None, output_attentions=False):
        batch, seq, embed_dim = x.shape
        proj = F.linear(x, self.in_proj_weight, self.in_proj_bias)

        # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
        proj = (
            proj.unflatten(-1, (3, embed_dim))
            .unsqueeze(0)
            .transpose(0, -2)
            .squeeze(-2)
            .contiguous()
        )
        q, k, v = proj[0], proj[1], proj[2]

        # Use "q_" so that we don't accidentally quit in pdb :)
        q = rearrange(q, "b s (h d) -> b h s d", h=self.num_heads)
        k = rearrange(k, "b s (h d) -> b h s d", h=self.num_heads)
        v = rearrange(v, "b s (h d) -> b h s d", h=self.num_heads)

        if self.rope:
            q, k = self.rope(q, k)

        if self.attn_implementation == "flash_attention_2" and is_flash_attn_2_available():
            # Flash Attention 2 implementation
            from transformers.modeling_flash_attention_utils import _flash_attention_forward
            
            q = rearrange(q, "b h s d -> b s h d", h=self.num_heads).contiguous()
            k = rearrange(k, "b h s d -> b s h d", h=self.num_heads).contiguous()
            v = rearrange(v, "b h s d -> b s h d", h=self.num_heads).contiguous()
            # Convert attention mask to the format expected by flash attention
            if attn_mask is not None:
                if attn_mask.dtype == torch.bool:
                    attn_mask = attn_mask.to(x.dtype)
                attn_mask = attn_mask.unsqueeze(1).unsqueeze(2)
            
            attn_output = _flash_attention_forward(
                q, k, v,
                attn_mask,
                seq,
                dropout=0.0,
                is_causal=False,
                use_top_left_mask=self._flash_attn_uses_top_left_mask,
            )
            attn_output = attn_output.reshape(batch, seq, embed_dim).contiguous()
            attn_weights = None
        else:
            # Standard attention implementation
            attn = F.scaled_dot_product_attention(
                q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale
            )
            attn_output = rearrange(attn, "b h s d -> b s (h d)")
            attn_weights = None

        output = F.linear(attn_output, self.out_proj.weight, self.out_proj.bias)
        
        if output_attentions:
            return output, attn_weights
        return output


class ResidualAttentionBlock(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_head: int,
        mlp_ratio: float = 4.0,
        ls_init_value: float = None,
        act_layer: Callable = nn.GELU,
        norm_layer: Callable = nn.LayerNorm,
        drop_path: float = 0.0,
        rope: Optional[nn.Module] = None,
        attn_implementation: str = "eager",
    ):
        super().__init__()

        if rope:
            self.attn = SelfAttention(d_model, n_head, rope=rope, attn_implementation=attn_implementation)
        else:
            self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)

        self.ls_1 = (
            LayerScale(d_model, ls_init_value)
            if ls_init_value is not None
            else nn.Identity()
        )
        self.ls_2 = (
            LayerScale(d_model, ls_init_value)
            if ls_init_value is not None
            else nn.Identity()
        )

        self.ln_1 = norm_layer(d_model)
        self.ln_2 = norm_layer(d_model)

        self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        mlp_width = int(d_model * mlp_ratio)
        self.mlp = nn.Sequential(
            OrderedDict(
                [
                    ("c_fc", nn.Linear(d_model, mlp_width)),
                    ("gelu", act_layer()),
                    ("c_proj", nn.Linear(mlp_width, d_model)),
                ]
            )
        )

    def _call_attn(
        self,
        q_x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ):
        if attn_mask is not None:
            # Leave boolean masks as is
            if not attn_mask.dtype == torch.bool:
                attn_mask = attn_mask.to(q_x.dtype)

        if isinstance(self.attn, SelfAttention):
            return self.attn(q_x, attn_mask=attn_mask, output_attentions=output_attentions)
        else:
            return self.attn(q_x, q_x, q_x, attn_mask=attn_mask, need_weights=output_attentions)

    def forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ):
        attn_output = self._call_attn(self.ln_1(x), attn_mask=attn_mask, output_attentions=output_attentions)
        if isinstance(attn_output, tuple):
            attn_output, attn_weights = attn_output
        else:
            attn_weights = None
            
        x = x + self.drop_path1(self.ls_1(attn_output))
        x = x + self.drop_path2(self.ls_2(self.mlp(self.ln_2(x))))
        
        if output_attentions:
            return x, attn_weights
        return x


class Transformer(nn.Module):
    def __init__(
        self,
        width: int,
        layers: int,
        heads: int,
        mlp_ratio: float = 4.0,
        ls_init_value: float = None,
        act_layer: Callable = nn.GELU,
        norm_layer: Callable = nn.LayerNorm,
        drop_path: float = 0.0,
        rope: Optional[nn.Module] = None,
        attn_implementation: str = "eager",
    ):
        super().__init__()
        self.width = width
        self.layers = layers
        self.grad_checkpointing = False

        self.resblocks = nn.ModuleList(
            [
                ResidualAttentionBlock(
                    width,
                    heads,
                    mlp_ratio,
                    ls_init_value=ls_init_value,
                    act_layer=act_layer,
                    norm_layer=norm_layer,
                    drop_path=drop_path,
                    rope=rope,
                    attn_implementation=attn_implementation,
                )
                for _ in range(layers)
            ]
        )

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.grad_checkpointing = enable

    @torch.jit.ignore
    def truncate(self, layer_idx: int):
        """ Delete layers so the last layer is the given layer index. """
        self.layers = ((self.layers + layer_idx) % self.layers) + 1
        self.resblocks = nn.ModuleList(self.resblocks[:self.layers])

    # def forward(
    #     self,
    #     x: torch.Tensor,
    #     attn_mask: Optional[torch.Tensor] = None,
    #     layer_idx: int = -1,
    #     output_attentions: bool = False,
    # ):
    #     stop_idx = (self.layers + layer_idx) % self.layers
    #     all_attentions = () if output_attentions else None

    #     for i, r in enumerate(self.resblocks):
    #         if self.grad_checkpointing and not torch.jit.is_scripting():
    #             # x = checkpoint(r, x, None, None, attn_mask)
    #             x = checkpoint(lambda x, attn_mask: r(x, attn_mask=attn_mask), x, attn_mask, use_reentrant=False)
    #         else:
    #             x = r(x, attn_mask=attn_mask)
            
    #         if i == stop_idx:
    #             break

    #     if output_attentions:
    #         return x, all_attentions
    #     return x
    def forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        layer_idx: int = -1,
        output_attentions: bool = False,
        skiplink_layers: Optional[List[int]] = None,
    ):
        stop_idx = (self.layers + layer_idx) % self.layers
        all_attentions = () if output_attentions else None

        # Dictionary to store outputs from skiplink layers
        skiplink_hidden_states = {}
        # Store the input as the "0th" layer hidden state if needed
        if skiplink_layers is not None and (0 in skiplink_layers or -(self.layers+1) in skiplink_layers):
            skiplink_hidden_states[0] = x.clone()

        # print(f"\033[31m RESBLOCK {len(self.resblocks)} IN TOTAL, {self.layers} \033[0m")
        for i, r in enumerate(self.resblocks):
            if self.grad_checkpointing and not torch.jit.is_scripting():
                # x = checkpoint(r, x, None, None, attn_mask)
                x = checkpoint(lambda x, attn_mask: r(x, attn_mask=attn_mask), x, attn_mask, use_reentrant=False)
            else:
                x = r(x, attn_mask=attn_mask)
        
            # Store hidden state if this layer index is in skiplink_layers
            if skiplink_layers is not None:
                current_layer = i + 1  # 1-indexed layer number
                
                # Check for positive index
                if current_layer in skiplink_layers:
                    skiplink_hidden_states[current_layer] = x.clone()
                
                # Check for negative indices
                for neg_idx in [idx for idx in skiplink_layers if idx < 0]:
                    # Convert negative index to its positive layer counterpart
                    # For example, with 12 layers, -1 maps to layer 12, -2 to layer 11, etc.
                    corresponding_layer = self.layers + neg_idx + 1
                    if current_layer == corresponding_layer:
                        skiplink_hidden_states[neg_idx] = x.clone()
            
            if i == stop_idx:
                break

        return_dict = {}
        return_dict["last_hidden_state"] = x
        if output_attentions:
            return_dict["attentions"] = all_attentions # not implemented yet
        if skiplink_layers is not None:
            return_dict["skiplink_hidden_states"] = skiplink_hidden_states if skiplink_layers else None
        return return_dict


class PeViT(PreTrainedModel):
    config_class = PEConfig
    base_model_prefix = "pe"
    supports_gradient_checkpointing = True
    _no_split_modules = ["ResidualAttentionBlock"]
    _supports_flash_attn_2 = True
    _supports_sdpa = True

    def __init__(
        self,
        config: PEConfig,
        **kwargs
    ):
        super().__init__(config)
        self.patch_size = config.patch_size
        self.width = config.width
        self.layers = config.layers
        self.heads = config.heads
        self.mlp_ratio = config.mlp_ratio
        self.output_dim = config.output_dim
        self.proj_dim = config.output_dim
        self.use_abs_posemb = config.use_abs_posemb
        self.use_cls_token = config.use_cls_token
        self.use_rope2d = config.use_rope2d
        self.image_size = config.image_size
        self.pool_type = config.pool_type

        self.conv1 = nn.Conv2d(
            in_channels=3,
            out_channels=self.width,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=False,
        )
        self.rope = (
            Rope2D(
                dim=self.width // self.heads,
                use_cls_token=self.use_cls_token,
            )
            if self.use_rope2d
            else None
        )

        self.ln_pre = nn.LayerNorm(self.width) if config.use_ln_pre else nn.Identity()
        self.ln_post = nn.LayerNorm(self.width) if config.use_ln_post else nn.Identity()

        self.transformer = Transformer(
            self.width,
            self.layers,
            self.heads,
            self.mlp_ratio,
            ls_init_value=config.ls_init_value,
            drop_path=config.drop_path,
            rope=self.rope,
            attn_implementation=config._attn_implementation,
        )

        if self.pool_type == "attn":
            self.attn_pool = AttentionPooling(
                embed_dim=self.width,
                num_heads=config.attn_pooler_heads,
            )
        else:
            self.attn_pool = None

        self.init_tensors()

    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            lecun_normal_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def init_tensors(self):
        def init_submodule_tensors(module):
            for name, child in module.named_children():
                if hasattr(child, "init_tensors"):
                    logger.debug(f"Initializing tensors for submodule: {name}")
                    child.init_tensors()
                init_submodule_tensors(child)

        init_submodule_tensors(self)
        if self.rope:
            self.rope.init_tensors()

        # class embeddings and positional embeddings
        init_scale = self.width**-0.5

        if self.use_cls_token:
            self.class_embedding = nn.Parameter(init_scale * torch.randn(self.width))

        if self.use_abs_posemb:
            self.posemb_grid_size = self.image_size // self.patch_size
            self.positional_embedding = nn.Parameter(
                init_scale
                * torch.randn(
                    int(self.use_cls_token) + self.posemb_grid_size**2, self.width
                )
            )

        if self.proj_dim is not None:
            self.proj = nn.Parameter(
                init_scale * torch.randn(self.width, self.proj_dim)
            )

    @classmethod
    def from_config(cls, config: PEConfig, **kwargs):
        """Create a model instance from a PEConfig object.
        
        Args:
            config (PEConfig): The configuration object to create the model from.
            **kwargs: Additional keyword arguments for model initialization.
                - checkpoint_path (str, optional): Path to the checkpoint file to load weights from.
        
        Returns:
            VisionTransformer: The initialized model.
        """
        checkpoint_path = kwargs.pop("checkpoint_path", None)
        
        # Initialize model with config
        model = cls(config, **kwargs)
        
        # Load checkpoint if path is provided
        if checkpoint_path:
            model.load_ckpt(checkpoint_path)
            
            # Convert model to bf16 if specified in config
            if config.use_bfloat16 and torch.cuda.is_available() and torch.cuda.is_bf16_supported():
                model = model.to(torch.bfloat16)
        
        return model

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        """Load a model from a pretrained model name or path.
        
        Args:
            pretrained_model_name_or_path (str): The name or path of the pretrained model.
            *model_args: Additional positional arguments for model initialization.
            **kwargs: Additional keyword arguments for model initialization.
                - config (PEConfig, optional): Configuration object. If not provided, will be created from pretrained_model_name_or_path.
                - checkpoint_path (str, optional): Path to the checkpoint file. If not provided, will be fetched using pretrained_model_name_or_path.
        
        Returns:
            PeViT: The loaded model.
        """
        config = kwargs.pop("config", None)
        checkpoint_path = kwargs.pop("checkpoint_path", None)
        
        # Initialize config from pretrained model name if not provided
        if config is None:
            config = PEConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
        
        # Initialize model with config
        model = cls(config, *model_args, **kwargs)
        
        # Load checkpoint if path is provided or pretrained model name is given
        if checkpoint_path or pretrained_model_name_or_path:
            ckpt_path = checkpoint_path or fetch_pe_checkpoint(pretrained_model_name_or_path)
            model.load_ckpt(ckpt_path)
            
            # Convert model to bf16 if specified in config
            if config.use_bfloat16 and torch.cuda.is_available() and torch.cuda.is_bf16_supported():
                model = model.to(torch.bfloat16)
        
        return model

    def load_ckpt(self, ckpt_path: str):
        _sd = torch.load(ckpt_path, weights_only=True)
        if "state_dict" in _sd:
            _sd = _sd["state_dict"]
        elif "weights" in _sd:
            _sd = _sd["weights"]

        # for backwards compatibility
        _sd = {k.replace("module.", ""): v for k, v in _sd.items()}
        if any(k.startswith("visual.") for k in _sd):
            _sd = {k.replace("visual.", ""): v for k, v in _sd.items() if "visual" in k}

        m, u = self.load_state_dict(_sd, strict=False)
        logger.info(f"Missing keys for loading vision encoder: {m}")
        logger.info(f"Unexpected keys for loading vision encoder: {u}")
        print(f"Missing keys for loading vision encoder: {m}")
        print(f"Unexpected keys for loading vision encoder: {u}")

    # def forward_features(
    #     self,
    #     x: torch.Tensor,
    #     norm: bool = False,
    #     layer_idx: int = -1,
    #     strip_cls_token: bool = False,
    #     output_attentions: bool = False,
    #     output_hidden_states: bool = False,
    #     return_dict: bool = True,
    # ) -> Union[Tuple, BaseModelOutput]:
    #     # Convert input to the same dtype as model weights if using bf16
    #     if self.config.use_bfloat16 and torch.cuda.is_available() and torch.cuda.is_bf16_supported():
    #         x = x.to(torch.bfloat16)
            
    #     batch, _, h, w = x.shape
    #     grid_h, grid_w = h // self.patch_size, w // self.patch_size

    #     x = self.conv1(x)
    #     x = x.permute(0, 2, 3, 1).reshape(batch, -1, self.width)

    #     if self.use_cls_token:
    #         x = torch.cat(
    #             [self.class_embedding.view(1, 1, -1).expand(batch, -1, -1), x],
    #             dim=1,
    #         )

    #     if self.use_abs_posemb:
    #         x = x + self._sample_abs_posemb(grid_h, grid_w)

    #     if self.use_rope2d:
    #         self.rope.update_grid(x.device, grid_h, grid_w)

    #     x = self.ln_pre(x)
        
    #     # Get attention outputs from transformer
    #     transformer_output = self.transformer(x, layer_idx=layer_idx, output_attentions=output_attentions)
    #     if isinstance(transformer_output, tuple):
    #         x = transformer_output[0]
    #         attentions = transformer_output[1] if output_attentions else None
    #     else:
    #         x = transformer_output
    #         attentions = None

    #     if norm:
    #         x = self.ln_post(x)

    #     if strip_cls_token and self.use_cls_token:
    #         x = x[:, 1:, :]

    #     if not return_dict:
    #         return (x, attentions) if output_attentions else x

    #     return BaseModelOutput(
    #         last_hidden_state=x,
    #         hidden_states=None,
    #         attentions=attentions,
    #     )


    def forward_features(
        self,
        x: torch.Tensor,
        norm: bool = False,
        layer_idx: int = -1,
        strip_cls_token: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        **kwargs: Any,
    ) -> Union[Tuple, BaseModelOutput, PeSkipLinkModelOutput]:
        # Convert input to the same dtype as model weights if using bf16
        if self.config.use_bfloat16 and torch.cuda.is_available() and torch.cuda.is_bf16_supported():
            x = x.to(torch.bfloat16)
            
        batch, _, h, w = x.shape
        grid_h, grid_w = h // self.patch_size, w // self.patch_size

        x = self.conv1(x)
        x = x.permute(0, 2, 3, 1).reshape(batch, -1, self.width)

        if self.use_cls_token:
            x = torch.cat(
                [self.class_embedding.view(1, 1, -1).expand(batch, -1, -1), x],
                dim=1,
            )

        if self.use_abs_posemb:
            x = x + self._sample_abs_posemb(grid_h, grid_w)

        if self.use_rope2d:
            self.rope.update_grid(x.device, grid_h, grid_w)

        x = self.ln_pre(x)
        
        # Extract skiplink_layers from kwargs if provided
        skiplink_layers = kwargs.get("skiplink_layers", None)
        # Get attention outputs from transformer
        transformer_output = self.transformer(
            x,
            layer_idx=layer_idx,
            output_attentions=output_attentions,
            skiplink_layers=skiplink_layers
        )

        x = transformer_output["last_hidden_state"]
        attentions = transformer_output.get("attentions", None)
        assert output_hidden_states is False, "Output hidden states are not supported in this implementation."
        if skiplink_layers is not None:
            skiplink_hidden_states = transformer_output.get("skiplink_hidden_states", None)

        if norm:
            x = self.ln_post(x)
        # layer norm every skiplink layer if provided
        if skiplink_layers is not None:
            for layer_idx, hidden_state in skiplink_hidden_states.items():
                assert isinstance(hidden_state, torch.Tensor)
                skiplink_hidden_states[layer_idx] = self.ln_post(hidden_state)
        
        if strip_cls_token and self.use_cls_token:
            x = x[:, 1:, :]
            # consider every skiplink layer
            for layer_idx, hidden_state in skiplink_hidden_states.items():
                assert isinstance(hidden_state, torch.Tensor)
                skiplink_hidden_states[layer_idx] = hidden_state[:, 1:, :]

        if not return_dict:
            return (x, attentions) if output_attentions else x
        if skiplink_layers is not None:
            return PeSkipLinkModelOutput(
                last_hidden_state=x,
                skiplink_hidden_states=skiplink_hidden_states,
                hidden_states=None,
                attentions=attentions,
            )
        return BaseModelOutput(
            last_hidden_state=x,
            hidden_states=None,
            attentions=attentions,
        )


    # def forward(
    #     self,
    #     x: torch.Tensor,
    #     output_attentions: bool = False,
    #     output_hidden_states: bool = False,
    #     return_dict: bool = True,
    #     **kwargs
    # ) -> Union[Tuple, BaseModelOutputWithPooling]:
    #     x = self.forward_features(
    #         x,
    #         norm=True,
    #         output_attentions=output_attentions,
    #         output_hidden_states=output_hidden_states,
    #         return_dict=return_dict,
    #         **kwargs
    #     )

    #     if isinstance(x, BaseModelOutput):
    #         print("\033[31m这是BaseModelOutput\033[0m")
    #         last_hidden_state = x.last_hidden_state
    #     else:
    #         print("\033[31m这不是BaseModelOutput\033[0m")
            
    #     pooler_output = self._pool(x)

    #     if self.proj_dim is not None:
    #         print("\033[31mif self.proj_dim is not None:\033[0m")
    #         last_hidden_state = x @ self.proj

    #     if not return_dict:
    #         return last_hidden_state

    #     return BaseModelOutputWithPooling(
    #         last_hidden_state=last_hidden_state,
    #         pooler_output=pooler_output,
    #         hidden_states=None,
    #         attentions=x.attentions,
    #     )

    def forward(
        self,
        x: torch.Tensor,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        **kwargs
    ) -> Union[Tuple, BaseModelOutputWithPooling, PeSkipLinkOutputWithPooling]:
        x = self.forward_features(
            x,
            norm=True,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            **kwargs
        )

        if isinstance(x, BaseModelOutput):
            last_hidden_state = x.last_hidden_state
        if isinstance(x, PeSkipLinkModelOutput):
            last_hidden_state = x.last_hidden_state
            skiplink_hidden_states = x.skiplink_hidden_states
  
        pooler_output = self._pool(last_hidden_state)

        if self.proj_dim is not None:
            last_hidden_state = last_hidden_state @ self.proj

        if isinstance(x, PeSkipLinkModelOutput):
            return PeSkipLinkOutputWithPooling(
                last_hidden_state=last_hidden_state,
                pooler_output=pooler_output,
                skiplink_hidden_states=skiplink_hidden_states,
                hidden_states=None,
                attentions=x.attentions,
            )

        if not return_dict: # what about attentions?
            raise NotImplementedError(
                "Return dict is not implemented for this case. Please use return_dict=True."
            )
            return last_hidden_state

        return BaseModelOutputWithPooling(
            last_hidden_state=last_hidden_state,
            pooler_output=pooler_output,
            hidden_states=None,
            attentions=x.attentions,
        )

    def _sample_abs_posemb(self, grid_h: int, grid_w: int):
        """Interpolates the absolute position embedding if necessary."""
        if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w:
            return self.positional_embedding[None, ...]

        pos_embed = self.positional_embedding
        if self.use_cls_token:
            cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]

        pos_embed = (
            pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1)
            .permute(0, 3, 1, 2)
            .contiguous()
        )
        pos_embed = F.interpolate(
            pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False
        )
        pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.width).contiguous()

        if self.use_cls_token:
            pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0)

        return pos_embed[None, ...]

    def _pool(self, x: torch.Tensor):
        if self.pool_type == "tok":
            return x[:, 0]
        elif self.pool_type == "avg":
            return x.mean(dim=1)
        elif self.pool_type == "attn":
            return self.attn_pool(x).squeeze(1)
        elif self.pool_type == "none":
            return x
        else:
            raise NotImplementedError


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()  # binarize
        output = x.div(keep_prob) * random_tensor
        return output

