from copy import deepcopy

import torch
import torch.nn as nn
import torch.jit

import PIL
import torchvision.transforms as transforms
import cotta_transforms as my_transforms
from inject_vida import inject_trainable_vida
from time import time
import logging

from torchvision.transforms import InterpolationMode
import math
import torch.nn.functional as F
from typing import Tuple


def get_tta_transforms(gaussian_std: float=0.005, soft=False, clip_inputs=False):
    img_shape = (224, 224, 3)
    n_pixels = img_shape[0]

    clip_min, clip_max = 0.0, 1.0

    p_hflip = 0.5

    tta_transforms = transforms.Compose([
        my_transforms.Clip(0.0, 1.0), 
        my_transforms.ColorJitterPro(
            brightness=[0.8, 1.2] if soft else [0.6, 1.4],
            contrast=[0.85, 1.15] if soft else [0.7, 1.3],
            saturation=[0.75, 1.25] if soft else [0.5, 1.5],
            hue=[-0.03, 0.03] if soft else [-0.06, 0.06],
            gamma=[0.85, 1.15] if soft else [0.7, 1.3]
        ),
        transforms.Pad(padding=int(n_pixels / 2), padding_mode='edge'),  
        transforms.RandomAffine(
            degrees=[-8, 8] if soft else [-15, 15],
            translate=(1/16, 1/16),
            scale=(0.95, 1.05) if soft else (0.9, 1.1),
            shear=None,
            #resample=PIL.Image.BILINEAR,
            #fillcolor=None
            interpolation=InterpolationMode.BILINEAR,
            fill=0         # new API
        ),
        transforms.GaussianBlur(kernel_size=5, sigma=[0.001, 0.25] if soft else [0.001, 0.5]),
        transforms.CenterCrop(size=n_pixels),
        transforms.RandomHorizontalFlip(p=p_hflip),
        my_transforms.GaussianNoise(0, gaussian_std),
        my_transforms.Clip(clip_min, clip_max)
    ])
    return tta_transforms


def update_ema_variables(ema_model, model, alpha_teacher, alpha_vida):#, iteration):
    # for ema_param, param in zip(ema_model.parameters(), model.parameters()):
    #     ema_param.data[:] = alpha_teacher * ema_param[:].data[:] + (1 - alpha_teacher) * param[:].data[:]
    # return ema_model
    for ema_param, (name, param) in zip(ema_model.parameters(), model.named_parameters()):
        #ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
        if "vida_" in name:
            ema_param.data[:] = alpha_vida * ema_param[:].data[:] + (1 - alpha_vida) * param[:].data[:]
        else:
            ema_param.data[:] = alpha_teacher * ema_param[:].data[:] + (1 - alpha_teacher) * param[:].data[:]
    return ema_model


class VIDA_MAE(nn.Module):
    """ViDA adapts a model by entropy minimization during testing.

    Once tented, a model adapts itself by updating on every forward.
    """
    def __init__(self, model, optimizer, hogs=None, mask_token=None, hog_ratio=1, block_size=16, mask_method="random", mask_ratio=0.5,steps=1, episodic=False, ema=0.99, ema_vida = 0.99, unc_thr = 0.2):
        super().__init__()
        self.model = model
        self.optimizer = optimizer
        self.steps = steps
        assert steps > 0, "VIDA_MAE requires >= 1 step(s) to forward and update"
        self.episodic = episodic
        
        self.model_state, self.optimizer_state, self.model_ema, self.model_anchor = \
            copy_model_and_optimizer(self.model, self.optimizer)
        self.transform = get_tta_transforms()    
        self.alpha_teacher = ema
        self.alpha_vida = ema_vida
        self.thr = unc_thr

        self.block_size = block_size
        self.mask_ratio = mask_ratio
        self.mask_method = mask_method

        self.hogs = hogs
        self.mask_token = mask_token
        self.ratio = hog_ratio
        self.mse_func = nn.MSELoss(reduction="mean")

    def forward(self, x):
        if self.episodic:
            self.reset()

        for _ in range(self.steps):
            outputs = self.forward_and_adapt(x, self.model, self.optimizer, self.block_size, self.mask_ratio, self.mask_token)

        return outputs

    def reset(self):
        if self.model_state is None or self.optimizer_state is None:
            raise Exception("cannot reset without saved model/optimizer state")
        load_model_and_optimizer(self.model, self.optimizer,
                                 self.model_state, self.optimizer_state)
        # use this line if you want to reset the teacher model as well. Maybe you also 
        # want to del self.model_ema first to save gpu memory.
        self.model_state, self.optimizer_state, self.model_ema, self.model_anchor = \
            copy_model_and_optimizer(self.model, self.optimizer)

    def group_descriptors(self, desc_map, mask_chosed, outputs_hog):
        """
        desc_map: either
          [B, C, nbins, Hd, Wd] for HOG  OR  [B, D, Hd, Wd] for Daisy
        """
        # 0) unify the “descriptor dimension” into a single axis
        if desc_map.dim() == 5:
            # HOG‐style: collapse C x nbins → D
            B, C, nbins, Hd, Wd = desc_map.shape
            desc_map = desc_map.reshape(B, C * nbins, Hd, Wd)
        # now desc_map is [B, D, Hd, Wd] in both cases

        # … the rest is unchanged …
        B, D, Hd, Wd = desc_map.shape
        L = outputs_hog.size(1) - 1
        P = int(math.sqrt(L))

        if Hd == P and Wd == P:
            desc = desc_map.flatten(2).permute(0, 2, 1)
        else:
            gh, gw = Hd // P, Wd // P
            dm = desc_map.permute(0, 2, 3, 1)
            dm = dm.unfold(1, gh, gh).unfold(2, gw, gw)
            dm = dm.permute(0, 1, 3, 5, 2, 4)
            desc = dm.reshape(B, P * P, D * gh * gw)

        # pick the projection layer from the inner module when wrapped
        proj_layer = (self.model.module.recon_proj
                      if isinstance(self.model, torch.nn.DataParallel)
                      else self.model.recon_proj)

        preds = proj_layer(outputs_hog[:, 1:, :])
        #mb = mask_chosed.flatten(1).to(bool)
        mb = mask_chosed.flatten(1).bool().to(preds.device)
        desc = desc.to(preds.device)
        
        return preds[mb], desc[mb]

    def set_scale(self, update_model, high, low):
        for name, module in update_model.named_modules():
            if hasattr(module, 'scale1'):
                module.scale1 = low.item()
            elif hasattr(module, 'scale2'):
                module.scale2 = high.item()
        # print('2')

    @torch.enable_grad()  # ensure grads in possible no grad context for testing
    def forward_and_adapt(self, x, model, optimizer, block_size, mask_ratio, mask_token):
        self.model_ema.eval()
        # Teacher Prediction
        # Augmentation-averaged Prediction
        #N = 10
        #outputs_uncs = []
        #for i in range(N):
        #    outputs_  = self.model_ema(self.transform(x)).detach()
        #    outputs_uncs.append(outputs_)
        #outputs_unc = torch.stack(outputs_uncs)
        #variance = torch.var(outputs_unc, dim=0)
        #uncertainty = torch.mean(variance)*0.1
        # print(uncertainty)
        #if uncertainty>= self.thr:
        #    lambda_high = 1+uncertainty
        #    lambda_low = 1-uncertainty
        #else:
        #    lambda_low = 1+uncertainty
        #    lambda_high = 1-uncertainty
        #self.set_scale(update_model = model, high = lambda_high, low = lambda_low)
        #self.set_scale(update_model = self.model_ema, high = lambda_high, low = lambda_low)

        #B = x.size(0)

        # 1) extract patch features once (for uncertainty/random/saliency)
        with torch.no_grad():
            _, feats = self.model_ema(self.transform(x), return_norm=True)
        patch_feats = feats[:, 1:, :]  # [B, L, D]
        B, L, D = patch_feats.shape

        # 2) pick mask indices (ids_dump) based on method
        if self.mask_method == "uncertainty":
            # — repeat forward to get per-token variance —
            n_forward = 10
            outs = []
            with torch.no_grad():
                for _ in range(n_forward):
                    _, f = self.model_ema(self.transform(x), return_norm=True)
                    outs.append(f[:, 1:, :].mean(dim=2))  # [B, L]
            var = torch.stack(outs, dim=0).var(dim=0)  # [B, L]
            _, sorted_idx = torch.sort(var, dim=1, descending=True)
            top_k = int(L * mask_ratio)
            ids_dump = sorted_idx[:, :top_k]  # [B, top_k]

        elif self.mask_method == "random":
            # — uniform random masking —
            # random_masking returns a 6-tuple; last entry is ids_dump
            _, _, _, _, _, ids_dump = self.random_masking(
                patch_feats, mask_ratio
            )  # [B, top_k]

        elif self.mask_method == "saliency":
            # — saliency‐based masking —
            #att_e, cos_e = self.saliency_importance(x)  # each [B, L]
            att_e, cos_e = self.saliency_last_only(x)
            importance = (2.0*att_e  +1.0*cos_e)/3.0
            _, _, _, _, _, ids_dump = self.saliency_based_masking(
                patch_feats, importance, mask_ratio
            )  # [B, top_k]

        else:
            raise ValueError(f"Unknown mask_method: {self.mask_method!r}")

        # 3) build binary mask [B, L] by scattering 1’s at ids_dump
        mask_chosed = torch.zeros_like(var if 'var' in locals() else patch_feats[..., 0])
        # mask_chosed shape is [B, L]
        mask_chosed.scatter_(1, ids_dump, 1.0)

        # 4) do the MAE forward with that mask
        outputs, outputs_hog = self.model(
            x, mask_token, mask_chosed, return_norm=True
        )
        standard_ema = self.model_ema(x)
        #outputs = self.model(x)
        # Student update consistency
        loss_consistency = (softmax_entropy(outputs, standard_ema.detach())).mean(0) 

        # 5) optional descriptor loss
        if self.hogs is not None:
            preds_mask, labels_mask = self.group_descriptors(
                desc_map=self.hogs(x),
                mask_chosed=mask_chosed,
                outputs_hog=outputs_hog
            )
            # Student update reconstruction
            desc_loss = self.mse_func(preds_mask, labels_mask)
            loss = loss_consistency + self.ratio * desc_loss
            #print("loss_consistency:",loss_consistency, "desc_loss:", desc_loss)
        else:
            desc_loss = 0.0
            loss = loss_consistency
            #print("loss_consistency:",loss)


        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        # Teacher update
        self.model_ema = update_ema_variables(ema_model = self.model_ema, model = self.model, alpha_teacher= self.alpha_teacher, alpha_vida = self.alpha_vida)
        # Stochastic restore
        # if True:
        #     for nm, m  in self.model.named_modules():
        #         for npp, p in m.named_parameters():
        #             if npp in ['weight', 'bias'] and p.requires_grad:
        #                 mask = (torch.rand(p.shape)<0.001).float().cuda() 
        #                 with torch.no_grad():
        #                     p.data = self.model_state[f"{nm}.{npp}"] * mask + p * (1.-mask)
        return standard_ema

    def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))

        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]

        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        ids_dump = ids_shuffle[:, len_keep:]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore, ids_shuffle, ids_keep, ids_dump

    def saliency_based_masking(self, x, importance, mask_ratio):
        """
        Perform masking based on saliency scores (importance values from decoder attention/angle maps).

        Args:
            x (torch.Tensor): Patch embeddings of shape [B, L, D]
            importance (torch.Tensor): Importance scores of shape [B, L]
            mask_ratio (float): Percentage of patches to be masked

        Returns:
            x_masked (torch.Tensor): Masked patch embeddings
            mask (torch.Tensor): Binary mask (0 = keep, 1 = mask)
            ids_restore (torch.Tensor): Indices for restoring original order
            ids_shuffle (torch.Tensor): Shuffled indices for masking
            ids_keep (torch.Tensor): Indices of kept patches
            ids_dump (torch.Tensor): Indices of masked patches
        """
        B, L, D = x.shape  # Batch, Length (196), Dim
        len_keep = int(L * (1 - mask_ratio))  # How many patches to keep

        # Normalize importance scores
        importance = importance - importance.min(dim=1, keepdim=True)[0]
        importance = (importance / (importance.max(dim=1, keepdim=True)[0] + 1e-6)) + 1e-6

        # Sort patches by importance (descending order = high importance first)
        ids_shuffle = torch.argsort(importance, dim=1, descending=False)
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # ids_shuffle = torch.multinomial(importance, L, replacement=False)
        # ids_restore = torch.argsort(ids_shuffle, dim=1)

        # Select patches to keep (top-K most important patches are masked)
        ids_keep = ids_shuffle[:, :len_keep]
        ids_dump = ids_shuffle[:, len_keep:]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # Create binary mask
        mask = torch.ones([B, L], device=x.device)
        mask[:, :len_keep] = 0  # Keep the least important patches
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore, ids_shuffle, ids_keep, ids_dump

    @torch.no_grad()
    def saliency_last_only(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Fastest saliency estimate – **runs ZERO earlier blocks**.
        Trade-off: ≈4 – 6 % accuracy drop vs. full prop_last.

        Returns
        -------
        att_e , cos_e : each [B , N]  (N = #patches)
        """
        # unwrap if DataParallel
        vit = self.model_ema.module if isinstance(self.model_ema,
                                                  nn.DataParallel) else self.model_ema
        B = img.size(0)

        # 1) patch → tokens  (no CLS, no previous blocks)
        tokens = vit.patch_embed(img)  # [B , N , D]

        # 2) approximate the missing layer-norm with the last block’s norm1
        last_blk = vit.blocks[-1]
        tokens = last_blk.norm1(tokens)  # [B , N , D]

        # 3) Q K from the last block only
        qkv = last_blk.attn.qkv(tokens)  # [B , N , 3·H·hd]
        H = last_blk.attn.num_heads
        hd = qkv.shape[-1] // (3 * H)

        q, k, _ = (qkv.view(B, -1, 3, H, hd)
                   .permute(2, 0, 3, 1, 4)
                   .unbind(0))  # each [B , H , N , hd]

        # 4) attention + cosine maps  [B , H , N , N]
        att = (q @ k.transpose(-2, -1)).softmax(dim=-1)
        qn, kn = q.norm(dim=-1, keepdim=True), k.norm(dim=-1, keepdim=True)
        cos = (q @ k.transpose(-2, -1)) / (qn * kn + 1e-6)

        # helper: [B,H,N,N] → [B,N] with min-max normalisation
        def reduce(hm: torch.Tensor) -> torch.Tensor:
            out = hm.mean(dim=1).sum(dim=1)  # [B , N]
            mn, mx = out.amin(1, keepdim=True), out.amax(1, keepdim=True)
            return (out - mn) / (mx - mn + 1e-6)

        att_e = reduce(att)  # [B , N]
        cos_e = reduce(cos)  # [B , N]
        return att_e, cos_e

    def saliency_importance(self,
                            img: torch.Tensor,
                            # patch_feats: torch.Tensor,
                            mode: str = "last"):
        """
        Args:
          patch_feats: [B, N, D]  (feats[:,1:,:])
          mode: "avg" to average over ALL encoder blocks;
                "last" to only use the final block.

        Returns:
          att_e, cos_e, att_d, cos_d   each [B, N]
        """
        # unwrap DataParallel if needed
        vit = (self.model_ema.module
               if isinstance(self.model_ema, nn.DataParallel)
               else self.model_ema)

        # B, N, D = patch_feats.shape
        patch_feats = vit.patch_embed(img)
        # buffers to collect per-block maps
        all_att_e, all_cos_e = [], []

        # inline helper to turn [B,H,N,N] -> [B,N]
        def process_map(hm: torch.Tensor) -> torch.Tensor:
            # hm: [B, H, N, N]
            out = hm.mean(dim=1)  # [B, N, N]
            out = out.sum(dim=1)  # [B, N]
            mn = out.amin(dim=1, keepdim=True)
            mx = out.amax(dim=1, keepdim=True)
            return (out - mn) / (mx - mn + 1e-6)

        # loop over encoder blocks
        for block in vit.blocks:
            # Q/K/V from this block
            qkv = block.attn.qkv(patch_feats)  # [B,N,3*H*hd]
            B2, N2, C2 = qkv.shape
            H = block.attn.num_heads
            hd = (C2 // 3) // H
            qkv = (qkv.view(B2, N2, 3, H, hd)
                   .permute(2, 0, 3, 1, 4))  # [3,B,H,N,hd]
            q, k, _ = qkv.unbind(dim=0)  # each [B,H,N,hd]

            attn_map = (q @ k.transpose(-2, -1)).softmax(dim=-1)  # [B,H,N,N]
            qn = q.norm(dim=-1, keepdim=True)
            kn = k.norm(dim=-1, keepdim=True)
            cos_map = (q @ k.transpose(-2, -1)) / (qn * kn + 1e-6)  # [B,H,N,N]

            all_att_e.append(process_map(attn_map))
            all_cos_e.append(process_map(cos_map))

        # now aggregate according to mode
        if mode == "avg":
            att_e = torch.stack(all_att_e, dim=0).mean(dim=0)  # [B,N]
            cos_e = torch.stack(all_cos_e, dim=0).mean(dim=0)
        elif mode == "max":
            att_e = torch.stack(all_att_e, dim=0).max(dim=0).values  # [B,N]
            cos_e = torch.stack(all_cos_e, dim=0).max(dim=0).values
        elif mode == "last":
            att_e = all_att_e[-1]
            cos_e = all_cos_e[-1]
        elif mode == "first":
            att_e = all_att_e[0]
            cos_e = all_cos_e[0]
        elif mode == "first_last":
            att_e = (all_att_e[0] + all_att_e[-1]) * .5
            cos_e = (all_cos_e[0] + all_cos_e[-1]) * .5
        elif mode == "prop_last":
            # --- Recompute the *input* to the final block by forwarding through all preceding blocks --
            # 1) patch → tokens (with cls + pos + drop)
            x = vit.patch_embed(img)  # [B, N, D]
            cls = vit.cls_token.expand(B, -1, -1)  # [B,1,D]
            x = torch.cat((cls, x), dim=1)  # [B,N+1,D]
            x = vit.pos_drop(x + vit.pos_embed)

            # 2) forward through all but the last block
            for block in vit.blocks[:-1]:
                x = block(x)

            # 3) now x is what the *last* block actually sees
            last = vit.blocks[-1]
            qkv = last.attn.qkv(x[:, 1:, :])  # drop CLS
            H = last.attn.num_heads

            hd = (qkv.shape[-1] // 3) // H
            q, k, _ = (qkv.view(B, N, 3, H, hd).permute(2, 0, 3, 1, 4).unbind(0))
            attn = (q @ k.transpose(-2, -1)).softmax(-1)

            qn, kn = q.norm(dim=-1, keepdim=True), k.norm(dim=-1, keepdim=True)
            cos = (q @ k.transpose(-2, -1)) / (qn * kn + 1e-6)

            att_e = process_map(attn)
            cos_e = process_map(cos)
        else:
            raise ValueError(f"Unknown mode {mode!r}, use 'avg' or 'last'.")

        return att_e, cos_e

@torch.jit.script
def softmax_entropy(x, x_ema):# -> torch.Tensor:
    """Entropy of softmax distribution from logits."""
    return -0.5*(x_ema.softmax(1) * x.log_softmax(1)).sum(1)-0.5*(x.softmax(1) * x_ema.log_softmax(1)).sum(1)

def collect_params(model):
    """Collect all trainable parameters.

    Walk the model's modules and collect all parameters.
    Return the parameters and their names.

    Note: other choices of parameterization are possible!
    """
    vida_params_list = []
    model_params_list = []

    for name, param in model.named_parameters():
        if ('vida_' in name) or ('recon_proj' in name) or ('mask_token' in name):
            vida_params_list.append(param)
        else:
            model_params_list.append(param)

    return model_params_list, vida_params_list


def copy_model_and_optimizer(model, optimizer):
    """Copy the model and optimizer states for resetting after adaptation."""
    model_state = deepcopy(model.state_dict())
    model_anchor = deepcopy(model)
    optimizer_state = deepcopy(optimizer.state_dict())
    ema_model = deepcopy(model)
    for param in ema_model.parameters():
        param.detach_()
    return model_state, optimizer_state, ema_model, model_anchor


def load_model_and_optimizer(model, optimizer, model_state, optimizer_state):
    """Restore the model and optimizer states from copies."""
    model.load_state_dict(model_state, strict=True)
    optimizer.load_state_dict(optimizer_state)


'''
def configure_model(model, cfg,  head_dim, num_class):
    """Configure model for use with tent."""
    # train mode, because tent optimizes the model to minimize entropy
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.cpu()
    vida_params, vida_names = inject_trainable_vida(model = model, target_replace_module = ["CrossAttention", "Attention"], \
            r = cfg.TEST.vida_rank1, r2 = cfg.TEST.vida_rank2)
    #projection head
    embed_dim, recon_dim = head_dim, num_class
    model.recon_proj = nn.Linear(embed_dim, recon_dim, bias=True)
    nn.init.trunc_normal_(model.recon_proj.weight, std=0.02)
    nn.init.zeros_(model.recon_proj.bias)

    #if cfg.TEST.ckpt!=None:
    #    model = torch.nn.DataParallel(model) # make parallel
    #    checkpoint = torch.load(cfg.TEST.ckpt)
    #    model.load_state_dict(checkpoint, strict=True)

    if cfg.TEST.ckpt is not None:
        state = torch.load(cfg.TEST.ckpt)

        # if the file was saved with torch.save(model.state_dict(), …)
        # it is already a state-dict; otherwise take the ['model'] entry.
        if 'model' in state:
            state = state['model']

        # de-parallelise if it was saved from DataParallel
        if next(iter(state)).startswith('module.'):
            state = {k[7:]: v for k, v in state.items()}

        # you added recon_proj after the checkpoint was saved, so use strict=False
        missing, unexpected = model.load_state_dict(state, strict=False)
        print(f"✓ checkpoint loaded  (missing={len(missing)}, unexpected={len(unexpected)})")

    #model = torch.nn.DataParallel(model)#.cuda()

    model.to(device)
    model.train()
    return model
'''

def configure_model(model, cfg, head_dim, num_class):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.cpu()

    # 1️⃣  inject ViDA adapters
    inject_trainable_vida(model, target_replace_module=["CrossAttention", "Attention"],
                          r=cfg.TEST.vida_rank1, r2=cfg.TEST.vida_rank2)

    # 2️⃣  reconstruction head
    model.recon_proj = nn.Linear(head_dim, num_class, bias=True)
    nn.init.trunc_normal_(model.recon_proj.weight, std=0.02)
    nn.init.zeros_(model.recon_proj.bias)

    # 3️⃣  learnable mask token  (⬅️  add **before** wrapping)
    model.mask_token = nn.Parameter(
        torch.zeros(1, 1, head_dim, device=device), requires_grad=True
    )

    # 4️⃣  load checkpoint (strict=False because recon_proj is new)
    if cfg.TEST.ckpt is not None:
        state = torch.load(cfg.TEST.ckpt)
        if 'model' in state:
            state = state['model']
        if next(iter(state)).startswith('module.'):
            state = {k[7:]: v for k, v in state.items()}
        model.load_state_dict(state, strict=False)

    # 5️⃣  single DataParallel wrap – now mask_token is replicated, too
    model = torch.nn.DataParallel(model).to(device).train()
    return model


def check_model(model):
    """Check model for compatability with tent."""
    is_training = model.training
    assert is_training, "tent needs train mode: call model.train()"
    param_grads = [p.requires_grad for p in model.parameters()]
    has_any_params = any(param_grads)
    has_all_params = all(param_grads)
    assert has_any_params, "tent needs params to update: " \
                           "check which require grad"
    assert not has_all_params, "tent should not update all params: " \
                               "check which require grad"
    has_bn = any([isinstance(m, nn.BatchNorm2d) for m in model.modules()])
    assert has_bn, "tent needs normalization for its optimization"
