import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import models_vit
from . import models_mae
from .swin_transformer import swin_t
from torchvision.models import Swin_T_Weights
from .simmim import build_simmim
import einops
import logging

logger = logging.getLogger(__name__)

ViT_MODELS = {
    'vit_base': [models_vit.vit_base_patch16, "vit_base.pth"],
    'vit_large': [models_vit.vit_large_patch16, "vit_large.pth"],
    'vit_huge': [models_vit.vit_huge_patch14, "vit_huge.pth"],
    'swin_tiny': [swin_t, Swin_T_Weights.DEFAULT],
    'mae_base': [models_mae.mae_vit_base_patch16, "mae_visualize_vit_base.pth"],
    'simmim': [build_simmim, "simmim_pretrain__swin_base__img192_window6__800ep.pth"],
}
    
class UTSClassification(nn.Module):
    def __init__(self, config):
        super(UTSClassification, self).__init__()
        self.ckpt = config.ckpt
        self.load_ckpt = config.load_ckpt

        if config.backbone == 'simmim':
            self.vm = ViT_MODELS[config.backbone][0](config)
        else:
            self.vm = ViT_MODELS[config.backbone][0]()

        self.head = nn.LazyLinear(out_features=config.num_class)
        if config.head_path and not config.save_vm:
            head_ckpt = torch.load(config.head_path, map_location='cpu')
            self.head.load_state_dict(head_ckpt, strict=False)

        if self.load_ckpt:
            if config.backbone != 'swin_tiny': # vit & mae
                if config.vm_path and not config.save_vm:
                    ckpt_path = config.vm_path
                else:
                    ckpt_path = os.path.join(self.ckpt, ViT_MODELS[config.backbone][1])
                if not os.path.isfile(ckpt_path):
                    assert False, f"Checkpoint file {ckpt_path} not found. "
                try:
                    print('loading visual model checkpoint from {}'.format(ckpt_path))
                    checkpoint = torch.load(ckpt_path, map_location='cpu')
                    if config.vm_path and not config.save_vm:
                        self.vm.load_state_dict(checkpoint, strict=False)
                    else:
                        self.vm.load_state_dict(checkpoint['model'], strict=False)
                except:
                    print("ERROR - " * 10)
                    print(f"Bad checkpoint file. Please delete {ckpt_path} and redownload!")
                    print("ERROR - " * 10)
            else: # swin case
                weight = ViT_MODELS[config.backbone][1]
                self.vm.load_state_dict(weight.get_state_dict(progress=True), strict=False)
                self.vm_transform = weight.transforms()

        finetune_type = config.ft_type
        print(f"Finetune type: {finetune_type}")
        if finetune_type != 'full':
            for n, param in self.vm.named_parameters():
                if 'ln' == finetune_type:
                    param.requires_grad = 'norm' in n
                elif 'full_wo_cls_mask' == finetune_type:
                    param.requires_grad = 'mask_token' not in n and 'cls_token' not in n
                elif 'mlp+ln' == finetune_type: # without attn
                    param.requires_grad = 'mlp' in n or 'norm' in n or 'patch_embed' in n or 'decoder_embed' in n or 'decoder_pred' in n
                elif 'zero-shot' == finetune_type:
                    param.requires_grad = False

    def _vm_forward(self, x):
        # x of size (B, C, H, W)
        # x = self.vm_transform(x) if self.vm_transform else x
        x = self.vm.forward_features(x)
        return x
        

    def forward(self, x): # x of size (b, n, h, w)
        B = x.shape[0]
        x = einops.rearrange(x, 'b n h w -> (b n) h w')
        x = x.unsqueeze(1)
        x = x.repeat(1, 3, 1, 1)
        vm_output = self._vm_forward(x) # (b*n, d)
        vm_output = einops.rearrange(vm_output, '(b n) d -> b n d', b=B)
        vm_output = vm_output.flatten(1)
        output = self.head(vm_output)
        return output