import torch
import torch.nn as nn
from types import SimpleNamespace
from functools import partial

from .general_encoder import GeneralVisionTower
from transformers import AutoModel, CLIPImageProcessor, AutoConfig
from transformers import logging
logger = logging.get_logger(__name__)
import warnings


DEFAULT_IMAGE_PROCESSOR = 'nvidia/RADIO-L'

class RADIOConfig():
    def __init__(self, model: nn.Module, image_size: int):
        self.model = model
        self.image_size = image_size
    
    @property
    def hidden_size(self):
        return self.model.hidden_size

    @property
    def patch_size(self):
        return self.model.patch_size

class RADIOVisionTower(GeneralVisionTower):
    def __init__(self, vision_tower, args, delay_load=False, loader=(AutoModel, CLIPImageProcessor, AutoConfig)):
        super().__init__(vision_tower, args, delay_load=False, loader=loader)

    def set_imsize(self, imsize:int):
        self._config.image_size = imsize
        self.image_processor.do_center_crop = True
        self.image_processor.do_resize = True
        self.image_processor.size['shortest_edge'] = imsize
        self.image_processor.crop_size['width'] = imsize
        self.image_processor.crop_size['height'] = imsize

    def load_model(self, device_map=None):
        if self.is_loaded:
            print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
            return

        self.image_processor = self.image_processor_loader.from_pretrained(DEFAULT_IMAGE_PROCESSOR, trust_remote_code=True)
        use_meta = torch.nn.Linear(1,1).weight.device==torch.device('meta')
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")   
            self.vision_tower = torch.hub.load('NVlabs/RADIO', 'radio_model', version=self.vision_tower_name, progress=True, skip_validation=True)
        self.vision_tower.requires_grad_(False)
        if device_map is not None:
            if device_map == 'auto' or device_map == 'cuda':
                device = 'cuda' if torch.cuda.is_available() else 'cpu'
            else:
                device = device_map
            self.vision_tower.to(device)
        self.vision_tower.eval()
        if self.force_image_size is None:
            logger.warning(f'force_image_size is None, using the preferred_resolution {self.vision_tower.preferred_resolution.width}')
            self.force_image_size = self.vision_tower.preferred_resolution.widthimg
        self._config = SimpleNamespace(
            image_size= self.force_image_size,
            hidden_size=self.vision_tower.model.num_features,
            patch_size=self.vision_tower.patch_size,
            summary_len=self.vision_tower.summary_idxs.shape[0],
        )
        self.set_imsize(self.force_image_size)

        self.is_loaded = True

    def feature_select(self, image_forward_outs):
        summary, image_features = image_forward_outs[0]
        if self.select_feature == 'patch':
            image_features = image_features
        elif self.select_feature == 'sum_patch':
            summary=summary.view(summary.shape[0], -1, image_features.shape[-1])
            image_features = torch.cat([summary, image_features], dim=1)
        else:
            raise ValueError(f'Unexpected select feature: {self.select_feature}')
        return image_features

    @torch.no_grad()
    def forward(self, images):
        feature_extract_func = partial(
            self.vision_tower.forward_intermediates, 
            indices=[len(self.vision_tower.model.blocks) + self.select_layer if self.select_layer < 0 else self.select_layer], 
            output_fmt="NLC", 
            return_prefix_tokens=True, 
            norm=False, # True or False?
            intermediates_only=True,
        )
        if type(images) is list:
            image_features = []
            for image in images:
                with torch.autocast('cuda', self.dtype):
                    image_forward_out = feature_extract_func(image.to(device=self.device).unsqueeze(0))
                image_feature = self.feature_select(image_forward_out).to(image.dtype)
                image_features.append(image_feature)
        else:
            with torch.autocast('cuda', self.dtype):
                image_forward_outs = feature_extract_func(images.to(device=self.device))
            image_features = self.feature_select(image_forward_outs).to(images.dtype)

        return image_features

    @property
    def dtype(self):
        return next(self.vision_tower.parameters()).dtype

    @property
    def device(self):
        return next(self.vision_tower.parameters()).device

    @property
    def config(self):
        if self.is_loaded:
            return self._config
        else:
            return self.cfg_only


# class RADIOVisionTower(nn.Module):
#     def __init__(self, vision_tower, args, delay_load=False):
#         super().__init__()

#         self.is_loaded = False

#         self.vision_tower_name = vision_tower
#         self.select_layer = args.mm_vision_select_layer
#         self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
#         self.resolution = getattr(args, 'mm_vision_resolution', 384)

#         if not delay_load:
#             self.load_model()
#         elif getattr(args, 'unfreeze_mm_vision_tower', False):
#             self.load_model()

#     def load_model(self, device_map=None):
#         if self.is_loaded:
#             print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
#             return

#         self.image_processor = CLIPImageProcessor.from_pretrained(DEFAULT_IMAGE_PROCESSOR)
#         self.image_processor.do_center_crop = True
#         self.image_processor.do_resize = True
#         self.image_processor.size['shortest_edge'] = self.resolution
#         self.image_processor.crop_size['width'] = self.resolution
#         self.image_processor.crop_size['height'] = self.resolution

#         model = torch.hub.load('NVlabs/RADIO', 'radio_model', version=self.vision_tower_name, progress=True, skip_validation=True)

#         for param in model.parameters():
#             param.requires_grad_(False)

#         self.vision_tower = model.cuda().eval()
#         self.is_loaded = True

#     @torch.no_grad()
#     def forward(self, images):
#         images_dtype = images.dtype
#         with torch.autocast('cuda', self.dtype):
#             summary, image_features = self.vision_tower(images)
#         return image_features.to(dtype=images_dtype)

#     @property
#     def dtype(self):
#         return self.vision_tower.model.patch_generator.embedder.weight.dtype

#     @property
#     def device(self):
#         return self.vision_tower.model.patch_generator.embedder.weight.device

#     @property
#     def hidden_size(self):
#         if self.vision_tower_name == 'radio_v2.5-b':
#             return 768
#         elif self.vision_tower_name == 'radio_v2.5-l':
#             return 1024
#         else:
#             raise NotImplementedError(f'hidden_size for {self.vision_tower_name} not implemented')
