from typing import *
from PIL.Image import Image as PILImage
from torch import Tensor

import numpy as np
from skimage.metrics import structural_similarity as calculate_ssim
import torch
import torch.nn.functional as F
from transformers import (
    CLIPImageProcessor, CLIPVisionModelWithProjection,
    CLIPTokenizer, CLIPTextModelWithProjection,
)
import ImageReward as RM
from kiui.lpips import LPIPS


class TextConditionMetrics:
    def __init__(self,
        clip_name: str = "openai/clip-vit-base-patch32",
        rm_name: str = "ImageReward-v1.0",
        device_idx: int = 0,
    ):
        self.image_processor = CLIPImageProcessor.from_pretrained(clip_name)
        self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_name).to(f"cuda:{device_idx}").eval()

        self.tokenizer = CLIPTokenizer.from_pretrained(clip_name)
        self.text_encoder = CLIPTextModelWithProjection.from_pretrained(clip_name).to(f"cuda:{device_idx}").eval()

        self.rm_model = RM.load(rm_name, download_root="/.cache/hf")

        self.device = f"cuda:{device_idx}"

    @torch.no_grad()
    def evaluate(self,
        image: Union[PILImage, List[PILImage]],
        text: Union[str, List[str]],
    ) -> Tuple[float, float, float]:
        if isinstance(image, PILImage):
            image = [image]
        if isinstance(text, str):
            text = [text]

        assert len(image) == len(text)

        image_inputs = self.image_processor(image, return_tensors="pt").pixel_values.to(self.device)
        image_embeds = self.image_encoder(image_inputs).image_embeds.float()  # (N, D)
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

        text_inputs = self.tokenizer(
            text,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        text_input_ids = text_inputs.input_ids.to(self.device)
        text_embeds = self.text_encoder(text_input_ids).text_embeds.float()  # (N, D)
        text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)

        assert image_embeds.shape == text_embeds.shape

        clip_scores = image_embeds @ text_embeds.T  # (N, N)

        # 1. CLIP similarity
        clip_sim = clip_scores.diag().mean().item()

        # 2. CLIP R-Precision
        clip_rprec = (clip_scores.argmax(dim=1) == torch.arange(len(text)).to(self.device)).float().mean().item()

        # 3. ImageReward
        rm_scores = []
        for img, txt in zip(image, text):
            rm_scores.append(self.rm_model.score(txt, img))
        rm_scores = torch.tensor(rm_scores, device=self.device)
        rm_score = rm_scores.mean().item()

        return clip_sim, clip_rprec, rm_score


class ImageConditionMetrics:
    def __init__(self,
        lpips_net: str = "vgg",
        lpips_res: int = 256,
        device_idx: int = 0,
    ):
        self.lpips_loss = LPIPS(net=lpips_net).to(f"cuda:{device_idx}").eval()

        self.lpips_res = lpips_res
        self.device = f"cuda:{device_idx}"

    @torch.no_grad()
    def evaluate(self,
        image: Union[Tensor, PILImage, List[PILImage]],
        gt: Union[Tensor, PILImage, List[PILImage]],
        chunk_size: Optional[int] = None,
        input_tensor: bool = False,
    ) -> Tuple[Tuple[float, float], Tuple[float, float], Tuple[float, float]]:
        if not input_tensor:
            if isinstance(image, PILImage):
                image = [image]
            if isinstance(gt, PILImage):
                gt = [gt]

            assert len(image) == len(gt)

            if chunk_size is None:
                chunk_size = len(image)

            def image_to_tensor(img: PILImage):
                return torch.tensor(np.array(img).transpose(2, 0, 1) / 255., device=self.device).unsqueeze(0).float()  # (1, 3, H, W)
            image_pt = torch.cat([image_to_tensor(img) for img in image], dim=0)
            gt_pt = torch.cat([image_to_tensor(img) for img in gt], dim=0)
        else:
            image_pt = image.to(device=self.device)
            gt_pt = gt.to(device=self.device)

        # 1. LPIPS
        lpips = []
        for i in range(0, len(image), chunk_size):
            _lpips = self.lpips_loss(
                F.interpolate(
                    image_pt[i:min(len(image), i+chunk_size)] * 2. - 1.,
                    (self.lpips_res, self.lpips_res), mode="bilinear", align_corners=False
                ),
                F.interpolate(
                    gt_pt[i:min(len(image), i+chunk_size)] * 2. - 1.,
                    (self.lpips_res, self.lpips_res), mode="bilinear", align_corners=False
                )
            )
            lpips.append(_lpips)
        lpips = torch.cat(lpips)
        lpips_mean, lpips_std = lpips.mean().item(), lpips.std().item()

        # 2. PSNR
        psnr = -10. * torch.log10((gt_pt - image_pt).pow(2).mean(dim=[1, 2, 3]))
        psnr_mean, psnr_std = psnr.mean().item(), psnr.std().item()

        # 3. SSIM
        ssim = []
        for i in range(len(image)):
            _ssim = calculate_ssim(
                (image_pt[i].cpu().float().numpy() * 255.).astype(np.uint8),
                (gt_pt[i].cpu().float().numpy() * 255.).astype(np.uint8),
                channel_axis=0,
            )
            ssim.append(_ssim)
        ssim = np.array(ssim)
        ssim_mean, ssim_std = ssim.mean(), ssim.std()

        return (psnr_mean, psnr_std), (ssim_mean, ssim_std), (lpips_mean, lpips_std)
