from logging import warn
import torch
import timm
import timm.data
import re
from types import SimpleNamespace
from functools import partial

from .general_encoder import GeneralVisionTower
from .dvt.models.vit_wrapper import PretrainedViTWrapper
# {'input_size': (3, 518, 518), 'interpolation': 'bicubic', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'crop_pct': 1.0, 'crop_mode': 'center'}

class TimmImageProcessor:
    def __init__(self, processor, config=None):
        self.processor = processor
        self.config = config


    def preprocess(self, images, return_tensors='pt'):
        result = self.processor(images).unsqueeze(0)
        if return_tensors == 'np':
            result = result.numpy()
        return {'pixel_values': result}

    def __call__(self, images, return_tensors='pt'):
        return self.preprocess(images)
    
    @property
    def image_mean(self):
        return self.config["mean"]
    
    @property
    def crop_size(self):
        H, W = self.config["input_size"][1:]
        return {"height": H, "width": W}

class TimmModelVisionTower(GeneralVisionTower):
    def __init__(self, vision_tower, args, delay_load=False, loader=None, ckpt_path=None):
        super(GeneralVisionTower, self).__init__()

        self.is_loaded = False
        self.ckpt_path = ckpt_path

        self.vision_tower_name = vision_tower
        if vision_tower.startswith('timm/'):
            vision_tower = vision_tower[5:]
        tails = ['.pth', '.pt', '.bin']
        for tail in tails:
            if vision_tower.endswith(tail):
                vision_tower = vision_tower[:-len(tail)]
        self.model_identifier = vision_tower
        self.select_layer = args.mm_vision_select_layer
        self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
        self.force_image_size = args.mm_force_imsize if hasattr(args, 'mm_force_imsize') else None

        if not delay_load:
            self.load_model()
        elif getattr(args, 'unfreeze_mm_vision_tower', False):
            self.load_model()
        else:
            self.cfg_only = None

    def set_imsize(self, imsize:int):
        # maybe modify the data config, then rebuild the transform
        raise NotImplementedError('Not implemented yet.')

    def load_model(self, device_map=None):
        if self.is_loaded:
            print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
            return

        patch_size = int(re.search(r"patch(\d+)", self.model_identifier).group(1))
        self.vision_tower = PretrainedViTWrapper(model_identifier=self.model_identifier, stride=patch_size)
        if self.ckpt_path is not None:
            ckpt = torch.load(self.ckpt_path)
            try:
                msg = self.vision_tower.load_state_dict(ckpt["model"])
            except:
                msg = self.vision_tower.load_state_dict(ckpt["denoiser"])
                from logging import warning
                warning(f'Why loading denoiser? distilled DVT should be a distilled model.')
            print(f'Loaded Timm model from a local path {self.ckpt_path}, {msg}')
        self.vision_tower.requires_grad_(False)
        if device_map is not None:
            if device_map == 'auto' or device_map == 'cuda':
                device = 'cuda' if torch.cuda.is_available() else 'cpu'
            else:
                device = device_map
            self.vision_tower.to(device)
        data_config = timm.data.resolve_model_data_config(model=self.vision_tower.model)
        if data_config.get('crop_pct', 1.0) < 1.0:
            from logging import warning
            warning(f'Crop pct is {data_config.get("crop_pct", 1.0)}, take care.')
        self.image_processor = TimmImageProcessor(self.vision_tower.transformation, data_config)
        self._config = SimpleNamespace(
            image_size= data_config['input_size'][-1],
            hidden_size=self.vision_tower.n_output_dims,
            patch_size=patch_size,
        )
        if self.force_image_size is not None:
            self.set_imsize(self.force_image_size)

        self.is_loaded = True

    def feature_select(self, image_forward_outs):
        image_features, cls_features = image_forward_outs[0]
        if cls_features.shape[1] > 1: # has register tokens
            cls_features, reg_features = cls_features[:, 0:1], cls_features[:, 1:]
        if self.select_feature == 'patch':
            image_features = image_features
        elif self.select_feature == 'cls_patch':
            image_features = torch.cat([cls_features, image_features], dim=1)
        elif self.select_feature == 'reg_patch':
            image_features = torch.cat([reg_features, image_features], dim=1)
        elif self.select_feature == 'cls_reg_patch':
            image_features = torch.cat([cls_features, reg_features, image_features], dim=1)
        else:
            raise ValueError(f'Unexpected select feature: {self.select_feature}')
        return image_features

    @torch.no_grad()
    def forward(self, images):
        feature_extract_func = partial(
            self.vision_tower.get_intermediate_layers, 
            n=[self.vision_tower.num_blocks + self.select_layer if self.select_layer < 0 else self.select_layer], 
            reshape=False, 
            return_prefix_tokens=True, 
            norm=True # True or False?
        )
        if type(images) is list:
            image_features = []
            for image in images:
                image_forward_out = feature_extract_func(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
                image_feature = self.feature_select(image_forward_out).to(image.dtype)
                image_features.append(image_feature)
        else:
            image_forward_outs = feature_extract_func(images.to(device=self.device, dtype=self.dtype))
            image_features = self.feature_select(image_forward_outs).to(images.dtype)

        return image_features

    @property
    def dtype(self):
        return next(self.vision_tower.parameters()).dtype

    @property
    def device(self):
        return next(self.vision_tower.parameters()).device

    @property
    def config(self):
        if self.is_loaded:
            return self._config
        else:
            return self.cfg_only
