from PIL import Image
import os
import io
import numpy as np
import torch
import torchvision
from importlib import resources
ASSETS_PATH = resources.files("assets")

# def aesthetic_score(
#     torch_dtype=None,
#     aesthetic_target=None,
#     grad_scale=0,
#     device=None,
#     return_loss=False,
#     accelerator=None
# ):

#     target_size = 224
#     normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
#                                                 std=[0.26862954, 0.26130258, 0.27577711])
#     scorer = AestheticScorerDiff(dtype=torch_dtype).to(device, dtype=torch_dtype)
#     scorer.requires_grad_(False)
#     if accelerator is not None:
#         print("accelerate reward")
#         scorer = accelerator.prepare(scorer)
#     target_size = 224

#     if not return_loss:
#         def _fn(images, prompts):
#             if images.min() < 0: # normalize unnormalized images
#                 images = ((images / 2) + 0.5).clamp(0, 1) 
                
#             images = torchvision.transforms.Resize(target_size)(images)
#             images = normalize(images).to(images.dtype)
#             scores = scorer(images)

#             return scores

#         return _fn

#     else:
#         def loss_fn(images, prompts):
#             if images.min() < 0: # normalize unnormalized images
#                 images = ((images / 2) + 0.5).clamp(0, 1) 
                
#             images = torchvision.transforms.Resize(target_size)(images)
#             images = normalize(images).to(images.dtype)
#             scores = scorer(images)

#             if aesthetic_target is None: # default maximization
#                 loss = -1 * scores
#             else:
#                 # using L1 to keep on same scale
#                 loss = abs(scores - aesthetic_target)
#             return loss * grad_scale, scores


#         return loss_fn

# def hps_score(inference_dtype=None, device=None, return_loss=False, accelerator=None):
#     model_name = "ViT-H-14"
#     model, preprocess_train, preprocess_val = create_model_and_transforms(
#         model_name,
#         'laion2B-s32B-b79K',
#         precision=inference_dtype,
#         device=device,
#         jit=False,
#         force_quick_gelu=False,
#         force_custom_text=False,
#         force_patch_dropout=False,
#         force_image_size=None,
#         pretrained_image=False,
#         image_mean=None,
#         image_std=None,
#         light_augmentation=True,
#         aug_cfg={},
#         output_dict=True,
#         with_score_predictor=False,
#         with_region_predictor=False
#     )    
    
#     tokenizer = get_tokenizer(model_name)
    
#     checkpoint_path = f"{os.path.expanduser('~')}/.cache/huggingface/hub/models--xswu--HPSv2/snapshots/697403c78157020a1ae59d23f111aa58ced35b0a/HPS_v2_compressed.pt"
#     # force download of model via score
#     hpsv2.score([], "")
    
#     checkpoint = torch.load(checkpoint_path, map_location=device)
#     model.load_state_dict(checkpoint['state_dict'])
#     tokenizer = get_tokenizer(model_name)
#     model = model.to(device, dtype=inference_dtype)
#     model.eval()
#     if accelerator is not None:
#         model = accelerator.prepare(model)

#     target_size =  224
#     normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
#                                                 std=[0.26862954, 0.26130258, 0.27577711])
    
#     if not return_loss:
#         def _fn(images, prompts):
#             if images.min() < 0: # normalize unnormalized images
#                 images = ((images / 2) + 0.5).clamp(0, 1) 
#             x_var = torchvision.transforms.Resize(target_size)(images)
#             x_var = normalize(x_var).to(images.dtype)        
#             caption = tokenizer(prompts)
#             caption = caption.to(device)
#             outputs = model(x_var, caption)
#             image_features, text_features = outputs["image_features"], outputs["text_features"]
#             logits = image_features @ text_features.T
#             scores = torch.diagonal(logits)
#             return scores
        
#         return _fn

#     else:
#         def loss_fn(images, prompts):    
#             if images.min() < 0: # normalize unnormalized images
#                 images = ((images / 2) + 0.5).clamp(0, 1) 
#             x_var = torchvision.transforms.Resize(target_size)(images)
#             x_var = normalize(x_var).to(images.dtype)        
#             caption = tokenizer(prompts)
#             caption = caption.to(device)
#             outputs = model(x_var, caption)
#             image_features, text_features = outputs["image_features"], outputs["text_features"]
#             logits = image_features @ text_features.T
#             scores = torch.diagonal(logits)
#             loss = 1.0 - scores
#             return  loss, scores

#         return loss_fn

def clip_score(
    inference_dtype=None, 
    device=None, 
    return_loss=False, 
    accelerator=None
):
    from clip_scorer import CLIPScorer

    scorer = CLIPScorer(dtype=torch.float32, device=device)
    scorer.requires_grad_(False)

    if not return_loss:
        def _fn(images, prompts):
            if images.min() < 0: # normalize unnormalized images
                images = ((images / 2) + 0.5).clamp(0, 1)
            scores = scorer(images, prompts)
            return scores

        return _fn

    else:
        def loss_fn(images, prompts):
            if images.min() < 0: # normalize unnormalized images
                images = ((images / 2) + 0.5).clamp(0, 1)
            scores = scorer(images, prompts)

            loss = - scores
            return loss, scores

        return loss_fn

def aesthetic_score(
    torch_dtype=None,
    aesthetic_target=None,
    grad_scale=0,
    device=None,
    return_loss=False,
    accelerator=None
):
    from aesthetic_scorer import AestheticScorer

    scorer = AestheticScorer(dtype=torch.float32, device=device)
    scorer.requires_grad_(False)
    if accelerator is not None:
        print("accelerate reward")
        scorer = accelerator.prepare(scorer)

    if not return_loss:
        def _fn(images, prompts):
            if images.min() < 0: # normalize unnormalized images
                images = ((images / 2) + 0.5).clamp(0, 1)
            scores = scorer(images)
            return scores

        return _fn

    else:
        def loss_fn(images, prompts):
            if images.min() < 0: # normalize unnormalized images
                images = ((images / 2) + 0.5).clamp(0, 1)
            scores = scorer(images)

            if aesthetic_target is None: # default maximization
                loss = -1 * scores
            else:
                # using L1 to keep on same scale
                loss = abs(scores - aesthetic_target)
            return loss * grad_scale, scores

        return loss_fn


def hps_score(
    inference_dtype=None, 
    device=None, 
    return_loss=False, 
    accelerator=None
):
    from hpsv2_scorer import HPSv2Scorer

    scorer = HPSv2Scorer(dtype=torch.float32, device=device)
    scorer.requires_grad_(False)
    if accelerator is not None:
        print("accelerate reward")
        scorer = accelerator.prepare(scorer)

    if not return_loss:
        def _fn(images, prompts):
            if images.min() < 0: # normalize unnormalized images
                images = ((images / 2) + 0.5).clamp(0, 1)
            scores = scorer(images, prompts)
            return scores

        return _fn

    else:
        def loss_fn(images, prompts):
            if images.min() < 0: # normalize unnormalized images
                images = ((images / 2) + 0.5).clamp(0, 1)
            scores = scorer(images, prompts)

            loss = 1.0 - scores
            return loss, scores

        return loss_fn


def ImageReward(
    inference_dtype=None, 
    device=None, 
    return_loss=False, 
    accelerator=None
):
    from ImageReward_scorer import ImageRewardScorer

    scorer = ImageRewardScorer(dtype=torch.float32, device=device)
    scorer.requires_grad_(False)
    if accelerator is not None:
        print("accelerate reward")
        scorer = accelerator.prepare(scorer)

    if not return_loss:
        def _fn(images, prompts):
            if images.min() < 0: # normalize unnormalized images
                images = ((images / 2) + 0.5).clamp(0, 1)
            scores = scorer(images, prompts)
            return scores

        return _fn

    else:
        def loss_fn(images, prompts):
            if images.min() < 0: # normalize unnormalized images
                images = ((images / 2) + 0.5).clamp(0, 1)
            scores = scorer(images, prompts)

            loss = - scores
            return loss, scores

        return loss_fn
    


def PickScore(
    inference_dtype=None, 
    device=None, 
    return_loss=False, 
    accelerator=None
):
    from PickScore_scorer import PickScoreScorer

    scorer = PickScoreScorer(dtype=torch.float32, device=device)
    scorer.requires_grad_(False)

    if not return_loss:
        def _fn(images, prompts):
            if images.min() < 0: # normalize unnormalized images
                images = ((images / 2) + 0.5).clamp(0, 1)
            scores = scorer(images, prompts)
            return scores

        return _fn

    else:
        def loss_fn(images, prompts):
            if images.min() < 0: # normalize unnormalized images
                images = ((images / 2) + 0.5).clamp(0, 1)
            scores = scorer(images, prompts)

            loss = - scores
            return loss, scores

        return loss_fn

def inpaint(x, width, y, height, sample_name, return_loss=False):
    image_path = ASSETS_PATH.joinpath(f"sample_images/{sample_name}.png")

    target_size =  224
    normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                                std=[0.26862954, 0.26130258, 0.27577711])

    target = Image.open(image_path).convert("RGB")
    target = torchvision.transforms.ToTensor()(target).unsqueeze(0)
    print(target.shape)

    mask = torch.ones_like(target)
    mask[:, :, y:y+height, x:x+width] = 0

    masked_target = mask * target # save masked target before transformation

    target = normalize(target)

    if not return_loss:
        def _fn(images, prompts):
            if images.min() < 0: # normalize unnormalized images
                images = ((images / 2) + 0.5).clamp(0, 1)

            images = torchvision.transforms.Resize(target.shape[-1])(images)
            images = normalize(images).to(images.dtype)

            target_repeated = target.to(images.dtype).to(images.device).repeat_interleave(images.shape[0], dim=0)
            mask_repeated = mask.to(images.dtype).to(images.device).repeat_interleave(images.shape[0], dim=0)

            scores = - ((images - target_repeated)**2 * mask_repeated).mean(dim=[1,2,3])

            return scores

        return _fn, masked_target

    else:
        def loss_fn(images, prompts):
            if images.min() < 0: # normalize unnormalized images
                images = ((images / 2) + 0.5).clamp(0, 1)

            images = torchvision.transforms.Resize(target_size)(images)
            images = normalize(images).to(images.dtype)

            target_repeated = target.to(images.dtype).to(images.device).repeat_interleave(images.shape[0], dim=0)
            mask_repeated = mask.to(images.dtype).to(images.device).repeat_interleave(images.shape[0], dim=0)

            scores = - ((images - target_repeated)**2 * mask_repeated).mean(dim=[1,2,3])
            loss = -scores

            return loss, scores 
        
        return loss_fn, masked_target