import copy
import math
from math import pi
import numpy as np
import os
from functools import partial
import pywt
import ptwt
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import logging
from flash_attn import flash_attn_qkvpacked_func
from vector_quantize_pytorch import GroupedResidualVQ, GroupedResidualLFQ

try:
    from timm.models.layers import drop_path, to_2tuple, trunc_normal_
except:
    from timm.layers import drop_path, to_2tuple, trunc_normal_

if os.getenv('ENV_TYPE') == 'deepspeed':
    try:
        from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
    except:
        from torch.utils.checkpoint import checkpoint
else:
    from torch.utils.checkpoint import checkpoint

try:
    import xformers.ops as xops
except ImportError:
    xops = None
    print("Please 'pip install xformers'")


class ConvBlock(nn.Module):
    def __init__(self, in_chans, out_chans,
                 kernel_size=1,
                 dw_kernel_size=3,
                 expand=2,
                 stride=1,
                 ):
        super().__init__()
        self.norm = nn.LayerNorm(in_chans)
        self.primary_conv = nn.Sequential(
            nn.Conv2d(in_chans, out_chans, kernel_size, stride,
                      kernel_size // 2, ),
            nn.InstanceNorm2d(in_chans),
            nn.SiLU(),
        )
        self.cheap_operation = nn.Sequential(
            nn.Conv2d(
                out_chans,
                out_chans,
                dw_kernel_size,
                1,
                dw_kernel_size // 2,
                groups=out_chans,
            ),
            nn.InstanceNorm2d(out_chans),
            nn.SiLU(),
        )
        self.output_conv = nn.Conv2d(
            out_chans, out_chans, kernel_size=(1, 1),
            padding=0,
        )
        self.shortcut = nn.Conv2d(in_chans, out_chans, kernel_size=(1, 1),
                                  padding=0) if in_chans != out_chans else nn.Identity()
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.trunc_normal_(m.weight, 0, 0.02)

    def forward_features(self, x):
        x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).contiguous()
        x = self.primary_conv(x)
        x = self.cheap_operation(x)
        x = self.output_conv(x)
        return x

    def forward(self, x):
        x = checkpoint(self.forward_features, x, use_reentrant=False) + checkpoint(
            self.shortcut, x, use_reentrant=False)
        return x


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """

    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

    def extra_repr(self) -> str:
        return 'p={}'.format(self.drop_prob)


def broadcat(tensors, dim=-1):
    num_tensors = len(tensors)
    shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
    assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
    shape_len = list(shape_lens)[0]
    dim = (dim + shape_len) if dim < 0 else dim
    dims = list(zip(*map(lambda t: list(t.shape), tensors)))
    expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
    assert all(
        [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
    max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
    expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
    expanded_dims.insert(dim, (dim, dims[dim]))
    expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
    tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
    return torch.cat(tensors, dim=dim)


def rotate_half(x):
    x = rearrange(x, '... (d r) -> ... d r', r=2)
    x1, x2 = x.unbind(dim=-1)
    x = torch.stack((-x2, x1), dim=-1)
    return rearrange(x, '... d r -> ... (d r)')


class VisionRotaryEmbedding(nn.Module):
    def __init__(
            self,
            dim,
            pt_seq_len,
            ft_seq_len=None,
            custom_freqs=None,
            freqs_for='lang',
            theta=10000,
            max_freq=10,
            num_freqs=1,
    ):
        super().__init__()
        if custom_freqs:
            freqs = custom_freqs
        elif freqs_for == 'lang':
            freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
        elif freqs_for == 'pixel':
            freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
        elif freqs_for == 'constant':
            freqs = torch.ones(num_freqs).float()
        else:
            raise ValueError(f'unknown modality {freqs_for}')

        if ft_seq_len is None: ft_seq_len = pt_seq_len
        t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len

        freqs_h = torch.einsum('..., f -> ... f', t, freqs)
        freqs_h = repeat(freqs_h, '... n -> ... (n r)', r=2)

        freqs_w = torch.einsum('..., f -> ... f', t, freqs)
        freqs_w = repeat(freqs_w, '... n -> ... (n r)', r=2)

        freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)

        self.register_buffer("freqs_cos", freqs.cos())
        self.register_buffer("freqs_sin", freqs.sin())

        logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')

    def forward(self, t, start_index=0):
        rot_dim = self.freqs_cos.shape[-1]
        end_index = start_index + rot_dim
        assert rot_dim <= t.shape[
            -1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
        t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
        t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
        return torch.cat((t_left, t, t_right), dim=-1)


class VisionRotaryEmbeddingFast(nn.Module):
    def __init__(
            self,
            dim,
            pt_seq_len,
            ft_seq_len=None,
            custom_freqs=None,
            freqs_for='lang',
            theta=10000,
            max_freq=10,
            num_freqs=1,
            patch_dropout=0.
    ):
        super().__init__()
        if custom_freqs:
            freqs = custom_freqs
        elif freqs_for == 'lang':
            freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
        elif freqs_for == 'pixel':
            freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
        elif freqs_for == 'constant':
            freqs = torch.ones(num_freqs).float()
        else:
            raise ValueError(f'unknown modality {freqs_for}')

        if ft_seq_len is None: ft_seq_len = pt_seq_len
        t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len

        freqs = torch.einsum('..., f -> ... f', t, freqs)
        freqs = repeat(freqs, '... n -> ... (n r)', r=2)
        freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)

        freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
        freqs_sin = freqs.sin().view(-1, freqs.shape[-1])

        self.patch_dropout = patch_dropout

        self.register_buffer("freqs_cos", freqs_cos)
        self.register_buffer("freqs_sin", freqs_sin)

        logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')

    def forward(self, t, patch_indices_keep=None):
        if patch_indices_keep is not None:
            batch = t.size()[0]
            batch_indices = torch.arange(batch)
            batch_indices = batch_indices[..., None]

            freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
            freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])

            freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
            freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
            freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
            freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')

            return t * freqs_cos + rotate_half(t) * freqs_sin

        return t * self.freqs_cos + rotate_half(t) * self.freqs_sin


class Mlp(nn.Module):
    def __init__(
            self,
            in_features,
            hidden_features=None,
            out_features=None,
            act_layer=nn.GELU,
            norm_layer=nn.LayerNorm,
            drop=0.,
            subln=False,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()

        self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()

        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        # x = self.drop(x)
        # commit this for the orignal BERT implement 
        x = self.ffn_ln(x)

        x = self.fc2(x)
        x = self.drop(x)
        return x


class SwiGLU(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
                 norm_layer=nn.LayerNorm, subln=False):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.w1 = nn.Linear(in_features, hidden_features)
        self.w2 = nn.Linear(in_features, hidden_features)

        self.act = act_layer()
        self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
        self.w3 = nn.Linear(hidden_features, out_features)

        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x1 = self.w1(x)
        x2 = self.w2(x)
        hidden = self.act(x1) * x2
        x = self.ffn_ln(hidden)
        x = self.w3(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(
            self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.1,
            proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False,
            norm_layer=nn.LayerNorm, featscale=True):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        if attn_head_dim is not None:
            head_dim = attn_head_dim
        all_head_dim = head_dim * self.num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.subln = subln
        if self.subln:
            self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
            self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
            self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
        else:
            self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)

        if qkv_bias:
            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
        else:
            self.q_bias = None
            self.v_bias = None

        if window_size:
            self.window_size = window_size
            self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
            self.relative_position_bias_table = nn.Parameter(
                torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH
            # cls to token & token 2 cls & cls to cls

            # get pair-wise relative position index for each token inside the window
            coords_h = torch.arange(window_size[0])
            coords_w = torch.arange(window_size[1])
            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
            relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
            relative_coords[:, :, 1] += window_size[1] - 1
            relative_coords[:, :, 0] *= 2 * window_size[1] - 1
            relative_position_index = \
                torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
            relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
            relative_position_index[0, 0:] = self.num_relative_distance - 3
            relative_position_index[0:, 0] = self.num_relative_distance - 2
            relative_position_index[0, 0] = self.num_relative_distance - 1

            self.register_buffer("relative_position_index", relative_position_index)
        else:
            self.window_size = None
            self.relative_position_bias_table = None
            self.relative_position_index = None

        self.attn_drop = nn.Dropout(attn_drop)
        self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
        # self.proj = nn.Linear(all_head_dim, all_head_dim)
        self.proj = nn.Linear(all_head_dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.xattn = xattn
        self.xattn_drop = attn_drop

        self.featscale = featscale

        if self.featscale:
            self.lambd1 = nn.Parameter(0.02 * torch.randn(all_head_dim), requires_grad=True)
            self.lambd2 = nn.Parameter(0.02 * torch.randn(all_head_dim), requires_grad=True)

        self.rope = rope

    def haar_featscale_forward(self, x):
        input_dtype = x.dtype
        with torch.autocast(device_type="cuda", dtype=torch.float32):
            res = ptwt.wavedec(x.to(torch.float32), axis=-2, level=1, mode="zero",
                               wavelet=pywt.Wavelet("haar"))
            [low, high] = res
        low = low.to(input_dtype)
        high = high.to(input_dtype)
        lambd1, lambd2 = self.lambd1.view(1, 1, -1), self.lambd2.view(1, 1, -1)
        low.mul(lambd1)
        high.mul(lambd2)
        with torch.autocast(device_type="cuda", dtype=torch.float32):
          #  print(x.shape, "x")
            rescaled = ptwt.waverec([low.to(torch.float32),
                                     high.to(torch.float32)],
                                    axis=-2,
                                    wavelet=pywt.Wavelet("haar"))
          #  print(rescaled.shape)
        x.add(rescaled.to(input_dtype))
        return x

    def forward(self, x, rel_pos_bias=None, attn_mask=None):
        B, N, C = x.shape
        if self.subln:
            q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
            k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
            v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)

            q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)  # B, num_heads, N, C
            k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
            v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
        else:

            qkv_bias = None
            if self.q_bias is not None:
                qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))

            qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
            qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)  # 3, B, num_heads, N, C
            q, k, v = qkv[0], qkv[1], qkv[2]

        if self.rope:
            # slightly fast impl
            q_t = q[:, :, 1:, :]
            ro_q_t = self.rope(q_t)
            q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)

            k_t = k[:, :, 1:, :]
            ro_k_t = self.rope(k_t)
            k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)

        if self.xattn:
            q = q.permute(0, 2, 1, 3)  # B, num_heads, N, C -> B, N, num_heads, C
            k = k.permute(0, 2, 1, 3)
            v = v.permute(0, 2, 1, 3)
            x = F.scaled_dot_product_attention(
                q.transpose(1, 2),
                k.transpose(1, 2),
                v.transpose(1, 2)
            ).transpose(1, 2)
         #   x = flash_attn_qkvpacked_func(qkv=torch.stack((q, k, v), dim=2),
          #                                causal=False,
                 #                         dropout_p=self.xattn_drop,
           #                               softmax_scale=self.scale)
            # x = xops.memory_efficient_attention(
            #     q, k, v,
            #      p=self.xattn_drop,
            #      scale=self.scale,
            # 3   )
            x = x.reshape(B, N, -1)
            x = self.inner_attn_ln(x)
         #   if self.featscale:
            #    x = self.haar_featscale_forward(x)
            x = self.proj(x)
            x = self.proj_drop(x)
        else:
            q = q * self.scale
            attn = (q @ k.transpose(-2, -1))

            if self.relative_position_bias_table is not None:
                relative_position_bias = \
                    self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
                        self.window_size[0] * self.window_size[1] + 1,
                        self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH
                relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
                attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)

            if rel_pos_bias is not None:
                attn = attn + rel_pos_bias.type_as(attn)

            if attn_mask is not None:
                attn_mask = attn_mask.bool()
                attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))

            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)

            x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
            x = self.inner_attn_ln(x)
            x = self.proj(x)
            x = self.proj_drop(x)
        return x


class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                 window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,
                 subln=False, naiveswiglu=True):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
            attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,
            xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

        if naiveswiglu:
            self.mlp = SwiGLU(
                in_features=dim,
                hidden_features=mlp_hidden_dim,
                subln=subln,
                norm_layer=norm_layer,
            )
        else:
            self.mlp = Mlp(
                in_features=dim,
                hidden_features=mlp_hidden_dim,
                act_layer=act_layer,
                subln=subln,
                drop=drop
            )

        if init_values is not None and init_values > 0:
            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
        else:
            self.gamma_1, self.gamma_2 = None, None

        self.postnorm = postnorm

    def forward(self, x, rel_pos_bias=None, attn_mask=None):
        if self.gamma_1 is None:
            if self.postnorm:
                x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
                x = x + self.drop_path(self.norm2(self.mlp(x)))
            else:
                x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
                x = x + self.drop_path(self.mlp(self.norm2(x)))
        else:
            if self.postnorm:
                x = x + self.drop_path(
                    self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
                x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
            else:
                x = x + self.drop_path(
                    self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
                x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
        return x


class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """

    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x, **kwargs):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x


class RelativePositionBias(nn.Module):

    def __init__(self, window_size, num_heads):
        super().__init__()
        self.window_size = window_size
        self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH
        # cls to token & token 2 cls & cls to cls

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(window_size[0])
        coords_w = torch.arange(window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * window_size[1] - 1
        relative_position_index = \
            torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
        relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        relative_position_index[0, 0:] = self.num_relative_distance - 3
        relative_position_index[0:, 0] = self.num_relative_distance - 2
        relative_position_index[0, 0] = self.num_relative_distance - 1

        self.register_buffer("relative_position_index", relative_position_index)

    def forward(self):
        relative_position_bias = \
            self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
                self.window_size[0] * self.window_size[1] + 1,
                self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH
        return relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww


class TransformerBlock(nn.Module):
    """ Vision Transformer with support for patch or hybrid CNN input stage
    """

    def __init__(self, in_chans=1, embed_dim=768, depth=12, kernel_size=(4, 4), stride=(2, 2), padding=1,
                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
                 drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,
                 rope=True,
                 init_scale=0.001,
                 grad_checkpointing=False, xattn=True, postnorm=False, intp_freq=False, naiveswiglu=True, subln=False):
        super().__init__()
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        self.patch_embed = nn.Conv2d(in_chans, embed_dim, kernel_size, stride, padding)

        self.pos_drop = nn.Dropout(p=drop_rate)
        self.rope = None
        self.naiveswiglu = naiveswiglu

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        #  self.use_rel_pos_bias = use_rel_pos_bias
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
                init_values=init_values, window_size=None, xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln,
                naiveswiglu=naiveswiglu)
            for i in range(depth)])

        self.apply(self._init_weights)
        self.fix_init_weight()

        # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
        #   self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()

        self.grad_checkpointing = grad_checkpointing

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            if self.naiveswiglu:
                rescale(layer.mlp.w3.weight.data, layer_id + 1)
            else:
                rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def get_cast_dtype(self) -> torch.dtype:
        return self.blocks[0].mlp.fc2.weight.dtype

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def get_num_layers(self):
        return len(self.blocks)

    def lock(self, unlocked_groups=0, freeze_bn_stats=False):
        assert unlocked_groups == 0, 'partial locking not currently supported for this model'
        for param in self.parameters():
            param.requires_grad = False

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.grad_checkpointing = enable

    #  @torch.jit.ignore
    #  def no_weight_decay(self):
    #      return {'pos_embed', 'cls_token'}

    def forward_features(self, x):  # , return_all_features=False):
        x = self.patch_embed(x)
        x = rearrange(x, "b c h w -> b (h w) c").contiguous()
        batch_size, seq_len, _ = x.size()

        # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
        #  if os.getenv('RoPE') == '1':
        #      if self.training and not isinstance(self.patch_dropout, nn.Identity):
        #         x, patch_indices_keep = self.patch_dropout(x)
        #        self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
        #    else:
        #         self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
        #      x = self.patch_dropout(x)
        #   else:
        #            x = self.patch_dropout(x)

        for blk in self.blocks:
            if self.grad_checkpointing:
                x = checkpoint(blk, x, (None,), use_reentrant=False)
            else:
                x = blk(x, rel_pos_bias=None)
        return x

    def forward(self, x):
        # if return_all_features:
        #     return self.forward_features(x)
        x = self.forward_features(x)
        return x


class LearnableFourierFeaturePositionalEmebedding2d(nn.Module):
    def __init__(self, dim, omega=2.0, out_dim=None, modulation=False):
        super().__init__()
        self.out_dim = dim if out_dim is None else out_dim
        self.modulation = modulation
        self.proj = nn.Linear(dim, self.out_dim)
        self.w1 = nn.Linear(2, dim)
        self.w2 = nn.Linear(dim, dim)
        self.w3 = nn.Linear(dim, dim)
        self.w4 = nn.Linear(dim, self.out_dim)
        if self.modulation:
            self.v1 = nn.Sequential(nn.Linear(dim, dim),
                                    nn.SiLU(),
                                    nn.Linear(dim, dim * 2))
            self.v2 = nn.Sequential(nn.Linear(dim, dim),
                                    nn.SiLU(),
                                    nn.Linear(dim, dim * 2))
            self.v3 = nn.Sequential(nn.Linear(dim, dim),
                                    nn.SiLU(),
                                    nn.Linear(dim, dim * 2))
        self.omega = omega

    def forward_features(self, x):
        B, C, H, W = x.shape
        xx, yy = torch.linspace(-1, 1, H).to(x), torch.linspace(-1, 1, W).to(x)
        xx, yy = torch.meshgrid([xx, yy], indexing="ij")
        pos = torch.stack((xx, yy), dim=-1).unsqueeze(0)
        x = x.permute(0, 2, 3, 1).contiguous()
        pos = self.w1(pos)
        if self.modulation:
            scale, bias = self.v1(x).chunk(2, dim=-1)
            pos = pos * scale + bias
        pos = torch.sin(pos * self.omega)
        pos = self.w2(pos)
        if self.modulation:
            scale, bias = self.v2(x).chunk(2, dim=-1)
            pos = pos * scale + bias
        pos = torch.sin(self.omega * pos)
        pos = self.w3(pos)
        if self.modulation:
            scale, bias = self.v3(x).chunk(2, dim=-1)
            pos = pos * scale + bias
        pos = torch.sin(pos * self.omega)
        pos = self.w4(pos)
        return (pos + self.proj(x)).permute(0, 3, 1, 2).contiguous()

    def forward(self, x):
        return checkpoint(self.forward_features, x, use_reentrant=False)


class DownsamplingBlock(nn.Module):
    def __init__(self, in_chans, out_chans):
        super().__init__()
        self.conv = nn.Conv2d(in_chans, out_chans,
                              kernel_size=(4, 4),
                              stride=(2, 2), padding=1)

    def forward(self, x):
        x = self.conv(x)
        return x
    #    x = F.interpolate(x, size=(H // 2, W // 2), mode="nearest")


class UpsamplingBlock(nn.Module):
    def __init__(self, in_chans, out_chans):
        super().__init__()
        self.conv = nn.Conv2d(in_chans, out_chans, kernel_size=(3, 3), stride=(1, 1), padding=1)
        self.fourier = LearnableFourierFeaturePositionalEmebedding2d(out_chans,
                                                                     modulation=True)

    def forward_features(self, x):
        B, C, H, W = x.shape
        x = F.interpolate(x, size=(H * 2, W * 2), mode="nearest")
        x = self.conv(x)
        x = self.fourier(x)
        return x

    def forward(self, x):
        return checkpoint(self.forward_features, x, use_reentrant=False)


class Encoder(nn.Module):
    def __init__(self, in_chans=1, embed_dim=128, z_dim=512, depth=1, dim_mult=(1, 2, 2, 4),
                 heads=(2, 4, 4, 8)):
        super().__init__()
        self.dim_mult = list(dim_mult)
        self.heads = list(heads)
        self.depth = depth
        in_dim = in_chans
        self.embed_dim = embed_dim
        self.blocks = nn.ModuleList()
        for i in range(len(dim_mult)):
            out_dim = self.dim_mult[i] * self.embed_dim
            blocks = nn.Module()
            blocks.downsampler = DownsamplingBlock(in_dim, out_dim)
            # blocks.pos_emb = LearnableFourierFeaturePositionalEmebedding2d(out_dim)
            blocks.convblock_1 = ConvBlock(out_dim, out_dim,
                                           kernel_size=1,
                                           dw_kernel_size=3)
            blocks.convblock_2 = ConvBlock(out_dim, out_dim,
                                           kernel_size=3,
                                           dw_kernel_size=5)
            blocks.transformer = TransformerBlock(out_dim, out_dim,
                                                  self.depth, num_heads=self.heads[i],
                                                  kernel_size=(3, 3), stride=(1, 1), padding=1)
            blocks.convblock_3 = ConvBlock(out_dim, out_dim,
                                           kernel_size=1,
                                           dw_kernel_size=3)
            blocks.convblock_4 = ConvBlock(out_dim, out_dim,
                                           kernel_size=3,
                                           dw_kernel_size=5)
            self.blocks.append(blocks)
            in_dim = out_dim
        self.head_norm = nn.LayerNorm(embed_dim * self.dim_mult[-1])
        self.head = nn.Linear(embed_dim * self.dim_mult[-1], z_dim)

    def forward(self, x):
        for i, blk in enumerate(self.blocks):
            B, _, H, W = x.shape
       #     print('d', x[0,0:2,0:2, 0:2])
            x = blk.downsampler(x)
        #    print('d', x[0,0:2,0:2,0:2])
            skip = x.clone()
            x = blk.convblock_1(x)
           # print('d',x[0, 0:2, 0:2, 0:2])
            x = blk.convblock_2(x)
          #  print('d',x[0, 0:2, 0:2, 0:2])
            x = blk.transformer(x)
         #   print('d',x[0, 0:2, 0:2])
            C = x.size(-1)
            x = x.reshape(B, H // 2, W // 2, C).permute(0, 3, 1, 2).contiguous()
            x = blk.convblock_3(x)
        #    print('d',x[0, 0:2, 0:2, 0:2])
            x = blk.convblock_4(x) + skip
    #     3   print('d',x[0, 0:2, 0:2, 0:2])
        x = self.head_norm(x.permute(0, 2, 3, 1).contiguous())
        x = self.head(x).permute(0, 3, 1, 2).contiguous()
        return x


class Decoder(nn.ModuleList):
    def __init__(self,
                 out_chans=1,
                 embed_dim=128,
                 z_dim=512,
                 depth=1,
                 dim_mult=(1, 2, 2, 4),
                 heads=(2, 4, 4, 8)):
        super().__init__()
        self.dim_mult = list(dim_mult)
        self.heads = list(heads)
        self.depth = depth
        in_dim = z_dim
        self.embed_dim = embed_dim
        self.blocks = nn.ModuleList()
        N = len(self.dim_mult)
        for i in range(len(dim_mult)):
            out_dim = self.dim_mult[N - 1 - i] * self.embed_dim
            blocks = nn.Module()
            blocks.upsampler = UpsamplingBlock(in_dim, out_dim)
            #     blocks.pos_emb = LearnableFourierFeaturePositionalEmebedding2d(out_dim, modulation=True)
            blocks.convblock_1 = ConvBlock(out_dim, out_dim,
                                           kernel_size=1,
                                           dw_kernel_size=3)
            blocks.convblock_2 = ConvBlock(out_dim, out_dim,
                                           kernel_size=3,
                                           dw_kernel_size=5)
            blocks.transformer = TransformerBlock(out_dim, out_dim,
                                                  self.depth, num_heads=self.heads[i],
                                                  kernel_size=(3, 3), stride=(1, 1), padding=1)
            blocks.convblock_3 = ConvBlock(out_dim, out_dim,
                                           kernel_size=1,
                                           dw_kernel_size=3)
            blocks.convblock_4 = ConvBlock(out_dim, out_dim,
                                           kernel_size=3,
                                           dw_kernel_size=5)

            self.blocks.append(blocks)
            in_dim = out_dim
        self.head_norm = nn.LayerNorm(embed_dim * self.dim_mult[0])
        self.head = LearnableFourierFeaturePositionalEmebedding2d(embed_dim * self.dim_mult[0],
                                                                  out_dim=out_chans,
                                                                  modulation=True)

    def forward(self, x):
        for i, blk in enumerate(self.blocks):
            #     x = blk.pos_emb(x)
            x = blk.upsampler(x)

            B, _, H, W = x.shape
            skip = x.clone()
            x = blk.convblock_1(x)
            x = blk.convblock_2(x)
            x = blk.transformer(x)
            C = x.size(-1)
            x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
            x = blk.convblock_3(x)
            x = blk.convblock_4(x) + skip
        x = self.head_norm(x.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
        x = self.head(x)
        return x


import random


class GRQVAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.contrastive_head = None
        #  self.contrastive_head = nn.Sequential(
        #       nn.Linear(512, 512),
        #       nn.InstanceNorm1d(512),
        #       nn.SiLU(),
        #       nn.Linear(512, 256),
        #   )
        self.ema_decay = 0.999
        self.ema_encoder = None
        self.bottleneck = GroupedResidualLFQ(dim=512,
                                             groups=4,
                                             codebook_size=1024 *
                                                           16,
                                             num_quantizers=8)
        self.contrastive_head = nn.Sequential(
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 128))

    def add_contrastive_head(self):
        self.contrastive_head = nn.Sequential(
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 128))

    def off_diagonal(self, x):
        # return a flattened view of the off-diagonal elements of a square matrix
        n, m = x.shape
        assert n == m
        return x.flatten()[:-1].reshape(n - 1, n + 1)[:, 1:].flatten()

    def barlow_twins_loss(self, y1, y2, lambd: float = 0.002):
        z1 = self.contrastive_head(y1)
        z2 = self.contrastive_head(y2)
        c = F.normalize(z1, dim=-1).T @ F.normalize(z2)
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = self.off_diagonal(c).pow_(2).sum()
        loss = on_diag + lambd * off_diag
        return loss

    @torch.no_grad()
    def add_ema_encoder(self):
        self.ema_encoder = copy.deepcopy(self.encoder)
        for param in self.ema_encoder.parameters():
            param.requires_grad = False

    @torch.no_grad()
    def update_ema_encoder(self):
        for param_q, param_k in zip((self.encoder.parameters(), self.ema_encoder.parameters())):
            param_k.data.mul_(self.ema_decay).add_(param_q.data, alpha=1 - self.ema_decay)

    def contrastive_loss(self, x):
        k, q = x.chunk(2, dim=0)
        q = self.encoder(q)
        with torch.no_grad():
            k = self.ema_encoder(k)
        q, k = self.bottleneck(q)[0], self.bottleneck(k)[0]
        q, k = q.mean(dim=1), k.mean(dim=-1)
        q, k = F.normalize(q, dim=-1), F.normalize(k, dim=-1)
        q, k = self.contrastive_head(q), self.contrastive_head(k)
        return None

    def forward(self, x):
        x = self.encoder(x)
        b, c, h, w = x.shape
        x = rearrange(x, "b c h w -> b (h w) c")
        dtype = x.dtype
        use_noise = random.choice([0, 1, 2, 3])
        with torch.autocast(device_type="cuda", dtype=torch.float32):
            C = x.size(-1)
            # noise_type = random.choice([0, 1])
            # if noise_type == 0:
            #      noise = (2*torch.rand_like(x)-1).div(math.sqrt(C))
            #  else:
            #      noise = torch.randn_like(x).div(math.sqrt(C))
            #  if use_noise == 1:
            #    x = x + noise
            x, indices, loss = self.bottleneck(x.float())
        x = x.to(dtype)
        #  global_avg_pooled_features = x.mean(dim=1)
        x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
        # if self.training:
        # choice = random.choice([0, 1])
        # if choice == 0:
        #    noise = (2*torch.rand_like(x)-1) / math.sqrt(C)
        # else:
        #     noise = torch.randn_like(x).div(math.sqrt(C))
        #   if use_noise:
        #      x = x + noise
        _, C, H, W = x.shape
        # print("bottlenecked: ", x[0, 0:4, 0:2, 0])
        with torch.autocast(device_type="cuda", dtype=torch.float32):
            z = x.mean(dim=(-2, -1)).to(torch.float32)
            z = self.contrastive_head(z)
            z1, z2 = z.chunk(2, dim=-1)
            #  print(z1[0, 0:4])
            N, D = z1.shape
            var_z1 = torch.sqrt(z1.var(dim=0) + 1e-4)
            var_z2 = torch.sqrt(z2.var(dim=0) + 1e-4)
            loss_variance = torch.mean(F.relu(1 - var_z1)) + torch.mean(F.relu(1 - var_z2))
            loss_invariance = 0.5 * F.l1_loss(z1, z2) + 0.5 * F.mse_loss(z2, z2)
            z1 = z1 - z1.mean(dim=0, keepdim=True)
            z2 = z2 - z2.mean(dim=0, keepdim=True)
            cov_z1 = torch.square((z1.t() @ z1) / (N - 1))
            cov_z2 = torch.square((z2.t() @ z2) / (N - 1))
            loss_cov1 = (cov_z1.sum() - cov_z1.diagonal().sum()) / D
            loss_cov2 = (cov_z2.sum() - cov_z2.diagonal().sum()) / D
            loss_covariance = loss_cov1 + loss_cov2
            contrastive_loss = loss_covariance + loss_invariance * 25 + loss_variance * 25
        x = self.decoder(x)
        return {"loss": loss.sum(),
                "ssl_loss": contrastive_loss,
                "recon": x,
                #   "global_avg_pooled_features":global_avg_pooled_features,
                "indices": indices}


def test():
    x = torch.randn(2, 1, 256, 256).cuda()
    f = GRQVAE().cuda()
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        y = f(x)
