import warnings

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig, ViTConfig
from llava.model.fmri_encoder.vit3d import CLIPVision3dModelWithProjection
from llava.model.fmri_encoder.vit3d_decoder import ViT3dWithProjectionModel


class CLIPVisionTower(nn.Module):
    def __init__(self, vision_tower, args, delay_load=False):
        super().__init__()

        self.is_loaded = False

        self.vision_tower_name = vision_tower
        self.select_layer = args.mm_vision_select_layer
        self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
        self.tune_vision_tower = getattr(args, 'tune_vision_tower', False)

        self.patch_size = getattr(args, 'mm_vision_patch_size', 14)
        self.fmri_shape = getattr(args, 'mm_vision_image_size', (83, 104, 81))
        self.total_fmri_size = self.fmri_shape[0] * self.fmri_shape[1] * self.fmri_shape[2]
        from_scratch = getattr(args, 'vision_tower_from_scratch', False)

        self.loss_fn = nn.MSELoss()

        if not delay_load:
            self.load_model(from_scratch=from_scratch)
        elif getattr(args, 'unfreeze_mm_vision_tower', False):
            self.load_model(from_scratch=from_scratch)
        else:
            self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)

        self._dtype = torch.float32

    def load_model(self, device_map=None, from_scratch=False):
        if self.is_loaded:
            print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
            return

        if '3d' in self.vision_tower_name:
            padding = []
            if self.patch_size:
                for dim in reversed(self.fmri_shape):
                    pad_size = (self.patch_size - dim % self.patch_size) % self.patch_size
                    padding.extend([pad_size // 2, pad_size // 2 + (1 if pad_size % 2 else 0)])

            def processor(x, mean=None, std=None, *args, **kwargs):
                if x.shape[-3:] != self.fmri_shape:
                    # print(x.shape)
                    x = F.interpolate(
                        x.float().unsqueeze(0).unsqueeze(0),
                        size=self.fmri_shape,
                        mode='trilinear',
                        align_corners=False
                    )
                    x = x.squeeze(0).squeeze(0)
                if mean is None or std is None:
                    warnings.warn('mean and std not provided, may cause issues with 3D models.')
                else:
                    x = (x - mean) / (std + 1e-5)
                # x = x / 1e5
                x = F.pad(x, pad=padding, mode='constant', value=0.)
                return x.unsqueeze(0)

            self.image_processor = processor
            if from_scratch:
                config = ViTConfig.from_pretrained(self.vision_tower_name)
                self.vision_tower = ViT3dWithProjectionModel(config)
            else:
                self.vision_tower = ViT3dWithProjectionModel.from_pretrained(self.vision_tower_name)
        else:
            self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
            self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)

        if self.tune_vision_tower:
            self.vision_tower.requires_grad_(True)
        else:
            self.vision_tower.requires_grad_(False)

        self.is_loaded = True

    def feature_select(self, image_forward_outs):
        if self.select_feature == 'patch':
            image_features = image_forward_outs.hidden_states[self.select_layer]
            image_features = image_features[:, 1:]
        elif self.select_feature == 'cls_patch':
            image_features = image_forward_outs.image_embeds
        else:
            raise ValueError(f'Unexpected select feature: {self.select_feature}')
        return image_features

    # @torch.no_grad()
    def forward(self, images, vision_embeds=None):
        if isinstance(images, list):
            image_features = []
            predicted_embeds = []
            vision_loss = []
            for image in images:
                image_forward_out = self.vision_tower(
                    image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
                    labels=vision_embeds,
                    output_hidden_states=True
                )
                image_feature = self.feature_select(image_forward_out).to(image.dtype)
                image_features.append(image_feature)
                predicted_embeds.append(image_forward_out.image_embeds)
                vision_loss.append(image_forward_out.loss)
        else:
            image_forward_outs = self.vision_tower(
                images.to(device=self.device, dtype=self.dtype),
                labels=vision_embeds,
                output_hidden_states=True
            )
            image_features = self.feature_select(image_forward_outs).to(images.dtype)
            vision_loss = image_forward_outs.loss

        return image_features, vision_loss

    @property
    def dummy_feature(self):
        return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)

    @property
    def dtype(self):
        return self.vision_tower.dtype

    @property
    def device(self):
        return self.vision_tower.device

    @property
    def config(self):
        if self.is_loaded:
            return self.vision_tower.config
        else:
            return self.cfg_only

    @property
    def hidden_size(self):
        return self.config.hidden_size

    @property
    def num_patches_per_side(self):
        return (self.total_fmri_size ** 0.333) // self.config.patch_size

    @property
    def num_patches(self):
        if hasattr(self.vision_tower.embeddings.patch_embeddings, 'token_ids'):
            print('Using token_ids', len(self.vision_tower.embeddings.patch_embeddings.token_ids))
            return len(self.vision_tower.embeddings.patch_embeddings.token_ids)
        return self.total_fmri_size // (self.config.patch_size ** 3)
