from copy import deepcopy

import torch
import torch.nn as nn
import torch.jit

import PIL
import torchvision.transforms as transforms
import cotta_transforms as my_transforms
from inject_vida import inject_trainable_vida
from time import time
import logging

from torchvision.transforms import InterpolationMode
import math
import torch.nn.functional as F
from typing import Tuple

def get_tta_transforms(gaussian_std: float=0.005, soft=False, clip_inputs=False):
    img_shape = (224, 224, 3)
    n_pixels = img_shape[0]

    clip_min, clip_max = 0.0, 1.0

    p_hflip = 0.5

    tta_transforms = transforms.Compose([
        my_transforms.Clip(0.0, 1.0), 
        my_transforms.ColorJitterPro(
            brightness=[0.8, 1.2] if soft else [0.6, 1.4],
            contrast=[0.85, 1.15] if soft else [0.7, 1.3],
            saturation=[0.75, 1.25] if soft else [0.5, 1.5],
            hue=[-0.03, 0.03] if soft else [-0.06, 0.06],
            gamma=[0.85, 1.15] if soft else [0.7, 1.3]
        ),
        transforms.Pad(padding=int(n_pixels / 2), padding_mode='edge'),  
        transforms.RandomAffine(
            degrees=[-8, 8] if soft else [-15, 15],
            translate=(1/16, 1/16),
            scale=(0.95, 1.05) if soft else (0.9, 1.1),
            shear=None,
            #resample=PIL.Image.BILINEAR,
            #fillcolor=None
            interpolation=InterpolationMode.BILINEAR,
            fill=0         # new API
        ),
        transforms.GaussianBlur(kernel_size=5, sigma=[0.001, 0.25] if soft else [0.001, 0.5]),
        #transforms.CenterCrop(size=n_pixels),
        #transforms.RandomResizedCrop(size=n_pixels, scale=(0.6, 0.7)),
        #transforms.RandomHorizontalFlip(p=p_hflip),
        my_transforms.GaussianNoise(0, gaussian_std),
        my_transforms.Clip(clip_min, clip_max)
    ])

    flip_and_color_jitter = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomApply(
            [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
            p=0.8
        ),
        transforms.RandomGrayscale(p=0.2),
    ])
    normalize = transforms.Compose([
        #transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    # transformation for the rest of global crops
    ibot_transforms_1 = transforms.Compose([
        #ransforms.GaussianBlur(kernel_size=5, sigma=[0.1, 0.2]),
        #transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC),
        #ttai_transforms,
        transforms.RandomResizedCrop(size=n_pixels, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
        #transforms.CenterCrop(size=n_pixels),
        #flip_and_color_jitter,
        #normalize,
    ])

    ibot_transforms_2 = transforms.Compose([
        #transforms.GaussianBlur(kernel_size=5, sigma=[0.01, 0.02] ),
        #transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC),
        #tta_transforms,
        #transforms.RandomResizedCrop(size=n_pixels, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
        #transforms.Pad(padding=int(n_pixels / 2), padding_mode='edge'),
        #transforms.RandomAffine(
        #    degrees=[-8, 8] if soft else [-15, 15],
        #    translate=(1 / 16, 1 / 16),
        #    scale=(0.95, 1.05) if soft else (0.9, 1.1),
        #    shear=None,
        #    # resample=PIL.Image.BILINEAR,
        #    # fillcolor=None
        #    interpolation=InterpolationMode.BILINEAR,
        #    fill=0  # new API
        #),
        my_transforms.ColorJitterPro(
            brightness=[0.8, 1.2] if soft else [0.6, 1.4],
            contrast=[0.85, 1.15] if soft else [0.7, 1.3],
            saturation=[0.75, 1.25] if soft else [0.5, 1.5],
            hue=[-0.03, 0.03] if soft else [-0.06, 0.06],
            gamma=[0.85, 1.15] if soft else [0.7, 1.3]
        ),
        #transforms.RandomApply(
        #    [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
        #    p=0.8
        #),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomResizedCrop(size=n_pixels, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
        #my_transforms.GaussianNoise(0, gaussian_std),
        #my_transforms.Clip(clip_min, clip_max),
        #transforms.CenterCrop(size=n_pixels),
        #transforms.RandomSolarize(0.05),
        #flip_and_color_jitter,
        #normalize,
    ])

    return ibot_transforms_1, ibot_transforms_2


def update_ema_variables(ema_model, model, alpha_teacher, alpha_vida):#, iteration):
    # for ema_param, param in zip(ema_model.parameters(), model.parameters()):
    #     ema_param.data[:] = alpha_teacher * ema_param[:].data[:] + (1 - alpha_teacher) * param[:].data[:]
    # return ema_model
    for ema_param, (name, param) in zip(ema_model.parameters(), model.named_parameters()):
        #ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
        if ('vida_' in name) or ('proj' in name) or ('qkv' in name):
            ema_param.data[:] = alpha_vida * ema_param[:].data[:] + (1 - alpha_vida) * param[:].data[:]
        else:
            ema_param.data[:] = alpha_teacher * ema_param[:].data[:] + (1 - alpha_teacher) * param[:].data[:]
    return ema_model


class VIDA_MAE_IBOT(nn.Module):
    """ViDA adapts a model by entropy minimization during testing.

    Once tented, a model adapts itself by updating on every forward.
    """
    def __init__(self, model, optimizer, hogs=None, mask_token=None, hog_ratio=1, block_size=16, mask_method="random", mask_ratio=0.5,steps=1, episodic=False, ema=0.99, ema_vida = 0.99, unc_thr = 0.2):
        super().__init__()

        self.optimizer = optimizer
        self.steps = steps
        assert steps > 0, "VIDA_MAE requires >= 1 step(s) to forward and update"
        self.episodic = episodic
        # 2️ reconstruction head

        model.proj = nn.Linear(768, 256, bias=True)
        nn.init.trunc_normal_( model.proj.weight, std=0.02)
        nn.init.zeros_(model.proj.bias)
        self.model = model

        self.model_state, self.optimizer_state, self.model_ema, self.model_anchor = \
            copy_model_and_optimizer(self.model, self.optimizer)
        self.transform_u,self.transform_v  = get_tta_transforms(soft=True)
        self.alpha_teacher = ema
        self.alpha_vida = ema_vida
        self.thr = unc_thr

        self.block_size = block_size
        self.mask_ratio = mask_ratio
        self.mask_method = mask_method

        self.hogs = hogs
        self.mask_token = mask_token
        self.ratio = hog_ratio
        self.mse_func = nn.MSELoss(reduction="mean")

        self.register_buffer("teacher_cent", torch.zeros(1, 256))
        self.teacher_temp = 0.02  # config
        self.student_temp = 0.1  # config
        self.center_momentum = 0.9

    def forward(self, x):
        if self.episodic:
            self.reset()

        for _ in range(self.steps):
            outputs = self.forward_and_adapt(x, self.model, self.optimizer, self.block_size, self.mask_ratio, self.mask_token)

        return outputs

    def reset(self):
        if self.model_state is None or self.optimizer_state is None:
            raise Exception("cannot reset without saved model/optimizer state")
        load_model_and_optimizer(self.model, self.optimizer,
                                 self.model_state, self.optimizer_state)
        # use this line if you want to reset the teacher model as well. Maybe you also 
        # want to del self.model_ema first to save gpu memory.
        self.model_state, self.optimizer_state, self.model_ema, self.model_anchor = \
            copy_model_and_optimizer(self.model, self.optimizer)


    def set_scale(self, update_model, high, low):
        for name, module in update_model.named_modules():
            if hasattr(module, 'scale1'):
                module.scale1 = low.item()
            elif hasattr(module, 'scale2'):
                module.scale2 = high.item()
        # print('2')

    '''
    @torch.enable_grad()  # ensure grads in possible no grad context for testing
    def forward_and_adapt(self, x, model, optimizer, block_size, mask_ratio, mask_token):
        self.model_ema.eval()


        U_image = self.transform(x)
        V_image = self.transform(x)
        # 1) extract patch features once (for uncertainty/random/saliency)
        with torch.no_grad():
            #_, feats = self.model_ema(self.transform(x), return_norm=True)
            _, feats = self.model_ema(U_image, return_norm=True)

        patch_feats = feats[:, 1:, :]  # [B, L, D]
        B, L, D = patch_feats.shape

        if self.mask_method == "random":
            # — uniform random masking —
            # random_masking returns a 6-tuple; last entry is ids_dump
            _, _, _, _, _, ids_dump = self.random_masking(
                patch_feats, mask_ratio
            )  # [B, top_k]

        elif self.mask_method == "saliency":
            # — saliency‐based masking —
            #att_e, cos_e = self.saliency_importance(x)  # each [B, L]
            att_e, cos_e = self.saliency_last_only(U_image)
            importance = (2.0*att_e  +1.0*cos_e)/3.0
            _, _, _, _, _, ids_dump = self.saliency_based_masking(
                patch_feats, importance, mask_ratio
            )  # [B, top_k]

        else:
            raise ValueError(f"Unknown mask_method: {self.mask_method!r}")

        # 3) build binary mask [B, L] by scattering 1’s at ids_dump
        mask_chosed = torch.zeros_like(var if 'var' in locals() else patch_feats[..., 0])
        # mask_chosed shape is [B, L]
        mask_chosed.scatter_(1, ids_dump, 1.0)


        # --- forward with mask_chosed ---
        logits_S, tokens_S = self.model(U_image, mask_token, mask_chosed, True)
        with torch.no_grad():
            logits_T, tokens_T = self.model_ema(V_image, None, None, True)

        # project
        self.model.proj.to(device=x.device)
        self.model_ema.proj.to(device=x.device)
        self.teacher_cent.to(device=x.device)
        #self.teacher_temp.to(device=x.device)
        z_s_cls = self.model.proj(tokens_S[:, 0])
        z_t_cls = self.model_ema.proj(tokens_T[:, 0])

        z_s_tok = self.model.proj(tokens_S[:, 1:])  # [B,N,256]
        z_t_tok = self.model_ema.proj(tokens_T[:, 1:])  # [B,N,256]

         
        self.teacher_cent = self.teacher_cent.to(device=x.device)
        self.teacher_temp = self.teacher_temp
        if isinstance(self.teacher_temp, torch.Tensor):
            self.teacher_temp = self.teacher_temp.to(device=x.device)

        center_momentum = self.center_momentum
        if isinstance(center_momentum, torch.Tensor):
            center_momentum = center_momentum.to(device=x.device)

        # teacher soft targets (centering + temp)
        p_t_cls = F.softmax(z_t_cls, dim=-1)# - self.teacher_cent) / self.teacher_temp, dim=-1)
        p_t_tok = F.softmax(z_t_tok, dim=-1)# - self.teacher_cent[:, None, :]) / self.teacher_temp, dim=-1)

        # student log-probs
        log_q_s_cls = F.log_softmax(z_s_cls , dim=-1) #/ self.student_temp, dim=-1)
        log_q_s_tok = F.log_softmax(z_s_tok , dim=-1)#/ self.student_temp, dim=-1)

        # KL losses
        L_CLS = -(p_t_cls.detach() * log_q_s_cls).sum(-1).mean()
        mask_bool = mask_chosed.bool()
        L_MIM = -(p_t_tok.detach()[mask_bool] * log_q_s_tok[mask_bool]).sum(-1).mean()

        loss_consistency = (softmax_entropy(logits_S, logits_T.detach())).mean(0)

        loss = self.ratio * L_MIM + loss_consistency #(tune later)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        #with torch.no_grad():
        #    batch_cent = p_t_cls.mean(0, keepdim=True)
        #    self.teacher_cent.mul_(self.center_momentum).add_(batch_cent * (1 - self.center_momentum))


        # Teacher update
        self.model_ema = update_ema_variables(ema_model = self.model_ema, model = self.model, alpha_teacher= self.alpha_teacher, alpha_vida = self.alpha_vida)
        # Stochastic restore
        # if True:
        #     for nm, m  in self.model.named_modules():
        #         for npp, p in m.named_parameters():
        #             if npp in ['weight', 'bias'] and p.requires_grad:
        #                 mask = (torch.rand(p.shape)<0.001).float().cuda() 
        #                 with torch.no_grad():
        #                     p.data = self.model_state[f"{nm}.{npp}"] * mask + p * (1.-mask)
        return logits_T
    '''

    def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))

        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]

        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        ids_dump = ids_shuffle[:, len_keep:]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore, ids_shuffle, ids_keep, ids_dump

    @torch.enable_grad()
    def forward_and_adapt(self, x, model, optim, block_size, mask_ratio, mask_token):
        self.model_ema.eval()

        logT_x, tokT_x = self.model_ema(x, None, None, True)

        # ── build two global views ──────────────────────────
        img_u = self.transform_u(x)
        img_v = self.transform_v(x)

        # ── helper ------------------------------------------------
        def gen_mask(img):
            """return binary mask [B,N] & bool index"""
            with torch.no_grad():
                _, f = self.model_ema(img, return_norm=True)
            toks = f[:, 1:]  # [B,N,D]
            if self.mask_method == "random":
                _, _, _, _, _, ids = self.random_masking(toks, self.mask_ratio)
            elif self.mask_method == "saliency":
                att, cos = self.saliency_last_only(img)
                imp = (2 * att + cos) / 3
                _, _, _, _, _, ids = self.saliency_based_masking(toks, imp, self.mask_ratio)
            else:
                raise ValueError
            m = torch.zeros_like(toks[..., 0]);
            m.scatter_(1, ids, 1.)
            return m, m.bool()

        mask_u, mb_u = gen_mask(img_u)
        mask_v, mb_v = gen_mask(img_v)

        # ── forward through student / teacher ──────────────
        logS_u, tokS_u = self.model(img_u, self.model.module.mask_token, mask_u, True)
        logS_v, tokS_v = self.model(img_v, self.model.module.mask_token, mask_v, True)

        with torch.no_grad():
            logT_u, tokT_u = self.model_ema(img_u, None, None, True)
            logT_v, tokT_v = self.model_ema(img_v, None, None, True)

        # ── projection to 256-d (already on correct device) ─────
        proj_S, proj_T = self.model.proj.to(device=x.device), self.model_ema.proj.to(device=x.device)
        z_s_u, z_s_v = proj_S(tokS_u[:, 1:]), proj_S(tokS_v[:, 1:])
        z_t_u, z_t_v = proj_T(tokT_u[:, 1:]), proj_T(tokT_v[:, 1:])

        # ── L_MIM on each view (teacher → student, same view) ───
        L_mim_u = -(F.softmax(z_t_u[mb_u], -1).detach() *
                    F.log_softmax(z_s_u[mb_u], -1)).sum(-1).mean()
        L_mim_v = -(F.softmax(z_t_v[mb_v], -1).detach() *
                    F.log_softmax(z_s_v[mb_v], -1)).sum(-1).mean()
        L_mim = (L_mim_u + L_mim_v) * 0.5

        # ── cross-view logits consistency (teacher ↔ student) ──
        L_cons = (softmax_entropy(logS_u, logT_v.detach()).mean() +
                  softmax_entropy(logS_v, logT_u.detach()).mean()) * 0.5

        # ── total loss & update ────────────────────────────
        loss = self.ratio * L_mim + L_cons
        loss.backward()
        optim.step()
        optim.zero_grad()

        # ── EMA update ─────────────────────────────────────
        update_ema_variables(self.model_ema, self.model,
                             self.alpha_teacher, self.alpha_vida)


        # Stochastic restore
        #if True:
        #    for nm, m  in self.model.named_modules():
        #        for npp, p in m.named_parameters():
        #            if npp in ['weight', 'bias'] and p.requires_grad:
        #                mask = (torch.rand(p.shape)<0.001).float().cuda()
        #                with torch.no_grad():
        #                    p.data = self.model_state[f"{nm}.{npp}"].to(p.device) * mask + p * (1.-mask)

        return logT_x

    def saliency_based_masking(self, x, importance, mask_ratio):
        """
        Perform masking based on saliency scores (importance values from decoder attention/angle maps).

        Args:
            x (torch.Tensor): Patch embeddings of shape [B, L, D]
            importance (torch.Tensor): Importance scores of shape [B, L]
            mask_ratio (float): Percentage of patches to be masked

        Returns:
            x_masked (torch.Tensor): Masked patch embeddings
            mask (torch.Tensor): Binary mask (0 = keep, 1 = mask)
            ids_restore (torch.Tensor): Indices for restoring original order
            ids_shuffle (torch.Tensor): Shuffled indices for masking
            ids_keep (torch.Tensor): Indices of kept patches
            ids_dump (torch.Tensor): Indices of masked patches
        """
        B, L, D = x.shape  # Batch, Length (196), Dim
        len_keep = int(L * (1 - mask_ratio))  # How many patches to keep

        # Normalize importance scores
        importance = importance - importance.min(dim=1, keepdim=True)[0]
        importance = (importance / (importance.max(dim=1, keepdim=True)[0] + 1e-6)) + 1e-6

        # Sort patches by importance (descending order = high importance first)
        ids_shuffle = torch.argsort(importance, dim=1, descending=False)
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # ids_shuffle = torch.multinomial(importance, L, replacement=False)
        # ids_restore = torch.argsort(ids_shuffle, dim=1)

        # Select patches to keep (top-K most important patches are masked)
        ids_keep = ids_shuffle[:, :len_keep]
        ids_dump = ids_shuffle[:, len_keep:]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # Create binary mask
        mask = torch.ones([B, L], device=x.device)
        mask[:, :len_keep] = 0  # Keep the least important patches
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore, ids_shuffle, ids_keep, ids_dump

    @torch.no_grad()
    def saliency_last_only(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Fastest saliency estimate – **runs ZERO earlier blocks**.
        Trade-off: ≈4 – 6 % accuracy drop vs. full prop_last.

        Returns
        -------
        att_e , cos_e : each [B , N]  (N = #patches)
        """
        # unwrap if DataParallel
        vit = self.model_ema.module if isinstance(self.model_ema,
                                                  nn.DataParallel) else self.model_ema
        B = img.size(0)

        # 1) patch → tokens  (no CLS, no previous blocks)
        tokens = vit.patch_embed(img)  # [B , N , D]

        # 2) approximate the missing layer-norm with the last block’s norm1
        last_blk = vit.blocks[-1]
        tokens = last_blk.norm1(tokens)  # [B , N , D]

        # 3) Q K from the last block only
        qkv = last_blk.attn.qkv(tokens)  # [B , N , 3·H·hd]
        H = last_blk.attn.num_heads
        hd = qkv.shape[-1] // (3 * H)

        q, k, _ = (qkv.view(B, -1, 3, H, hd)
                   .permute(2, 0, 3, 1, 4)
                   .unbind(0))  # each [B , H , N , hd]

        # 4) attention + cosine maps  [B , H , N , N]
        att = (q @ k.transpose(-2, -1)).softmax(dim=-1)
        qn, kn = q.norm(dim=-1, keepdim=True), k.norm(dim=-1, keepdim=True)
        cos = (q @ k.transpose(-2, -1)) / (qn * kn + 1e-6)

        # helper: [B,H,N,N] → [B,N] with min-max normalisation
        def reduce(hm: torch.Tensor) -> torch.Tensor:
            out = hm.mean(dim=1).sum(dim=1)  # [B , N]
            mn, mx = out.amin(1, keepdim=True), out.amax(1, keepdim=True)
            return (out - mn) / (mx - mn + 1e-6)

        att_e = reduce(att)  # [B , N]
        cos_e = reduce(cos)  # [B , N]
        return att_e, cos_e

    def saliency_importance(self,
                            img: torch.Tensor,
                            # patch_feats: torch.Tensor,
                            mode: str = "last"):
        """
        Args:
          patch_feats: [B, N, D]  (feats[:,1:,:])
          mode: "avg" to average over ALL encoder blocks;
                "last" to only use the final block.

        Returns:
          att_e, cos_e, att_d, cos_d   each [B, N]
        """
        # unwrap DataParallel if needed
        vit = (self.model_ema.module
               if isinstance(self.model_ema, nn.DataParallel)
               else self.model_ema)

        # B, N, D = patch_feats.shape
        patch_feats = vit.patch_embed(img)
        # buffers to collect per-block maps
        all_att_e, all_cos_e = [], []

        # inline helper to turn [B,H,N,N] -> [B,N]
        def process_map(hm: torch.Tensor) -> torch.Tensor:
            # hm: [B, H, N, N]
            out = hm.mean(dim=1)  # [B, N, N]
            out = out.sum(dim=1)  # [B, N]
            mn = out.amin(dim=1, keepdim=True)
            mx = out.amax(dim=1, keepdim=True)
            return (out - mn) / (mx - mn + 1e-6)

        # loop over encoder blocks
        for block in vit.blocks:
            # Q/K/V from this block
            qkv = block.attn.qkv(patch_feats)  # [B,N,3*H*hd]
            B2, N2, C2 = qkv.shape
            H = block.attn.num_heads
            hd = (C2 // 3) // H
            qkv = (qkv.view(B2, N2, 3, H, hd)
                   .permute(2, 0, 3, 1, 4))  # [3,B,H,N,hd]
            q, k, _ = qkv.unbind(dim=0)  # each [B,H,N,hd]

            attn_map = (q @ k.transpose(-2, -1)).softmax(dim=-1)  # [B,H,N,N]
            qn = q.norm(dim=-1, keepdim=True)
            kn = k.norm(dim=-1, keepdim=True)
            cos_map = (q @ k.transpose(-2, -1)) / (qn * kn + 1e-6)  # [B,H,N,N]

            all_att_e.append(process_map(attn_map))
            all_cos_e.append(process_map(cos_map))

        # now aggregate according to mode
        if mode == "avg":
            att_e = torch.stack(all_att_e, dim=0).mean(dim=0)  # [B,N]
            cos_e = torch.stack(all_cos_e, dim=0).mean(dim=0)
        elif mode == "max":
            att_e = torch.stack(all_att_e, dim=0).max(dim=0).values  # [B,N]
            cos_e = torch.stack(all_cos_e, dim=0).max(dim=0).values
        elif mode == "last":
            att_e = all_att_e[-1]
            cos_e = all_cos_e[-1]
        elif mode == "first":
            att_e = all_att_e[0]
            cos_e = all_cos_e[0]
        elif mode == "first_last":
            att_e = (all_att_e[0] + all_att_e[-1]) * .5
            cos_e = (all_cos_e[0] + all_cos_e[-1]) * .5
        elif mode == "prop_last":
            # --- Recompute the *input* to the final block by forwarding through all preceding blocks --
            # 1) patch → tokens (with cls + pos + drop)
            x = vit.patch_embed(img)  # [B, N, D]
            cls = vit.cls_token.expand(B, -1, -1)  # [B,1,D]
            x = torch.cat((cls, x), dim=1)  # [B,N+1,D]
            x = vit.pos_drop(x + vit.pos_embed)

            # 2) forward through all but the last block
            for block in vit.blocks[:-1]:
                x = block(x)

            # 3) now x is what the *last* block actually sees
            last = vit.blocks[-1]
            qkv = last.attn.qkv(x[:, 1:, :])  # drop CLS
            H = last.attn.num_heads

            hd = (qkv.shape[-1] // 3) // H
            q, k, _ = (qkv.view(B, N, 3, H, hd).permute(2, 0, 3, 1, 4).unbind(0))
            attn = (q @ k.transpose(-2, -1)).softmax(-1)

            qn, kn = q.norm(dim=-1, keepdim=True), k.norm(dim=-1, keepdim=True)
            cos = (q @ k.transpose(-2, -1)) / (qn * kn + 1e-6)

            att_e = process_map(attn)
            cos_e = process_map(cos)
        else:
            raise ValueError(f"Unknown mode {mode!r}, use 'avg' or 'last'.")

        return att_e, cos_e

@torch.jit.script
def softmax_entropy(x, x_ema):# -> torch.Tensor:
    """Entropy of softmax distribution from logits."""
    return -0.5*(x_ema.softmax(1) * x.log_softmax(1)).sum(1)-0.5*(x.softmax(1) * x_ema.log_softmax(1)).sum(1)

def collect_params(model):
    """Collect all trainable parameters.

    Walk the model's modules and collect all parameters.
    Return the parameters and their names.

    Note: other choices of parameterization are possible!
    """
    vida_params_list = []
    model_params_list = []
    proj_param_list = []

    for name, param in model.named_parameters():
        if ('vida_' in name) or ('recon_proj' in name) or ('mask_token' in name):# or ('proj' in name):
            vida_params_list.append(param)
        elif ('proj' in name) or ('qkv' in name) :
            proj_param_list.append(param)
        else:
            model_params_list.append(param)

    return model_params_list, vida_params_list, proj_param_list


def copy_model_and_optimizer(model, optimizer):
    """Copy the model and optimizer states for resetting after adaptation."""
    model_state = deepcopy(model.state_dict())
    model_anchor = deepcopy(model)
    optimizer_state = deepcopy(optimizer.state_dict())
    ema_model = deepcopy(model)
    for param in ema_model.parameters():
        param.detach_()
    return model_state, optimizer_state, ema_model, model_anchor


def load_model_and_optimizer(model, optimizer, model_state, optimizer_state):
    """Restore the model and optimizer states from copies."""
    model.load_state_dict(model_state, strict=True)
    optimizer.load_state_dict(optimizer_state)




def configure_model(model, cfg, head_dim, num_class):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.cpu()

    # 1️  inject ViDA adapters
    #inject_trainable_vida(model, target_replace_module=["CrossAttention", "Attention"],
    #                      r=cfg.TEST.vida_rank1, r2=cfg.TEST.vida_rank2)

    # 2️ reconstruction head
    model.recon_proj = nn.Linear(head_dim, num_class, bias=True)
    nn.init.trunc_normal_(model.recon_proj.weight, std=0.02)
    nn.init.zeros_(model.recon_proj.bias)

    # 3️  learnable mask token  (⬅  add **before** wrapping)
    model.mask_token = nn.Parameter(
        torch.zeros(1, 1, head_dim, device=device), requires_grad=True
    )

    # 4️ load checkpoint (strict=False because recon_proj is new)
    #if cfg.TEST.ckpt is not None:
    #    state = torch.load(cfg.TEST.ckpt)
    #    if 'model' in state:
    #        state = state['model']
    #    if next(iter(state)).startswith('module.'):
    #        state = {k[7:]: v for k, v in state.items()}
    #    model.load_state_dict(state, strict=False)

    # 5️  single DataParallel wrap – now mask_token is replicated, too
    model = torch.nn.DataParallel(model).to(device).train()
    return model


def check_model(model):
    """Check model for compatability with tent."""
    is_training = model.training
    assert is_training, "tent needs train mode: call model.train()"
    param_grads = [p.requires_grad for p in model.parameters()]
    has_any_params = any(param_grads)
    has_all_params = all(param_grads)
    assert has_any_params, "tent needs params to update: " \
                           "check which require grad"
    assert not has_all_params, "tent should not update all params: " \
                               "check which require grad"
    has_bn = any([isinstance(m, nn.BatchNorm2d) for m in model.modules()])
    assert has_bn, "tent needs normalization for its optimization"
