import copy

import torch
import numpy as np
import random
from torch import nn
from torch.nn import functional as F
# from xformers.components.feedforward import FusedMLP
from xformers.ops.swiglu_op import SwiGLU
from timm.layers.drop import DropPath
from torch.utils import checkpoint
from einops import rearrange, einsum
from ot.gaussian import empirical_bures_wasserstein_distance
from m_foundation_model import NormLayer, Tokenizer
from hyperattention.hyper_attention import HyperAttention
from flash_attn import flash_attn_func
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------

import math
from typing import Callable, Tuple

import torch

class Heaviside(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return (x>=0).type(x.dtype)
    @staticmethod
    def backward(ctx, grad):
        input, = ctx.saved_tensors
        res = input.sigmoid()
        return res*(1-res)*grad


class StarReLU(nn.Module):
    def __init__(self):
        super().__init__()
        self.s = nn.Parameter(torch.as_tensor(0.8944), requires_grad=True)
        self.b = nn.Parameter(torch.as_tensor(-0.4472), requires_grad=True)

    def forward(self, x):
        return F.relu(x).pow(2) * self.s + self.b


class Adaptor(nn.Module):
    def __init__(self,
                 w_a,
                 w_b,
                 act,
                 r=8,
                 alpha=2):
        super().__init__()
        self.adaptor_w_a = w_a
        self.adaptor_w_b = w_b
        self.alpha = alpha
        self.scale = nn.Parameter(torch.as_tensor(1.0),
                                  requires_grad=True)
        self.act = act

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'scale'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'scale'}

    def forward(self, x, **kwargs,
                ):
        identity = x
        gate = self.adaptor_w_a(x)
        x = self.act(gate)
        x = self.adaptor_w_b(x)
        x = F.dropout(x, 0.1, self.training)
        x = x + identity
        return x


class LoRA(nn.Module):
    def __init__(self,
                 fn,
                 w_a,
                 w_b,
                 r=8,
                 alpha=2):
        super().__init__()
        self.fn = fn
        self.lora_w_a = w_a
        self.lora_w_b = w_b
        self.r = r
        self.scale = nn.Parameter(torch.as_tensor(1.0), requires_grad=True)
        self.alpha = alpha

        if isinstance(self.lora_w_a, nn.Linear):
            nn.init.kaiming_normal_(self.lora_w_a.weight)
            nn.init.constant_(self.lora_w_a.bias, 0.)
        if isinstance(self.lora_w_b, nn.Linear):
            nn.init.kaiming_normal_(self.lora_w_b.weight)
            nn.init.constant_(self.lora_w_b.bias, 0.)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'scale'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'scale'}

    def forward(self, x):
        bypass = self.lora_w_a(x)
        bypass = self.lora_w_b(bypass)
        bypass = F.dropout(bypass, 0.1, training=self.training)
        x = self.fn(x) + (self.alpha // self.r) * self.scale * bypass
        return x


def apply_lora(model, r=8, alpha=8):
    model_copy = copy.deepcopy(model)
    for name, m in model_copy.named_modules():
        if isinstance(m, SelfAttention):
            in_dim, out_dim = m.shared_qkv_layer.in_features, \
                              m.shared_qkv_layer.out_features
            lora_module = LoRA(copy.deepcopy(m.shared_qkv_layer),
                               nn.Linear(in_dim, r),
                               nn.Linear(r, out_dim),
                               r,
                               alpha)
            new_m = copy.deepcopy(m)
            setattr(new_m, "shared_qkv_layer", lora_module)
            setattr(model, name, copy.deepcopy(new_m))
    del model_copy
    return model


def apply_adaptor(model, r=8, alpha=8):
    model_copy = copy.deepcopy(model)
    for name, m in model_copy.named_modules():
        if isinstance(m, MixerLayer):
            in_dim, out_dim = m.glu_layer.in_features, m.glu_layer.out_features
            new_m = copy.deepcopy(m)
            setattr(new_m, "adaptor_layer_1", Adaptor(
                nn.Linear(in_dim, r),
                nn.Linear(r, out_dim),
                nn.SiLU(),
                r,
                alpha))
            setattr(new_m, "adaptor_layer_2", Adaptor(
                nn.Linear(in_dim, r),
                nn.Linear(r, out_dim),
                nn.SiLU(),
                r,
                alpha)
                    )
            setattr(model, name, copy.deepcopy(new_m))
    del model_copy
    return model


def do_nothing(x, mode=None):
    return x


def bipartite_soft_matching(
        metric: torch.Tensor,
        r: int,
        class_token: bool = False,
        distill_token: bool = False,
) -> Tuple[Callable, Callable]:
    protected = 0
    if class_token:
        protected += 1
    if distill_token:
        protected += 1

    # We can only reduce by a maximum of 50% tokens
    t = metric.shape[1]
    r = min(r, (t - protected) // 2)

    if r <= 0:
        return do_nothing, do_nothing

    with torch.no_grad():
        metric = metric / metric.norm(dim=-1, keepdim=True)
        a, b = metric[..., ::2, :], metric[..., 1::2, :]
        scores = a @ b.transpose(-1, -2)

        if class_token:
            scores[..., 0, :] = -math.inf
        if distill_token:
            scores[..., :, 0] = -math.inf

        node_max, node_idx = scores.max(dim=-1)
        edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]

        unm_idx = edge_idx[..., r:, :]  # Unmerged Tokens
        src_idx = edge_idx[..., :r, :]  # Merged Tokens
        dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)

        if class_token:
            # Sort to ensure the class token is at the start
            unm_idx = unm_idx.sort(dim=1)[0]

    def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
        src, dst = x[..., ::2, :], x[..., 1::2, :]
        n, t1, c = src.shape
        unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))
        src = src.gather(dim=-2, index=src_idx.expand(n, r, c))
        dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)

        if distill_token:
            return torch.cat([unm[:, :1], dst[:, :1], unm[:, 1:], dst[:, 1:]], dim=1)
        else:
            return torch.cat([unm, dst], dim=1)

    def unmerge(x: torch.Tensor) -> torch.Tensor:
        unm_len = unm_idx.shape[1]
        unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
        n, _, c = unm.shape

        src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c))

        out = torch.zeros(n, metric.shape[1], c, device=x.device, dtype=x.dtype)

        out[..., 1::2, :] = dst
        out.scatter_(dim=-2, index=(2 * unm_idx).expand(n, unm_len, c), src=unm)
        out.scatter_(dim=-2, index=(2 * src_idx).expand(n, r, c), src=src)

        return out

    return merge, unmerge


def kth_bipartite_soft_matching(
        metric: torch.Tensor, k: int
) -> Tuple[Callable, Callable]:
    if k <= 1:
        return do_nothing, do_nothing

    def split(x):
        t_rnd = (x.shape[1] // k) * k
        x = x[:, :t_rnd, :].view(x.shape[0], -1, k, x.shape[2])
        a, b = (
            x[:, :, : (k - 1), :].contiguous().view(x.shape[0], -1, x.shape[-1]),
            x[:, :, (k - 1), :],
        )
        return a, b

    with torch.no_grad():
        metric = metric / metric.norm(dim=-1, keepdim=True)
        a, b = split(metric)
        r = a.shape[1]
        scores = a @ b.transpose(-1, -2)

        _, dst_idx = scores.max(dim=-1)
        dst_idx = dst_idx[..., None]

    def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
        src, dst = split(x)
        n, _, c = src.shape
        dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)

        return dst

    def unmerge(x: torch.Tensor) -> torch.Tensor:
        n, _, c = x.shape
        dst = x

        src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c)).to(x.dtype)

        src = src.view(n, -1, (k - 1), c)
        dst = dst.view(n, -1, 1, c)

        out = torch.cat([src, dst], dim=-2)
        out = out.contiguous().view(n, -1, c)

        return out

    return merge, unmerge


def random_bipartite_soft_matching(
        metric: torch.Tensor, r: int
) -> Tuple[Callable, Callable]:
    if r <= 0:
        return do_nothing, do_nothing

    with torch.no_grad():
        B, N, _ = metric.shape
        rand_idx = torch.rand(B, N, 1, device=metric.device).argsort(dim=1)

        a_idx = rand_idx[:, :r, :]
        b_idx = rand_idx[:, r:, :]

        def split(x):
            C = x.shape[-1]
            a = x.gather(dim=1, index=a_idx.expand(B, r, C))
            b = x.gather(dim=1, index=b_idx.expand(B, N - r, C))
            return a, b

        metric = metric / metric.norm(dim=-1, keepdim=True)
        a, b = split(metric)
        scores = a @ b.transpose(-1, -2)

        _, dst_idx = scores.max(dim=-1)
        dst_idx = dst_idx[..., None]

    def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
        src, dst = split(x)
        C = src.shape[-1]
        dst = dst.scatter_reduce(-2, dst_idx.expand(B, r, C), src, reduce=mode)
        return dst

    def unmerge(x: torch.Tensor) -> torch.Tensor:
        C = x.shape[-1]
        dst = x
        src = dst.gather(dim=-2, index=dst_idx.expand(B, r, C))
        out = torch.zeros(B, N, C, device=x.device, dtype=x.dtype)
        out.scatter_(dim=-2, index=a_idx.expand(B, r, C), src=src)
        out.scatter_(dim=-2, index=b_idx.expand(B, N - r, C), src=dst)
        return out

    return merge, unmerge


def merge_wavg(
        merge: Callable, x: torch.Tensor, size: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    if size is None:
        size = torch.ones_like(x[..., 0, None])
    x = merge(x * size, mode="sum")
    size = merge(size, mode="sum")
    x = x / size
    return x, size


def merge_source(
        merge: Callable, x: torch.Tensor, source: torch.Tensor = None
) -> torch.Tensor:
    if source is None:
        n, t, _ = x.shape
        source = torch.eye(t, device=x.device)[None, ...].expand(n, t, t)
    source = merge(source, mode="amax")
    return source


class SCELoss(nn.Module):
    def __init__(self, num_classes=2, a=1, b=0.1):
        super().__init__()
        self.a = a
        self.b = b
        self.num_classes = num_classes
        self.cross_entropy = nn.CrossEntropyLoss(label_smoothing=1e-4)

    def forward(self, preds, labels):
        ce = self.cross_entropy(preds, labels)
        preds = preds.softmax(dim=-1)
        preds = torch.clamp(preds, min=1e-7,
                            max=1.0)
        label_one_hot = F.one_hot(labels,
                                  self.num_classes).float().to(preds.device)
        label_one_hot = torch.clamp(label_one_hot, min=1e-4, max=1.0)
        rce = (-1 * torch.sum(preds * torch.log(label_one_hot), dim=-1))
        loss = self.a * ce + self.b * rce.mean()
        return loss


class MultiHeadFeatScale(nn.Module):
    def __init__(self, head_features, heads):
        super().__init__()
        self.lambd1 = nn.Parameter(torch.ones(1, heads, 1, head_features))
        self.lambd2 = nn.Parameter(torch.ones(1, heads, 1, head_features))

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'lambd1',
                'lambd2'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'lambd1',
                'lambd2'}

    def forward(self, x):
        B, H, L, D = x.shape
        x_mean = x.mean(dim=-2, keepdim=True)
        x = x + (x - x_mean) * self.lambd1 + x_mean * self.lambd2
        return x


class BidirectionalCrossAttention(nn.Module):
    def __init__(
            self,
            dim=768,
            heads=16,
            dim_head=768 // 16,
            context_dim=None,
            dropout=0.,
            talking_heads=True,
    ):
        super().__init__()
        context_dim = dim if context_dim is None else context_dim
        self.norm = nn.LayerNorm(dim)
        self.context_norm = nn.LayerNorm(context_dim)
        self.heads = heads
        inner_dim = dim_head * heads
        self.scale = dim_head ** -0.5
        self.dropout = nn.Dropout(dropout)
        self.context_dropout = nn.Dropout(dropout)

        self.to_qk = nn.Linear(dim, inner_dim, bias=False)
        self.context_to_qk = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_v = nn.Linear(dim, inner_dim, bias=False)
        self.context_to_v = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_out = nn.Linear(inner_dim, dim)
        self.context_to_out = nn.Linear(inner_dim, context_dim)

        self.talking_heads = nn.Conv2d(heads, heads,
                                       1, bias=False) if talking_heads else nn.Identity()
        self.context_talking_heads = nn.Conv2d(heads, heads,
                                               1, bias=False) if talking_heads else nn.Identity()

    def forward(
            self,
            x,
            context,
    ):
        b, i, j, h, device = x.shape[0], x.shape[-2], context.shape[-2], self.heads, x.device
        x = self.norm(x)
        context = self.context_norm(context)
        # get shared query/keys and values for sequence and context
        qk, v = self.to_qk(x), self.to_v(x)
        context_qk, context_v = self.context_to_qk(context), self.context_to_v(context)
        # split out head
        qk, context_qk, v, context_v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h),
                                           (qk, context_qk, v, context_v))
        sim = torch.einsum('b h i d, b h j d -> b h i j', qk, context_qk) * self.scale
        attn = sim.softmax(dim=-1)
        context_attn = sim.softmax(dim=-2)
        # dropouts
        attn = self.dropout(attn)
        context_attn = self.context_dropout(context_attn)
        # talking heads
        #    print(attn.shape)
        #     print(context_attn.shape)
        attn = self.talking_heads(attn)
        context_attn = self.context_talking_heads(context_attn)
        # src sequence aggregates values from context, context aggregates values from src sequence
        out = torch.einsum('b h i j, b h j d -> b h i d', attn, context_v)
        context_out = torch.einsum('b h j i, b h j d -> b h i d', context_attn, v)
        # merge heads and combine out
        out, context_out = map(lambda t: rearrange(t, 'b h n d -> b n (h d)'), (out, context_out))
        out = self.to_out(out)
        context_out = self.context_to_out(context_out)
        return out, context_out


class SelfAttention(nn.Module):
    def __init__(self,
                 in_features: int = 512,
                 head_features: int = 64,
                 qkv_heads: int = 8,
                 drop_path: float = 0.,
                 drop_proj: float = 0.05,
                 min_seq_len: int = 1028):
        super().__init__()
        self.qkv_heads = qkv_heads
        self.head_features = head_features

        self.shared_qkv_layer = nn.Linear(
            in_features,
            head_features * self.qkv_heads * 3
        )
        self.out_layer = nn.Linear(self.qkv_heads * head_features,
                                   in_features)
        self.query_norm_layer = NormLayer(head_features)
        self.key_norm_layer = NormLayer(head_features)
        self.drop_path = DropPath(drop_path)
        self.drop_proj = nn.Dropout(drop_proj)
        self.multihead_featscale = MultiHeadFeatScale(head_features, qkv_heads)

    def forward_features(
            self,
            hidden_states: torch.Tensor,
            attention_mask: torch.Tensor = None,
            prefix_len: int = None,
            masked_prefix_token_ids: torch.Tensor = None,
    ):
        hidden_states = checkpoint.checkpoint(self.shared_qkv_layer, hidden_states, use_reentrant=True)
        hidden_states = self.drop_proj(hidden_states)
        queries, keys, values = hidden_states.chunk(3, dim=-1)
        query_len, key_len = queries.size(-2), keys.size(-2)
        batch_size = queries.size(0)
        queries, keys, values = map(lambda t: rearrange(t,
                                                        "b l (h d) -> b h l d",
                                                        h=self.qkv_heads),
                                    (queries, keys, values))
        queries = self.query_norm_layer(queries)
        keys = self.key_norm_layer(keys)
        causal_mask = torch.ones(query_len, key_len, dtype=torch.long).tril().expand(batch_size,
                                                                                     self.qkv_heads,
                                                                                     query_len, key_len).clone()
        causal_mask = causal_mask.clone()
        causal_mask[:, :, 0:prefix_len, 0:prefix_len] = causal_mask[:, :, 0:prefix_len, 0:prefix_len] + 1
        #   if masked_prefix_token_ids is not None:
        #     causal_mask[:, :, masked_prefix_token_ids, :] = 0
        prefix_causal_mask = causal_mask.bool().to(queries.device)
        if attention_mask is not None:
            attention_mask = torch.logical_or(prefix_causal_mask,
                                              attention_mask)
        else:
            attention_mask = prefix_causal_mask
        if self.training and prefix_len != query_len:
            hidden_states = F.scaled_dot_product_attention(queries,
                                                           keys,
                                                           values,
                                                           attention_mask
                                                           )
        elif self.training and prefix_len == query_len:
            hidden_states = flash_attn_func(queries.transpose(1, 2),
                                            keys.transpose(1, 2), values.transpose(1, 2),
                                            causal=False).transpose(1, 2)
        else:
            hidden_states = flash_attn_func(queries.transpose(1, 2),
                                            keys.transpose(1, 2), values.transpose(1, 2),
                                            causal=False).transpose(1, 2)
        hidden_states = self.multihead_featscale(hidden_states)
        hidden_states = rearrange(hidden_states, "b h l d -> b l (h d)")
        hidden_states = checkpoint.checkpoint(self.out_layer, hidden_states,
                                              use_reentrant=True)
        hidden_states = self.drop_proj(hidden_states)
        return hidden_states

    def forward(self, hidden_states, attention_mask=None, prefix_len=None,
                prefix_masked_token_ids=None):
        hidden_states = self.forward_features(hidden_states,
                                              attention_mask,
                                              prefix_len,
                                              prefix_masked_token_ids)
        return hidden_states


class MixerLayer(nn.Module):
    def __init__(self,
                 in_features=512,
                 head_features=32,
                 heads=16,
                 expand=4,
                 mixer_type="self_attention",
                 drop_path=0.,
                 drop=0.1,
                 states=64):
        super().__init__()
        self.norm_layer_1 = NormLayer(in_features)
        self.norm_layer_2 = NormLayer(in_features)
        self.drop_path = DropPath(drop_path)
        if mixer_type == "self_attention":
            self.mixer_layer = SelfAttention(in_features,
                                             head_features,
                                             heads)
        self.glu_layer = SwiGLU(in_features,
                                expand * in_features,
                                in_features)
        self.adaptor_layer_1 = None
        self.adaptor_layer_2 = None
        self.drop = drop

    def forward(self, hidden_states, attention_mask=None,
                prefix_len=None, prefix_masked_token_ids=None, num_views=None):
        identity = hidden_states
        hidden_states = self.drop_path(self.mixer_layer(self.norm_layer_1(hidden_states),
                                                        attention_mask=attention_mask,
                                                        prefix_len=prefix_len,
                                                        prefix_masked_token_ids=prefix_masked_token_ids)) + identity
        if self.adaptor_layer_1 is not None:
            hidden_states = self.adaptor_layer_1(hidden_states) + identity
        identity = hidden_states.clone()
        hidden_states = self.norm_layer_2(hidden_states)
        hidden_states = self.glu_layer(hidden_states)
        hidden_states = F.dropout(hidden_states, self.drop, self.training)
        hidden_states = self.drop_path(hidden_states) + identity
        if self.adaptor_layer_2 is not None:
            hidden_states = self.adaptor_layer_2(hidden_states) + identity
        return hidden_states


def get_emb(sin_inp):
    """
    Gets a base embedding for one dimension with sin and cos intertwined
    """
    emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
    return torch.flatten(emb, -2, -1)


class PositionalEncoding2D(nn.Module):
    def __init__(self, channels):
        """
        :param channels: The last dimension of the tensor you want to apply pos emb to.
        """
        super(PositionalEncoding2D, self).__init__()
        self.org_channels = channels
        channels = int(np.ceil(channels / 4) * 2)
        self.channels = channels
        inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
        self.register_buffer("inv_freq", inv_freq)
        self.register_buffer("cached_penc", None, persistent=False)

    def forward(self, tensor):
        """
        :param tensor: A 4d tensor of size (batch_size, x, y, ch)
        :return: Positional Encoding Matrix of size (batch_size, x, y, ch)
        """
        if len(tensor.shape) != 4:
            raise RuntimeError("The input tensor has to be 4d!")

        if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
            return self.cached_penc

        self.cached_penc = None
        batch_size, x, y, orig_ch = tensor.shape
        pos_x = torch.arange(x, device=tensor.device, dtype=self.inv_freq.dtype)
        pos_y = torch.arange(y, device=tensor.device, dtype=self.inv_freq.dtype)
        sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
        sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
        emb_x = get_emb(sin_inp_x).unsqueeze(1)
        emb_y = get_emb(sin_inp_y)
        emb = torch.zeros(
            (x, y, self.channels * 2),
            device=tensor.device,
            dtype=tensor.dtype,
        )
        emb[:, :, : self.channels] = emb_x
        emb[:, :, self.channels: 2 * self.channels] = emb_y

        self.cached_penc = emb[None, :, :, :orig_ch].repeat(tensor.shape[0], 1, 1, 1)
        return self.cached_penc


class PositionalEncoding1D(nn.Module):
    def __init__(self, channels):
        """
        :param channels: The last dimension of the tensor you want to apply pos emb to.
        """
        super(PositionalEncoding1D, self).__init__()
        self.org_channels = channels
        channels = int(np.ceil(channels / 2) * 2)
        self.channels = channels
        inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
        self.register_buffer("inv_freq", inv_freq)
        self.register_buffer("cached_penc", None, persistent=False)

    def forward(self, tensor):
        """
        :param tensor: A 3d tensor of size (batch_size, x, ch)
        :return: Positional Encoding Matrix of size (batch_size, x, ch)
        """
        if len(tensor.shape) != 3:
            raise RuntimeError("The input tensor has to be 3d!")

        if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
            return self.cached_penc

        self.cached_penc = None
        batch_size, x, orig_ch = tensor.shape
        pos_x = torch.arange(x, device=tensor.device, dtype=self.inv_freq.dtype)
        sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
        emb_x = get_emb(sin_inp_x)
        emb = torch.zeros((x, self.channels), device=tensor.device, dtype=tensor.dtype)
        emb[:, : self.channels] = emb_x

        self.cached_penc = emb[None, :, :orig_ch].repeat(batch_size, 1, 1)
        return self.cached_penc


class BeatrixEncoder(nn.Module):
    def __init__(self, depth=16, in_features=512, head_features=8, heads=8, expand=4):
        super().__init__()
        self.cls_token = nn.Parameter(0.02 * torch.randn(1, 1, in_features))
        self.mixer_layers = nn.ModuleList(
            [
                MixerLayer(in_features, head_features, heads, expand)
                for _ in range(depth)
            ]
        )

    #  self._init_params()
    @torch.jit.ignore
    def no_weight_decay(self):
        return {'cls_token'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'cls_token'}

    def _init_params(self):
        for _, m in self.named_modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, 0.02)
                nn.init.constant_(m.bias, 0.)

    @staticmethod
    def prepare_prefix(x, ratio=0.75):
        batch_size, seq_len, features = x.size()
        len_keep = int(ratio * seq_len)
        noise = torch.rand(batch_size, seq_len, device=x.device)  # noise in [0, 1]
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)
        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, features))
        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([batch_size, seq_len], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)
        return x_masked, mask, ids_restore

    def inference(self,
                  x,
                  frames=1):
        actual_prefix_len = x.size(1)
        batch_size = x.size(0)
        x = torch.cat((self.cls_token.repeat(batch_size, 1, 1),
                       x), dim=-2)
        for layer_id, layer in enumerate(self.mixer_layers):
            x = layer(x, prefix_len=actual_prefix_len + 1)
        cls_token, x = x[:, 0, :].unsqueeze(1), x[:, 1:, :]
        return {"cls_token": cls_token,
                "hidden_states": x}

    def forward(self, x: torch.Tensor, num_views: int = None):
        prefix_len = random.randint(32, x.size(-2) - 2)
        prefix, suffix = (
            x[:, 0:prefix_len, :],
            x[:, prefix_len:, :]
        )
        ratio = random.choice([0.75, 0.775, 0.80, 0.825, 0.85])
        prefix, mask, prefix_ids_restore = self.prepare_prefix(
            prefix,
            ratio,
        )
        #  print("ori_prefix", prefix.shape)
        actual_prefix_len = prefix.size(1)
        batch_size = prefix.size(0)
        x = torch.cat((self.cls_token.repeat(batch_size, 1, 1),
                       prefix, suffix), dim=-2)
        for layer_id, layer in enumerate(self.mixer_layers):
            x = layer(x, prefix_len=actual_prefix_len + 1, num_views=num_views)
        cls_token, x = x[:, 0, :].unsqueeze(1), x[:, 1:, :]
        prefix, suffix = (
            x[:, 0:actual_prefix_len, :],
            x[:, actual_prefix_len:, :]
        )
        return {
            "cls_token": cls_token,
            "prefix": prefix,
            "suffix": suffix,
            "mask": mask,
            "restore_ids": prefix_ids_restore,
        }


class BeatrixDecoder(nn.Module):
    def __init__(self, depth=4, enc_features: int = 768, in_features=512, head_features=8, heads=8, expand=4):
        super().__init__()

        self.mask_token = nn.Parameter(0.02 * torch.randn(1, 1, in_features))

        self.mixer_layers = nn.ModuleList(
            [
                MixerLayer(in_features, head_features, heads, expand)
                for _ in range(depth)
            ]
        )
        self.dec_emb = nn.Linear(enc_features, in_features)
        self.dec_norm = NormLayer(in_features)
        self.dec_pos_emb = PositionalEncoding1D(in_features)

    #     self._init_params()

    def _init_params(self):
        for _, m in self.named_modules():
            if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d):
                nn.init.trunc_normal_(m.weight, 0.02)
                nn.init.constant_(m.bias, 0.)

    def forward(self,
                cls_token,
                prefix,
                suffix,
                restore_ids,
                ):
        # append mask tokens to sequence
        prefix = self.dec_emb(prefix)
        suffix = self.dec_emb(suffix)
        prefix = self.dec_norm(prefix)
        suffix = self.dec_norm(suffix)
        cls_token = self.dec_norm(self.dec_emb(cls_token))
        prefix_len = prefix.size(-2)
        mask_tokens = self.mask_token.repeat(prefix.shape[0], restore_ids.shape[1] + 1 - prefix.shape[1], 1)
        x_ = torch.cat([prefix[:, 1:, :], mask_tokens], dim=1)  # no cls token
        x_ = torch.gather(x_, dim=1, index=restore_ids.unsqueeze(-1).repeat(1, 1,
                                                                            prefix.shape[2]))  # unshuffle
        prefix = torch.cat((cls_token, x_), dim=-2)
        actual_prefix_len = prefix.size(-2)
        x = torch.cat((prefix, suffix), dim=-2)
        x = x + self.dec_pos_emb(x)
        for layer_id, layer in enumerate(self.mixer_layers):
            x = layer(x, prefix_len=actual_prefix_len)

        cls_token, prefix, suffix = (
            x[:, :1, :],
            x[:, 1:actual_prefix_len, :],
            x[:, actual_prefix_len:, ]
        )

        return {
            "cls_token": cls_token,
            "prefix": prefix,
            "suffix": suffix,
        }


def load_large_beatrix(ckpt=None):
    model = FactorizedBeatrixPretrainedModel()
    if ckpt is not None:
        model.load_state_dict(torch.load(ckpt), strict=False)

    spatial_transformer = copy.deepcopy(model.spatial_encoder.mixer_layers[-1])
    for _ in range(8):
        model.spatial_encoder.mixer_layers.append(copy.deepcopy(
            spatial_transformer))
    view_transformer = copy.deepcopy(model.multview_encoder[-1])
    for _ in range(4):
        model.multview_encoder.append(copy.deepcopy(view_transformer))
    return model


def load_base_beatrix(ckpt="/data/users/zhengruizhe/34bckpt_4.pth"):
    model = FactorizedBeatrixPretrainedModel()
    if ckpt is not None:
        model.load_state_dict(torch.load(ckpt), strict=False)
    return model


class FactorizedBeatrixPretrainedModel(nn.Module):
    def __init__(self,
                 enc_depth=12,
                 token_dim=512,
                 enc_dim=768,
                 dec_dim=512,
                 dec_depth=8,
                 token_merging_fold=4,
                 head_features=64,
                 heads=16,
                 expand=4):
        super().__init__()
        self.tokenizer = Tokenizer()
        self.encoder_pos_emb = PositionalEncoding2D(enc_dim)
        self.token_projector = nn.Linear(
            token_dim, enc_dim)
        self.token_projector_norm = NormLayer(enc_dim)

        self.spatial_encoder = BeatrixEncoder(
            enc_depth,
            enc_dim,
            enc_dim // heads,
            heads,
            expand
        )
        self.multiview_token_projector = nn.Linear(enc_dim, enc_dim)
        self.multiview_token_projector_norm = nn.Linear(enc_dim, enc_dim)

        self.multview_encoder = nn.ModuleList(
            [
                MixerLayer(enc_dim,
                           enc_dim // heads,
                           heads,
                           expand)
                for _ in range(2)
            ]
        )

        self.decoder = BeatrixDecoder(
            dec_depth,
            enc_dim,
            dec_dim,
            dec_dim // heads,
            heads,
            expand
        )

        self.token_predictor_norm = NormLayer(dec_dim)
        self.token_predictor = nn.Linear(
            dec_dim, token_dim
        )

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'token_predictor',
                'token_projector',
                'multiview_token_projector'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'token_predictor',
                'token_projector',
                'multiview_token_projector'}

    def forward(self, x: torch.Tensor):
        if x.ndim == 5:
            B, T, C, H, W = x.shape
            x = x.reshape(-1, C, H, W)
            with torch.no_grad():
                x = self.tokenizer.encode(x)
            _, H, W, D = x.shape
            x = x.reshape(B, T, H, W, D)
            x = rearrange(x, "B T H W D -> B H (T W) D")
            P = 1
        elif x.ndim == 6:
            B, T, P, C, H, W = x.shape
            with torch.no_grad():
                x = self.tokenizer.encode(x.reshape(-1, C, H, W))
            _, H, W, D = x.shape
            x = x.reshape(B, T, P, H, W, D)
            x = rearrange(x, "B T F H W D -> (B F) H (T W) D")
        x.requires_grad = True
        with torch.autocast(device_type="cuda", dtype=torch.float32):
            H, W = x.size(-3), x.size(-2)
            x = rearrange(x, "(B F) H W D -> (B F) (H W) D", F=P)
            x = F.normalize(x, dim=-1)
            targets = x.clone()
            x = self.token_projector(x)
            x = F.dropout(x, 0.1, self.training)
            x = self.token_projector_norm(x)
            x = rearrange(x, "B (H W F) D -> (B F) H W D", H=H, W=W)
            x = x + self.encoder_pos_emb(x)
            x = rearrange(x, "(B F) H W D -> (B F) (H W) D", H=H, W=W, F=P)
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            x = x.to(torch.float16)
            enc_dict = self.spatial_encoder(x)
            cls_token, prefix, suffix, restore_ids, mask = enc_dict["cls_token"], enc_dict["prefix"], enc_dict[
                "suffix"], enc_dict["restore_ids"], enc_dict["mask"]
            prefix_len = prefix.size(-2)
            suffix_len = suffix.size(-2)
            x = torch.cat((cls_token, prefix, suffix), dim=-2)
            x = checkpoint.checkpoint(self.multiview_token_projector, x)
            x = F.dropout(x, 0.1, self.training)
            x = self.multiview_token_projector_norm(x)
            _, L, D = x.shape
            x = rearrange(x, "(B F) L D -> (B L) F D", F=P)
            for layer_idx, layer in enumerate(self.multview_encoder):
                x = layer(x, prefix_len=x.size(-2))
            x = rearrange(x, "(B L) F D -> (B F) L D", L=L)
            cls_token = x[:, :1, :]
            prefix = x[:, 1:1 + prefix_len, :]
            suffix = x[:, 1 + prefix_len:, :]
            dec_dict = self.decoder(
                cls_token,
                prefix,
                suffix,
                restore_ids,
            )
        cls_token, prefix, suffix = dec_dict["cls_token"], dec_dict["prefix"], dec_dict["suffix"]
        with torch.autocast(device_type="cuda", dtype=torch.float32):
            prefix = self.token_predictor_norm(prefix.float())
            suffix = self.token_predictor_norm(suffix.float())
            prefix = self.token_predictor(prefix)
            prefix_len = prefix.size(-2)
            suffix = self.token_predictor(suffix)
            prefix = F.normalize(prefix, dim=-1)
            suffix = F.normalize(suffix, dim=-1)
            #  mean = targets.mean(dim=-1, keepdim=True)
            #  var = targets.var(dim=-1, keepdim=True)
            #  targets = (targets - mean) / ((var + 1e-5) ** 0.5)
            # print("prefix: ", prefix.shape)
            # print("suffix: ", suffix.shape)
            #  print("mask: ", mask.shape)
            #  print(prefix_len)
            #  print("tar", targets.shape)
            prefix_loss = (prefix - targets[:, :prefix_len, :]) ** 2
            prefix_loss = prefix_loss.mean(dim=-1)
            prefix_loss = (prefix_loss * mask).sum() / mask.sum()
            prefix_loss = prefix_loss.sum() / mask.sum()
            suffix_loss = (suffix[:, :-1, :] - targets[:, prefix_len + 1:, ]) ** 2
            suffix_loss = suffix_loss.mean(dim=-1).sum()

        return {"loss": prefix_loss + suffix_loss}


class PromptGenerator(nn.Module):
    def __init__(self, hidden_features: int = 768,
                 heads=16,
                 head_features=768 // 16,
                 num_prompts=32,
                 kernel_size=4):
        super().__init__()
        self.prompts = nn.Parameter(0.02 *
                                    torch.randn(1,
                                                num_prompts,
                                                heads * head_features),
                                    requires_grad=True)
        self.to_kv = nn.Linear(hidden_features,
                               heads * head_features * 2,
                               )
        self.heads = heads
        self.norm = NormLayer(hidden_features)
        self.out_layer = nn.Linear(heads * head_features,
                                   hidden_features)
        self.conv = nn.Conv1d(hidden_features,
                              hidden_features,
                              kernel_size=(kernel_size,))

    def forward(self, context):
        context = self.conv(context.permute(0, 2, 1)).permute(0, 2, 1)
        context = self.norm(context)
        keys, values = self.to_kv(context).chunk(2, dim=-1)
        b, l, d = context.shape
        prompts = self.prompts.repeat(b, 1, 1)
        prompts, keys, values = map(lambda t:
                                    rearrange(t, "b l (h d) -> b h l d", h=self.heads),
                                    (prompts, keys, values))
        keys, values = map(lambda t: F.normalize(t), (keys, values))
        prompts = F.scaled_dot_product_attention(prompts, keys, values)
        prompts = rearrange(prompts, "b h l d -> b l (h d)")
        return self.out_layer(prompts)


class InstanceSpecificPredictor(nn.Module):
    def __init__(self, in_features=768,
                 hidden_features=768,
                 out_features=768):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.activation = nn.SiLU()
        self.fc2 = nn.Linear(hidden_features, out_features * 2)

    def forward(self, x):
        x = self.fc1(x)
        x = checkpoint.checkpoint(self.activation, x)
        x = checkpoint.checkpoint(self.fc2, x)
        mu, sigma = x.chunk(2, dim=-1)
        return mu, sigma


class FactorizedBeatrixFinetuneDomainGeneralizableModel(nn.Module):
    def __init__(self,
                 enc_dim=768,
                 num_classes=2,
                 functional_connectivity_features: int = 666,
                 drop=0.1,
                 lora=True,
                 r=16,
                 prompt_probing=True):
        super().__init__()
        model = FactorizedBeatrixPretrainedModel()
        model.load_state_dict(torch.load("/data/users/"
                                         "34bckpt_4.pth",
                                         map_location="cpu"))  # strict=False)
        self.tokenizer = model.tokenizer.eval()
        self.encoder_pos_emb = model.encoder_pos_emb
        self.token_projector = model.token_projector
        self.token_projector_norm = model.token_projector_norm
        spatial_encoder = model.spatial_encoder
        multview_encoder = model.multview_encoder
        self.multiview_token_projector = model.multiview_token_projector
        self.multiview_token_projector_norm = model.multiview_token_projector_norm
        self.classification_head = nn.Linear(enc_dim,
                                             num_classes)
        self.classification_norm = NormLayer(enc_dim)
        self.drop = nn.Dropout(drop)
        self.ce_loss_func = nn.CrossEntropyLoss(label_smoothing=5e-5)
        self.prompt_probing = prompt_probing
        self.lora = lora
        if self.prompt_probing:
            for p in spatial_encoder.parameters():
                p.requires_grad = False
            for p in multview_encoder.parameters():
                p.requires_grad = False
            for p in self.multiview_token_projector_norm.parameters():
                p.requires_grad = True
            for p in self.multiview_token_projector.parameters():
                p.requires_grad = False
            for p in self.token_projector.parameters():
                p.requires_grad = False
            for p in self.token_projector_norm.parameters():
                p.requires_grad = True
            self.prompts = nn.Parameter(0.02 * torch.randn(1, 32,
                                                           enc_dim),
                                        requires_grad=True)
        self.lora = lora
        self.spatial_filtering = None
        if self.lora:
            spatial_encoder = apply_lora(copy.deepcopy(spatial_encoder),
                                         r=r, alpha=r)
            multview_encoder = apply_lora(copy.deepcopy(multview_encoder),
                                          r=r, alpha=r)
            self.spatial_encoder = apply_adaptor(spatial_encoder, r=r)
            self.multview_encoder = apply_adaptor(multview_encoder, r=r)
            for _, m in self.spatial_encoder.named_modules():
                if isinstance(m, NormLayer):
                    for p in m.parameters():
                        p.requires_grad = True
            for _, m in self.multview_encoder.named_modules():
                if isinstance(m, NormLayer):
                    for p in m.parameters():
                        p.requires_grad = True
        heads = 16
        self.instance_specific_predictor = InstanceSpecificPredictor(

        )
        self.prompt_generator_1 = PromptGenerator(enc_dim)
        self.prompt_generator_2 = PromptGenerator(enc_dim)
        if functional_connectivity_features is not None:
            self.conn_projector = nn.Linear(functional_connectivity_features,
                                            enc_dim)
            self.conn_projector_norm = NormLayer(enc_dim)
            self.conn_encoder = MixerLayer(enc_dim, enc_dim // heads, heads)

    def no_weight_decay(self):
        if not self.prompt_probing:
            return {'token_predictor',
                    'token_projector',
                    'multiview_token_projector'}
        else:
            return {"prompts",
                    'token_predictor',
                    'token_projector',
                    'multiview_token_projector'
                    }

    def conn_encode(self, conn):
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            conn = self.conn_projector(conn)
            conn = self.conn_projector_norm(conn)
            conn = self.conn_encoder(conn)
        return conn

    def encode(self, x):
        if x.ndim == 5:
            B, T, C, H, W = x.shape
            x = x.reshape(-1, C, H, W)
            with torch.no_grad():
                x = self.tokenizer.encode(x)
            _, H, W, D = x.shape
            x = x.reshape(B, T, H, W, D)
            x = rearrange(x, "B T H W D -> B H (T W) D")
            P = 1
        elif x.ndim == 6:
            B, T, P, C, H, W = x.shape
            with torch.no_grad():
                x = self.tokenizer.encode(x.reshape(-1, C, H, W))
            _, H, W, D = x.shape
            x = x.reshape(B, T, P, H, W, D)
        # x = rearrange(x, "B T F H W D -> (B F) H (T W) D")
        return x, P

    def forward(self,
                x: torch.Tensor,
                conn_tokens: torch.Tensor,
                labels: torch.LongTensor = None,
                pretokenized=True):
        if not pretokenized:
            if self.spatial_filtering is not None:
                x = rearrange(x, "B T F C H W -> B T C H W F")
                x = self.spatial_filtering(x)
                x = rearrange(x, "B T C H W F -> B T F C H W")
                P = x.size(2)
            x, P = self.encode(x)
            x = rearrange(x, "B T F H W D -> (B F) H (T W) D")
            x.requires_grad = True
        else:
            P = x.size(2)
            if self.spatial_filtering is not None:
                x = rearrange(x, "B T F H W D -> B T H W D F")
                x = self.spatial_filtering(x)
                x = rearrange(x, "B T H W D F -> B T F H W D")
                P = x.size(2)
            x = rearrange(x, "B T F H W D -> (B F) H (T W) D")

        if not x.requires_grad:
            x.requires_grad = True
        if conn_tokens is not None:
            if not conn_tokens.requires_grad:
                conn_tokens.requires_grad = True
            conn_tokens = self.conn_encode(conn_tokens)

        with torch.autocast(device_type="cuda", dtype=torch.float32):
            H, W = x.size(-3), x.size(-2)
            x = rearrange(x, "(B F) H W D -> (B F) (H W) D", F=P)
            x = F.normalize(x, dim=-1)
            x = self.token_projector(x)
            #  x = self.drop(x)
            x = self.token_projector_norm(x)
            x = rearrange(x, "B (H W F) D -> (B F) H W D", H=H, W=W)
            x = x + self.encoder_pos_emb(x)
            x = rearrange(x, "(B F) H W D -> (B F) (H W) D", H=H, W=W, F=P)

        with torch.autocast(device_type="cuda", dtype=torch.float16):
            x = x.to(torch.float16)
            if self.prompt_probing:
                prompts = self.prompts.repeat(x.size(0), 1, 1)
                instance_specific_prompts_1 = self.prompt_generator_1(x)
                instance_specific_prompts_2 = self.prompt_generator_2(conn_tokens)
                b1, l1, d1 = instance_specific_prompts_1.shape
                instance_specific_prompts_2 = instance_specific_prompts_2.repeat(
                    b1 // instance_specific_prompts_2.size(0), 1, 1
                )
                prompts = prompts + instance_specific_prompts_1 + instance_specific_prompts_2
                x = torch.cat((prompts, x), dim=-2)
            enc_dict = self.spatial_encoder.inference(x)
            cls_token, x = enc_dict["cls_token"], enc_dict["hidden_states"]
            x = torch.cat((cls_token, x), dim=-2)
            x = self.multiview_token_projector(x)
            x = self.multiview_token_projector_norm(x)
            _, L, D = x.shape
            x = rearrange(x, "(B F) L D -> (B L) F D", F=P)
            for layer_idx, layer in enumerate(self.multview_encoder):
                x = layer(x, prefix_len=x.size(-2))
            x = rearrange(x, "(B L) F D -> B F L D", L=L, F=P)
        x_mean = x.mean(dim=(1, 2))
        cls_token, x = (
            x[:, :, :1, :],
            x[:, :, :1, :]
        )
        if self.prompt_probing:
            prompts = x[:, :, 0:32, :]
            cls_token = torch.cat((cls_token, prompts), dim=-2)
            # cls_token = torch.cat((cls_token), dim=-2)
            #      cls_token = self.drop(cls_token)
            cls_token = cls_token.float().mean(dim=(1, 2)).squeeze()
        #  cls_token = self.drop(cls_token)
        else:
            cls_token = cls_token.float().mean(dim=1).squeeze()
        with torch.autocast(device_type="cuda", dtype=torch.float32):
            cls_token = self.classification_norm(cls_token)
            weight = self.classification_head.weight  # (2, 768)
            bias = self.classification_head.bias
            batch_size = cls_token.size(0)
            mu, var = self.instance_specific_predictor(x_mean)
            # regularizer = torch.sum(mu.pow(2)) + torch.sum(mu.pow(2))
            #  weight = F.normalize(weight, dim=-1).unsqueeze(0) * \
            #            var.unsqueeze(1) + mu.unsqueeze(1)
            #   logits = torch.bmm(weight, cls_token.unsqueeze(-1)).squeeze(-1) + bias.unsqueeze(0)
            cls_token = cls_token * var + mu
            logits = self.classification_head(cls_token)
            loss = self.ce_loss_func(logits, labels.long())  # + 5e-4 * regularizer / batch_size

        return {"logits": logits,
                "loss": loss}


class FactorizedBeatrixFinetunedModel(nn.Module):
    def __init__(self, enc_dim=768,
                 num_classes=2,
                 drop=0.1,
                 lora=True,
                 r=16,
                 prompt_probing=True,
                 temporal_prompt=False,
                 freeze_norms=False,
                 freeze_adaptors=False,
                 num_queries=32):
        super().__init__()
        from m_wavelet_tree_network import WaveletTreeNet
        self.temporal_prompt = temporal_prompt
        self.num_queries = num_queries
        if self.temporal_prompt:
            self.prompt_generator = WaveletTreeNet(num_queries=num_queries)
            self.projection_head = nn.Sequential(
                nn.Linear(768, 768),
                nn.BatchNorm1d(768),
                nn.ReLU(),
                nn.Linear(768, 768),
                nn.BatchNorm1d(768),
                nn.ReLU(),
                nn.Linear(768, 128),
            )
        model = FactorizedBeatrixPretrainedModel()
        # model.load_state_dict(torch.load("/data/users/34bckpt_2.pth"),
        #                       strict=False)
        model.load_state_dict(torch.load("/data/users/35bckpt_.pth", map_location="cpu"))  # strict=False)
        self.tokenizer = model.tokenizer.eval()
        self.encoder_pos_emb = model.encoder_pos_emb
        self.token_projector = model.token_projector
        self.token_projector_norm = model.token_projector_norm
        spatial_encoder = model.spatial_encoder
        multview_encoder = model.multview_encoder
        self.multiview_token_projector = model.multiview_token_projector
        self.multiview_token_projector_norm = model.multiview_token_projector_norm
        self.classification_head = nn.Linear(enc_dim,
                                             num_classes)
        self.classification_norm = NormLayer(enc_dim)
        self.drop = nn.Dropout(drop)
        self.ce_loss_func = nn.CrossEntropyLoss(label_smoothing=5e-5)
        self.prompt_probing = prompt_probing
        self.lora = lora
        if self.prompt_probing:
            for p in spatial_encoder.parameters():
                p.requires_grad = False
            for p in multview_encoder.parameters():
                p.requires_grad = False
            for p in self.multiview_token_projector_norm.parameters():
                p.requires_grad = True
            for p in self.multiview_token_projector.parameters():
                p.requires_grad = False
            for p in self.token_projector.parameters():
                p.requires_grad = False
            for p in self.token_projector_norm.parameters():
                p.requires_grad = True
            self.prompts = nn.Parameter(0.02 * torch.randn(1, num_queries,
                                                           enc_dim),
                                        requires_grad=True)
        self.lora = lora
        self.spatial_filtering = None
        if self.lora:
            spatial_encoder = apply_lora(copy.deepcopy(spatial_encoder),
                                         r=r, alpha=r)
            multview_encoder = apply_lora(copy.deepcopy(multview_encoder),
                                          r=r, alpha=r)
            if not freeze_adaptors:
                spatial_encoder = apply_adaptor(spatial_encoder, r=r)
                multview_encoder = apply_adaptor(multview_encoder, r=r)
            self.spatial_encoder =spatial_encoder
            self.multview_encoder=multview_encoder
            for _, m in self.spatial_encoder.named_modules():
                if isinstance(m, NormLayer):
                    for p in m.parameters():
                        p.requires_grad = True
                    if freeze_norms:
                        p.requires_grad = False
            for _, m in self.multview_encoder.named_modules():
                if isinstance(m, NormLayer):
                    for p in m.parameters():
                        p.requires_grad = True
                    if freeze_norms:
                        p.requires_grad = False

    @staticmethod
    def uniformity(x, t=0.005):  # t=0.02
        sq_dist = torch.pdist(x, p=2).pow(2)
        return sq_dist.mul(-t).exp().mean().log()

    @staticmethod
    def wasserstein_uniformity(x, y):
        mx = x.mean(dim=-1)
        my = y.mean(dim=-1)
        cx = x.t() @ x / x.size(0)
        cy = y.t() @ y / y.size(0)
        return empirical_bures_wasserstein_distance(
            mx, my, cx, cy
        )

    def no_weight_decay(self):
        if not self.prompt_probing:
            return {'token_predictor',
                    'token_projector',
                    'multiview_token_projector'}
        else:
            return {"prompts",
                    'token_predictor',
                    'token_projector',
                    'multiview_token_projector'
                    }

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        if not self.prompt_probing:
            return {'token_predictor',
                    'token_projector',
                    'multiview_token_projector'}
        else:
            return {"prompts",
                    'token_predictor',
                    'token_projector',
                    'multiview_token_projector'
                    }

    def encode(self, x):
        if x.ndim == 5:
            B, T, C, H, W = x.shape
            x = x.reshape(-1, C, H, W)
            with torch.no_grad():
                x = self.tokenizer.encode(x)
            _, H, W, D = x.shape
            x = x.reshape(B, T, H, W, D)
            x = rearrange(x, "B T H W D -> B H (T W) D")
            P = 1
        elif x.ndim == 6:
            B, T, P, C, H, W = x.shape
            with torch.no_grad():
                x = self.tokenizer.encode(x.reshape(-1, C, H, W))
            _, H, W, D = x.shape
            x = x.reshape(B, T, P, H, W, D)
        # x = rearrange(x, "B T F H W D -> (B F) H (T W) D")
        return x, P

    def forward(self, x: torch.Tensor,
                labels: torch.LongTensor = None,
                ts: torch.Tensor = None,
                pretokenized=True,
                use_no_contrastive_loss=False):

        if self.temporal_prompt and ts is not None:
            temporal_prompts = self.prompt_generator(ts)
        else:
            temporal_prompts = None
        if not pretokenized:
            if self.spatial_filtering is not None:
                x = rearrange(x, "B T F C H W -> B T C H W F")
                x = self.spatial_filtering(x)
                x = rearrange(x, "B T C H W F -> B T F C H W")
                P = x.size(2)
            x, P = self.encode(x)
            x = rearrange(x, "B T F H W D -> (B F) H (T W) D")
            x.requires_grad = True
        else:
            P = x.size(2)
            if self.spatial_filtering is not None:
                x = rearrange(x, "B T F H W D -> B T H W D F")
                x = self.spatial_filtering(x)
                x = rearrange(x, "B T H W D F -> B T F H W D")
                P = x.size(2)
            x = rearrange(x, "B T F H W D -> (B F) H (T W) D")
            if not x.requires_grad:
                x.requires_grad = True
        #        x.requires_grad = True

        with torch.autocast(device_type="cuda", dtype=torch.float32):
            H, W = x.size(-3), x.size(-2)
            x = rearrange(x, "(B F) H W D -> (B F) (H W) D", F=P)
            x = F.normalize(x, dim=-1)
            x = self.token_projector(x)
            #  x = self.drop(x)
            x = self.token_projector_norm(x)
            x = rearrange(x, "B (H W F) D -> (B F) H W D", H=H, W=W)
            x = x + self.encoder_pos_emb(x)
            x = rearrange(x, "(B F) H W D -> (B F) (H W) D", H=H, W=W, F=P)

        if temporal_prompts is not None:
            temporal_prompts = temporal_prompts.unsqueeze(1).repeat(1, P, 1, 1)
            temporal_prompts = rearrange(temporal_prompts, "B F L D -> (B F) L D")
            x = torch.cat((temporal_prompts, x), dim=1)

        with torch.autocast(device_type="cuda", dtype=torch.float16):
            x = x.to(torch.float16)
            if self.prompt_probing:
                prompts = self.prompts.repeat(x.size(0), 1, 1)
                x = torch.cat((prompts, x), dim=-2)
            enc_dict = self.spatial_encoder.inference(x)
            cls_token, x = enc_dict["cls_token"], enc_dict["hidden_states"]
            x = torch.cat((cls_token, x), dim=-2)
            x = self.multiview_token_projector(x)
            #  x = self.drop(x)
            x = self.multiview_token_projector_norm(x)
            _, L, D = x.shape
            x = rearrange(x, "(B F) L D -> (B L) F D", F=P)
            for layer_idx, layer in enumerate(self.multview_encoder):
                x = layer(x, prefix_len=x.size(-2))
            x = rearrange(x, "(B L) F D -> B F L D", L=L, F=P)
        x = x.to(torch.float32)
        cls_token, x = (
            x[:, :, :1, :],
            x[:, :, 1:, :]
        )
        if self.prompt_probing:
            #   prompts = x[:, :, :, :] #
            if self.temporal_prompt:
                prompts = x[:, :, 0:self.num_queries * 2, :]
            else:
                prompts = x[:, :, 0:self.num_queries, :]
            if self.temporal_prompt:
                #       print(prompts.shape)
                temp_feat, freq_feat = prompts.chunk(2, dim=-2)
                temp_feat, freq_feat = map(lambda t: F.normalize(t.mean(dim=(1, 2)), dim=-1), (temp_feat, freq_feat))
                #     print(temp_feat.shape)
                temp_feat, freq_feat = self.projection_head(temp_feat), self.projection_head(freq_feat)
                temp_feat, freq_feat = map(lambda t: F.normalize(t), (temp_feat, freq_feat))
                alignment_loss = F.mse_loss(temp_feat, freq_feat)
                uniformity_loss = 0.5 * (self.uniformity(temp_feat) + self.uniformity(freq_feat))

            cls_token = torch.cat((cls_token, prompts), dim=-2)
            cls_token = self.drop(cls_token)
            cls_token = cls_token.float().mean(dim=(1, 2)).squeeze()
        #  cls_token = self.drop(cls_token)
        else:
            cls_token = cls_token.float().mean(dim=1).squeeze()
        with torch.autocast(device_type="cuda", dtype=torch.float32):
            cls_token = self.classification_norm(cls_token)
            logits = self.classification_head(cls_token)
            loss = self.ce_loss_func(logits, labels.long())

            if self.temporal_prompt:
                if use_no_contrastive_loss:
                    loss = loss
                else:
                    loss = 2.0 * loss + 0.5 * (alignment_loss + 0.02 * uniformity_loss)
        return {"logits": logits,
                "features": cls_token.detach(),
                "loss": loss}


class FactorizedBeatrixFinetunedDomainInvariantModel(nn.Module):
    def __init__(self, enc_dim=768,
                 num_classes=2,
                 drop=0.1,
                 lora=True,
                 r=16,
                 prompt_probing=True,
                 temporal_prompt=False,
                 freeze_norms=False,
                 num_environments=4,
                 num_queries=32):
        super().__init__()
        from m_wavelet_tree_network import WaveletTreeNet
        self.temporal_prompt = temporal_prompt
        self.num_queries = num_queries
        if self.temporal_prompt:
            self.b = nn.Parameter(torch.as_tensor(0.))
            self.tau = nn.Parameter(torch.as_tensor(0.02))
            self.prompt_generator = WaveletTreeNet(num_queries=num_queries)
            self.mask_generator = nn.Sequential(
                nn.Linear(768, 768),
                nn.BatchNorm1d(768),
                nn.ReLU(),
                nn.Linear(768, 768),
                nn.BatchNorm1d(768),
                nn.ReLU(),
                nn.Linear(768, 768*2),
            )
            self.projection_head_1 = nn.Sequential(
                nn.Linear(768, 768),
                nn.BatchNorm1d(768),
                nn.ReLU(),
                nn.Linear(768,768),
                nn.BatchNorm1d(768),
                nn.ReLU(),
                nn.Linear(768, 128)
            )
            self.projection_head_2 = copy.deepcopy(self.projection_head_1)

            self.heaviside = Heaviside.apply
            from timm.utils.model_ema import ModelEmaV2
            self.target_projection_head_1 = ModelEmaV2(self.projection_head_1, 0.99)
            self.target_projection_head_2 = ModelEmaV2(self.projection_head_2, 0.99)

            self.environment_predictor_1 = nn.Sequential(
                nn.Linear(768, 768),
                nn.BatchNorm1d(768),
                nn.ReLU(),
                nn.Linear(768, 768),
                nn.BatchNorm1d(768),
                nn.ReLU(),
                nn.Linear(768, num_environments),
            )
            self.environment_predictor_2 = copy.deepcopy(self.environment_predictor_1)
        model = FactorizedBeatrixPretrainedModel()
        # model.load_state_dict(torch.load("/data/users/34bckpt_2.pth"),
        #                       strict=False)
        model.load_state_dict(torch.load("/data/users/35bckpt_.pth", map_location="cpu"))  # strict=False)
        self.tokenizer = model.tokenizer.eval()
        self.encoder_pos_emb = model.encoder_pos_emb
        self.token_projector = model.token_projector
        self.token_projector_norm = model.token_projector_norm
        spatial_encoder = model.spatial_encoder
        multview_encoder = model.multview_encoder
        self.multiview_token_projector = model.multiview_token_projector
        self.multiview_token_projector_norm = model.multiview_token_projector_norm
        self.classification_head = nn.Linear(enc_dim,
                                             num_classes)
        self.classification_norm = NormLayer(enc_dim)
        self.drop = nn.Dropout(drop)
        from focus_loss import FocalLoss
        if num_classes > 2:
            weights= [58587,18893,60592,11300,22763]
            weights = torch.tensor(weights) / sum(weights)
            self.ce_loss_func = FocalLoss(
                alpha=weights,
                gamma=2,
            )
        self.ce_loss_func = nn.CrossEntropyLoss(label_smoothing=5e-5)
        self.prompt_probing = prompt_probing
        self.lora = lora
        if self.prompt_probing:
            for p in spatial_encoder.parameters():
                p.requires_grad = False
            for p in multview_encoder.parameters():
                p.requires_grad = False
            for p in self.multiview_token_projector_norm.parameters():
                p.requires_grad = True
            for p in self.multiview_token_projector.parameters():
                p.requires_grad = False
            for p in self.token_projector.parameters():
                p.requires_grad = False
            for p in self.token_projector_norm.parameters():
                p.requires_grad = True
            self.prompts = nn.Parameter(0.02 * torch.randn(1, num_queries,
                                                           enc_dim),
                                        requires_grad=True)
        self.lora = lora
        self.spatial_filtering = None
        if self.lora:
            spatial_encoder = apply_lora(copy.deepcopy(spatial_encoder),
                                         r=r, alpha=r)
            multview_encoder = apply_lora(copy.deepcopy(multview_encoder),
                                          r=r, alpha=r)
            self.spatial_encoder = apply_adaptor(spatial_encoder, r=r)
            self.multview_encoder = apply_adaptor(multview_encoder, r=r)
            for _, m in self.spatial_encoder.named_modules():
                if isinstance(m, NormLayer):
                    for p in m.parameters():
                        p.requires_grad = True
                    if freeze_norms:
                        p.requires_grad = False
            for _, m in self.multview_encoder.named_modules():
                if isinstance(m, NormLayer):
                    for p in m.parameters():
                        p.requires_grad = True
                    if freeze_norms:
                        p.requires_grad = False

    @staticmethod
    def uniformity(x, t=0.01):  # t=0.02
        sq_dist = torch.pdist(x, p=2).pow(2)
        return sq_dist.mul(-t).exp().mean().log()

    @staticmethod
    def wasserstein_uniformity(x, y):
        mx = x.mean(dim=-1)
        my = y.mean(dim=-1)
        cx = x.t() @ x / x.size(0)
        cy = y.t() @ y / y.size(0)
        return empirical_bures_wasserstein_distance(
            mx, my, cx, cy
        )

    def no_weight_decay(self):
        if not self.prompt_probing:
            return {'token_predictor',
                    'token_projector',
                    'multiview_token_projector'}
        else:
            return {"prompts",
                    'token_predictor',
                    'token_projector',
                    'multiview_token_projector'
                    }

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        if not self.prompt_probing:
            return {'token_predictor',
                    'token_projector',
                    'multiview_token_projector'}
        else:
            return {"prompts",
                    'token_predictor',
                    'token_projector',
                    'multiview_token_projector'
                    }

    def encode(self, x):
        if x.ndim == 5:
            B, T, C, H, W = x.shape
            x = x.reshape(-1, C, H, W)
            with torch.no_grad():
                x = self.tokenizer.encode(x)
            _, H, W, D = x.shape
            x = x.reshape(B, T, H, W, D)
            x = rearrange(x, "B T H W D -> B H (T W) D")
            P = 1
        elif x.ndim == 6:
            B, T, P, C, H, W = x.shape
            with torch.no_grad():
                x = self.tokenizer.encode(x.reshape(-1, C, H, W))
            _, H, W, D = x.shape
            x = x.reshape(B, T, P, H, W, D)
        # x = rearrange(x, "B T F H W D -> (B F) H (T W) D")
        return x, P

    def forward(self, x: torch.Tensor,
                labels: torch.LongTensor = None,
                ts: torch.Tensor = None,
                pretokenized=True):

        if self.temporal_prompt and ts is not None:
            temporal_prompts = self.prompt_generator(ts)
        else:
            temporal_prompts = None
        if not pretokenized:
            if self.spatial_filtering is not None:
                x = rearrange(x, "B T F C H W -> B T C H W F")
                x = self.spatial_filtering(x)
                x = rearrange(x, "B T C H W F -> B T F C H W")
                P = x.size(2)
            x, P = self.encode(x)
            x = rearrange(x, "B T F H W D -> (B F) H (T W) D")
            x.requires_grad = True
        else:
            P = x.size(2)
            if self.spatial_filtering is not None:
                x = rearrange(x, "B T F H W D -> B T H W D F")
                x = self.spatial_filtering(x)
                x = rearrange(x, "B T H W D F -> B T F H W D")
                P = x.size(2)
            x = rearrange(x, "B T F H W D -> (B F) H (T W) D")
            if not x.requires_grad:
                x.requires_grad = True
        #        x.requires_grad = True

        with torch.autocast(device_type="cuda", dtype=torch.float32):
            H, W = x.size(-3), x.size(-2)
            x = rearrange(x, "(B F) H W D -> (B F) (H W) D", F=P)
            x = F.normalize(x, dim=-1)
            x = self.token_projector(x)
            #  x = self.drop(x)
            x = self.token_projector_norm(x)
            x = rearrange(x, "B (H W F) D -> (B F) H W D", H=H, W=W)
            x = x + self.encoder_pos_emb(x)
            x = rearrange(x, "(B F) H W D -> (B F) (H W) D", H=H, W=W, F=P)

        if temporal_prompts is not None:
            temporal_prompts = temporal_prompts.unsqueeze(1).repeat(1, P, 1, 1)
            temporal_prompts = rearrange(temporal_prompts, "B F L D -> (B F) L D")
            x = torch.cat((temporal_prompts, x), dim=1)

        with torch.autocast(device_type="cuda", dtype=torch.float16):
            x = x.to(torch.float16)
            if self.prompt_probing:
                prompts = self.prompts.repeat(x.size(0), 1, 1)
                x = torch.cat((prompts, x), dim=-2)
            enc_dict = self.spatial_encoder.inference(x)
            cls_token, x = enc_dict["cls_token"], enc_dict["hidden_states"]
            x = torch.cat((cls_token, x), dim=-2)
            x = self.multiview_token_projector(x)
            #  x = self.drop(x)
            x = self.multiview_token_projector_norm(x)
            _, L, D = x.shape
            x = rearrange(x, "(B F) L D -> (B L) F D", F=P)
            for layer_idx, layer in enumerate(self.multview_encoder):
                x = layer(x, prefix_len=x.size(-2))
            x = rearrange(x, "(B L) F D -> B F L D", L=L, F=P)
        x = x.to(torch.float32)
        cls_token, x = (
            x[:, :, :1, :],
            x[:, :, 1:, :]
        )

        if self.prompt_probing:
            inferred_env = cls_token.clone().mean(dim=(1, 2))
            mask = self.mask_generator(inferred_env)
            mask1, mask2 = self.heaviside(mask).chunk(2, dim=-1)
            device = inferred_env.device
            #   prompts = x[:, :, :, :] #
            if self.temporal_prompt:
                prompts = x[:, :, 0:self.num_queries * 2, :]
            else:
                prompts = x[:, :, 0:self.num_queries, :]
           # cls_token = torch.cat((cls_token, prompts), dim=-2)
          #  cls_token = self.drop(cls_token)
            #cls_token = cls_token.float().mean(dim=(1, 2)).squeeze()
            loss = 0.
            grad_reg = 0
            if self.training:
                scale = torch.tensor(1.).to(cls_token.device).requires_grad_()
                if self.temporal_prompt:
                    temp_feat, freq_feat = prompts.chunk(2, dim=-2)
                    z1 = temp_feat.mean(dim=(1, 2))
                    z2 = freq_feat.mean(dim=(1, 2))
                    env1, env2 = self.environment_predictor_1(z1*(1-mask1)),\
                                  self.environment_predictor_2(z2*(1-mask2))
                    env1 = env1.softmax(dim=-1)
                    env2 = env2.softmax(dim=-1)
                    z1 = z1 * mask1
                    z2 = z2 * mask2
                    #  temp_feat, freq_feat = map(lambda t: environment_gate * (t.mean(dim=(1, 2))), (temp_feat, freq_feat))
                    o1, o2 = scale * self.projection_head_1(z1), scale * self.projection_head_2(z2)
                    with torch.no_grad():
                        self.target_projection_head_2.update(self.projection_head_2)
                        self.target_projection_head_1.update(self.projection_head_1)
                        t1, t2 = self.target_projection_head_1.module(z1).detach(), \
                                 self.target_projection_head_2.module(z2).detach()

                    target_preds = [t1, t2]
                    online_preds = [o1, o2]
                    grad_reg_term = torch.as_tensor(0.).requires_grad_().to(device)
                    kl_div_term = torch.as_tensor(0.).requires_grad_().to(device)
                    alignment = torch.as_tensor(0.).requires_grad_().to(device)
                    uniformity = torch.as_tensor(0.).requires_grad_().to(device)
                    for o_idx, o in enumerate(online_preds):
                        for _, t in enumerate(target_preds):
                            o, t = F.normalize(o, dim=-1), F.normalize(t, dim=-1)
                           # alignment += F.mse_loss(o, t)
                          #  uniformity += 0.5 * (self.uniformity(o) + self.uniformity(t))
                            proxy_logits = torch.mm(o, t.t()) * self.tau.exp().clamp(0, 5.0) + \
                                           self.b
                            proxy_labels = torch.arange(proxy_logits.size(0), device=proxy_logits.device)
                            nce_loss = F.cross_entropy(proxy_logits, proxy_labels, reduction="none")
                            p1 = proxy_logits.log_softmax(dim=1)
                            p2 = proxy_logits.softmax(dim=0).t()
                           # kl_loss = F.kl_div(p1, p2, reduction="none")
                            if o_idx == 0:
                                multi_loss = #(kl_loss.unsqueeze(-1)*1e-2 +
                                             nce_loss.unsqueeze(-1))*env1
                            elif o_idx == 1:
                                multi_loss = #(kl_loss.unsqueeze(-1)*1e-2 +
                                             nce_loss.unsqueeze(-1))*env1
                            E = multi_loss.size(-1)
                            for e in range(E):
                                g1 = torch.autograd.grad(
                                    multi_loss[:, e][0::2].mean(),
                                    [scale],
                                    create_graph=True
                                )[0]
                                g2 = torch.autograd.grad(
                                    multi_loss[:, e][1::2].mean(),
                                    [scale],
                                    create_graph=True
                                )[0]
                                grad_reg_term += (g1 * g2).mean()
                            #kl_div_term += kl_loss.mean()*0.01
                            loss = loss + nce_loss.mean()
                            # kl_div 0.1
                    loss = loss + kl_div_term
                    grad_reg = grad_reg_term + grad_reg
               # temp_feat, freq_feat = map(lambda t: F.normalize(t), (temp_feat, freq_feat))
              #  alignment_loss = F.mse_loss(temp_feat, freq_feat)
              #  uniformity_loss = 0.5 * (self.uniformity(temp_feat) + self.uniformity(freq_feat))
        #  cls_token = self.drop(cls_token)
        # cls_token = torch.cat((cls_token, prompts), dim=-2)
        #  cls_token = self.drop(cls_token)
        # cls_token = cls_token.float().mean(dim=(1, 2)).squeeze()
        else:
            cls_token = cls_token.float().mean(dim=1).squeeze()
        with torch.autocast(device_type="cuda", dtype=torch.float32):
            cls_token = torch.cat((cls_token, prompts), dim=-2)
            cls_token = self.drop(cls_token)
            cls_token = cls_token.float().mean(dim=(1, 2)).squeeze()
            cls_token = self.classification_norm(cls_token)
            logits = self.classification_head(cls_token)
            ce_loss = self.ce_loss_func(logits, labels.long())
           # if self.temporal_prompt:
           #     loss = ce_loss + loss  #+ 0.5 * (alignment_loss + 0.02 * uniformity_loss)

        return {"logits": logits,
                "features": cls_token.detach(),
                "ce_loss":ce_loss,
                "gradient_penalty":-grad_reg,
                "loss": loss}













