import os
import json
import os
import torch
import argparse
import numpy as np
from PIL import Image
from tqdm import tqdm
from glob import glob
from functools import partial
from torchvision import transforms
from pipelines.utils import get_coco_dict
import warnings
warnings.filterwarnings("ignore", category=UserWarning, message=".*?Your .*? set is empty.*?")

def parse_args():
    parser = argparse.ArgumentParser(description="CLIP score evaluation")
    parser.add_argument("--generated_dir", type=str, default="/sda/home/ada6k4_05/exp6/sd3-cfg=10-base-steps=30-seed=625")
    parser.add_argument("--anno_dir", type=str, default="/sda/home/ada6k4_05/datasets/annotations/captions_val2017.json")
    parser.add_argument("--dataset_dir", type=str, default="/sda/home/ada6k4_05/datasets/val2017")
    parser.add_argument("--model_name_or_path", type=str, default="openai/clip-vit-base-patch16")
    parser.add_argument("--batch_size", type=int, default=256)
    return parser.parse_args()

def calculate_clip_score(images, prompts, clip_score_fn):
    # convert PIL.Image to numpy array
    images = [np.array(image) for image in images]
    # transform images into 512x512
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
    ])
    images = [transform(image).numpy() for image in images]
    images = np.stack(images)

    images_int = (images * 255).astype("uint8")
    clip_score = clip_score_fn(torch.from_numpy(images_int), prompts).detach()
    return round(float(clip_score), 4)

def eval_clip(args):

    from torchmetrics.functional.multimodal import clip_score
    clip_score_fn = partial(
        clip_score,
        model_name_or_path=args.model_name_or_path
    )

    id2caption, id2image_path = get_coco_dict(
        annotation_file=args.anno_dir,
        images_dir=args.dataset_dir,
        percent=1,
    )
    image_ids = [image_path.split('/')[-1].strip('.jpg').strip('.png') for image_path in id2image_path.values()]
    images_paths = [id2image_path[str(image_id)] for image_id in image_ids]

    if args.generated_dir != 'coco':
        print(f'Using images from {args.generated_dir}')
        images_paths = []
        for ext in ['jpg', 'png']:
            images_paths += glob(os.path.join(args.generated_dir, '**.' + ext))

        image_ids = [image_path.split('/')[-1].strip('.jpg').strip('.png') for image_path in images_paths]
    else:
        print(f'Using images from COCO dataset')
    
    print(f'{len(image_ids)} images from {len(id2image_path)}')

    clip_scores = []
    for idx in tqdm(range(0, len(image_ids), args.batch_size)):

        images = [Image.open(image_path).convert('RGB') for image_path in images_paths[idx:idx + args.batch_size]]
        captions = [id2caption[image_id] for image_id in image_ids[idx:idx + args.batch_size]]
        clip_score = calculate_clip_score(images, captions, clip_score_fn)

        for _ in range(len(image_ids[idx:idx + args.batch_size])):
            clip_scores.append(clip_score)

    return clip_scores


def get_coco_val_pairs(
        annotation_file = '/mnt/mydisk/_datasets/annotations/captions_val2017.json',
        images_dir = '/mnt/mydisk/_datasets/val2017',
    ):

    with open(annotation_file, 'r') as f:
        data = json.load(f)
   
    id_to_filename = {img['id']: img['file_name'] for img in data['images']}
   
    pairs = dict()
    for ann in data['annotations']:
        image_id = ann['image_id']
        caption = ann['caption']
        file_name = id_to_filename[image_id]
        image_path = os.path.join(images_dir, file_name)
        pairs[image_id] = (image_path, caption)
   
    return pairs


def get_coco_image_ids_and_captions(
        annotation_file = '/mnt/mydisk/_datasets/annotations/captions_val2017.json',
        images_dir = '/mnt/mydisk/_datasets/val2017',
        percent = 1.0,
    ):
    pairs = get_coco_val_pairs(annotation_file, images_dir)
    captions = []
    image_ids = []
    for image_id, (image_path, caption) in pairs.items():
        image_id = str(image_id).zfill(12)
        image_ids.append(image_id)
        captions.append(caption)

    num_images = len(image_ids)
    num_images_to_select = int(num_images * percent)
    if num_images_to_select < num_images:
        image_ids = image_ids[:num_images_to_select]
        captions = captions[:num_images_to_select]
    else:
        print(f"Warning: Requested {percent*100}% of images, but only {num_images} images available. Returning all images.")
        image_ids = image_ids[:num_images]
        captions = captions[:num_images]
    print(f"Selected {len(image_ids)} images out of {num_images} total images.")
    print(f"Selected {len(captions)} captions out of {num_images} total captions.")

    return image_ids, captions

def get_coco(
        annotation_file = '/mnt/mydisk/_datasets/annotations/captions_val2017.json',
        images_dir = '/mnt/mydisk/_datasets/val2017',
        percent = 1.0,
    ):
    pairs = get_coco_val_pairs(annotation_file, images_dir)
    captions = []
    image_ids = []
    image_paths = []
    for image_id, (image_path, caption) in pairs.items():
        image_id = str(image_id).zfill(12)
        image_ids.append(image_id)
        captions.append(caption)
        image_paths.append(image_path)

    num_images = len(image_ids)
    num_images_to_select = int(num_images * percent)
    if num_images_to_select < num_images:
        image_ids = image_ids[:num_images_to_select]
        captions = captions[:num_images_to_select]
        image_paths = image_paths[:num_images_to_select]
    else:
        print(f"Warning: Requested {percent*100}% of images, but only {num_images} images available. Returning all images.")
        image_ids = image_ids[:num_images]
        captions = captions[:num_images]
        image_paths = image_paths[:num_images]

    return image_ids, captions, image_paths


def get_coco_dict(
        annotation_file = '/mnt/mydisk/_datasets/annotations/captions_val2017.json',
        images_dir = '/mnt/mydisk/_datasets/val2017',
        percent = 1.0,
    ):
    from collections import defaultdict
    image_ids, captions, image_paths = get_coco(annotation_file, images_dir, percent)
    id2caption = defaultdict(str)
    id2image_path = defaultdict(str)
    for image_id, caption, image_path in zip(image_ids, captions, image_paths):
        id2caption[image_id] = caption
        id2image_path[image_id] = image_path
    # print(f"Selected {len(id2caption)} image ids out of {len(image_ids)} total image ids.")
    # print(f"Selected {len(id2image_path)} image paths out of {len(image_paths)} total image paths.")
    return id2caption, id2image_path