from email.mime import image
import open_clip.timm_model
import torch
import os
import torch.nn as nn
from types import SimpleNamespace

from transformers import AutoModel, AutoImageProcessor, AutoConfig
import open_clip
from huggingface_hub import repo_exists
from .general_encoder import GeneralVisionTower

def get_device(device_map):
    if device_map is None or device_map == 'auto':
        return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    return torch.device(device_map)

class OpenCLIPImageProcessor:
    def __init__(self, processor, config=None):
        self.processor = processor
        self.config = config

    def preprocess(self, images, return_tensors='pt'):
        result = self.processor(images).unsqueeze(0)
        if return_tensors == 'np':
            result = result.numpy()
        return {'pixel_values': result}

    def __call__(self, images, return_tensors='pt'):
        return self.preprocess(images)
    
    @property
    def image_mean(self):
        return self.config["mean"]
    
    @property
    def crop_size(self):
        return self.config["image_size"]

class OpenCLIPConvNeXt(nn.Module):
    def __init__(self, visual: open_clip.timm_model.TimmModel):
        super().__init__()
        self.model = visual.trunk
        self.config = SimpleNamespace(
            image_size=visual.image_size[0] if isinstance(visual.image_size, (list, tuple)) else visual.image_size,
            hidden_size=self.model.num_features,
            patch_size=self.get_patch_size(self.model),
        )
    
    @staticmethod
    def get_patch_size(model: nn.Module):
        downsampling_factor_h = 1
        downsampling_factor_w = 1

        for module in model.modules():
            if isinstance(module, nn.Conv2d):
                stride = module.stride  # stride is a tuple (stride_height, stride_width)
                downsampling_factor_h *= stride[0]
                downsampling_factor_w *= stride[1]

        return downsampling_factor_h if downsampling_factor_h == downsampling_factor_w else (downsampling_factor_h, downsampling_factor_w)

    def forward(self, images, output_hidden_states=True):
        return self.model.forward_features(images)

class OpenCLIPVisionTower(GeneralVisionTower):
    def __init__(self, vision_tower, args, delay_load=False):
        super(GeneralVisionTower, self).__init__()

        self.is_loaded = False

        self.vision_tower_name = vision_tower
        self.select_layer = args.mm_vision_select_layer
        self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
        self.force_image_size = args.mm_force_imsize if hasattr(args, 'mm_force_imsize') else None

        if not delay_load:
            self.load_model()
        elif getattr(args, 'unfreeze_mm_vision_tower', False):
            self.load_model()
        else:
            self.cfg_only = open_clip.factory._get_hf_config(self.vision_tower_name)

    def load_model(self, device_map=None):
        if self.is_loaded:
            print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
            return

        if repo_exists(self.vision_tower_name):
            config = open_clip.factory._get_hf_config(self.vision_tower_name)
            model_config = config['model_cfg']
            preprocess_config = config['preprocess_cfg']
            if self.force_image_size is not None:
                model_config['vision_cfg']['image_size'] = self.force_image_size
            preprocess_config['image_size'] = {
                'height': model_config['vision_cfg']['image_size'],
                'width': model_config['vision_cfg']['image_size']
            }
            self.vision_tower_name = 'hf-hub:' + self.vision_tower_name
        else:
            raise ValueError(f'Unknown vision tower: {self.vision_tower_name}')

        model, preprocess = open_clip.create_model_from_pretrained(self.vision_tower_name, force_image_size=self.force_image_size, device=get_device(device_map))
        self.image_processor = OpenCLIPImageProcessor(preprocess, config = preprocess_config)
        self.vision_tower = OpenCLIPConvNeXt(model.visual)
        self.vision_tower.requires_grad_(False)

        del model
        self.is_loaded = True

    def feature_select(self, image_forward_outs: torch.Tensor):
        if self.select_layer != -1:
            raise NotImplementedError('OpenCLIP does not support selecting intermediate layers')
        # (N, C, H, W) ->(N, H*W, C) or (C, H, W) to (H*W, C)
        image_features = image_forward_outs.view(*image_forward_outs.shape[:-2], -1).transpose(-2, -1)

        if self.select_feature == 'patch':
            image_features = image_features[:, :]
        else:
            raise ValueError(f'Unexpected select feature: {self.select_feature}')
        return image_features

    @property
    def dtype(self):
        return next(self.vision_tower.parameters()).dtype

    @property
    def device(self):
        return next(self.vision_tower.parameters()).device

