from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms.functional as F
from torchvision.transforms import Normalize, Compose, Resize, CenterCrop, ToTensor, InterpolationMode
import os
import huggingface_hub
import hpsv2
from hpsv2 import img_score
from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer

OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)

hps_version_map = {
    "v2.0": "HPS_v2_compressed.pt",
    "v2.1": "HPS_v2.1_compressed.pt",
}
environ_root = os.environ.get('HPS_ROOT')
root_path = os.path.expanduser('~/.cache/hpsv2') if environ_root == None else environ_root

class MaskAwareNormalize(nn.Module):
    def __init__(self, mean, std):
        super().__init__()
        self.normalize = Normalize(mean=mean, std=std)

    def forward(self, tensor):
        if tensor.shape[0] == 4:
            return torch.cat([self.normalize(tensor[:3]), tensor[3:]], dim=0)
        else:
            return self.normalize(tensor)
        
class ResizeMaxSize(nn.Module):
    def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
        super().__init__()
        if not isinstance(max_size, int):
            raise TypeError(f"Size should be int. Got {type(max_size)}")
        self.max_size = max_size
        self.interpolation = interpolation
        self.fn = min if fn == 'min' else min
        self.fill = fill

    def forward(self, img):
        if isinstance(img, torch.Tensor):
            height, width = img.shape[1:]
        else:
            width, height = img.size
        scale = self.max_size / float(max(height, width))
        if scale != 1.0:
            new_size = tuple(round(dim * scale) for dim in (height, width))
            img = F.resize(img, new_size, self.interpolation)
            pad_h = self.max_size - new_size[0]
            pad_w = self.max_size - new_size[1]
            img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
        return img

def differentiable_image_transform(
    image_size: int,
    resize_longest_max: bool = False,
    fill_color: int = 0,
):
    mean = OPENAI_DATASET_MEAN
    std = OPENAI_DATASET_STD

    if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
        # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
        image_size = image_size[0]

    normalize = MaskAwareNormalize(mean=mean, std=std)
    
    transforms = [
        # _convert_to_rgb_or_rgba,
        # ToTensor(),
    ]
    if resize_longest_max:
        transforms.extend([
            ResizeMaxSize(image_size, fill=fill_color)
        ])
    else:
        transforms.extend([
            Resize(image_size, interpolation=InterpolationMode.BICUBIC),
            CenterCrop(image_size),
        ])
    transforms.extend([
        normalize,
    ])
    return Compose(transforms)

class HPSV2Score:
    def __init__(self, model_path=None, device='cuda', hps_version="v2.0"):
        self.device = device

        model_dict = img_score.initialize_model()
        self.model = model_dict['model']
        self.preprocess_val = model_dict['preprocess_val']
        self.tensor_preprocess_val = differentiable_image_transform(
            self.model.visual.image_size,
            resize_longest_max=True,
        )

        if not os.path.exists(root_path):
            os.makedirs(root_path)
        if model_path is None:
            model_path = huggingface_hub.hf_hub_download("xswu/HPSv2", hps_version_map[hps_version])

        checkpoint = torch.load(model_path, map_location=device)
        self.model.load_state_dict(checkpoint['state_dict'])
        self.tokenizer = get_tokenizer('ViT-H-14')
        self.model = self.model.to(device)
        self.model.eval()
        
    def score(self, img, prompt):
        if isinstance(img, str):
            img = Image.open(img)
            image = self.preprocess_val(img).unsqueeze(0).to(device=self.device, non_blocking=True)
            text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
            with torch.cuda.amp.autocast():
                outputs = self.model(image, text)
                image_features, text_features = outputs["image_features"], outputs["text_features"]
                logits_per_image = image_features @ text_features.T

                hps_score = torch.diagonal(logits_per_image)
            return hps_score[0]

        elif isinstance(img, torch.Tensor):
            # img.requires_grad = True # for testing
            # print("Warning: Now using image tensor input")
            if len(img.shape) == 4:
                print("Warning: Squeezing image tensor from 4D to 3D")
                img = img.squeeze(0)
            image = self.tensor_preprocess_val(img).unsqueeze(0).to(device=self.device, non_blocking=True)
            text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
            with torch.cuda.amp.autocast():
                outputs = self.model(image, text)
                image_features, text_features = outputs["image_features"], outputs["text_features"]
                logits_per_image = image_features @ text_features.T

                hps_score = torch.diagonal(logits_per_image)

                # hps_score[0].backward()
            return hps_score[0]
