# Imports
import numpy as np
import torch
import re
import torchvision.transforms as transforms
from torchmetrics.multimodal.clip_score import CLIPScore
from torchmetrics.multimodal import CLIPImageQualityAssessment
from moviepy.editor import VideoFileClip
import os
from PIL import Image
import argparse

# Add argparse setup
parser = argparse.ArgumentParser(description='Get clip scores from video frames')

IQA_LIST = ['quality', 'noisiness', 'sharpness', 'real', 'natural']

# Root image folder
parser.add_argument(
    '--root_folder', 
    type=str,
    default='./example/', 
    help='Path to the folder containing images'
    )

# Select results folder
parser.add_argument(
    '--iteration',
    type=str,
    default='9000' if torch.cuda.is_available() else 'cpu',
    help='Iteration folder to compute score'
)

# Select results folder
parser.add_argument(
    '--prompt_augmentation',
    type=bool,
    default=False,
    help='Whether to include prompt augmentation for CLIP Score computation'
)

# Print
parser.add_argument(
    '--print_individual_scores',
    action='store_true',
    default=False,
    help='print individual frame scores'
)

# GPU device to run CLIP network
parser.add_argument(
    '--device',
    type=str,
    default='6' if torch.cuda.is_available() else 'cpu',
    help='Device to run inference on (cuda or cpu)'
)

args = parser.parse_args()


def get_score(image, prompt):
    text_tokens = clip.tokenize([prompt], truncate=True).to("cuda")
    with torch.no_grad():
        image_features = model_clip.encode_image(image)
        text_features = model_clip.encode_text(text_tokens)
        d_bname2clip[bname] = model_clip(x_target, text_tokens)[0].item()/100.0




def compute_scores(image_folder, iteration):
    # Load images from folder
    complete_path = os.path.join(args.root_folder, image_folder, f'save/it{iteration}-test')
    if not os.path.exists(complete_path):
        print(f"Iteration does not exist for {image_folder}")
        return None
    image_files = os.listdir(complete_path)
    

    # Extract prompt from video_path
    prompt = re.search(r'([^@]+)@', image_folder).group(1).replace('_', ' ')
    print(f'prompt: {prompt:60s}', end=' ')

    # Define transformations to prepare frames for CLIPScore
    transform_pil = transforms.Compose([
        # Crop the leftmost 512x512 pixels
        transforms.Lambda(lambda img: img.crop((0, 0, 512, 512))),
        # Resize the cropped image to (224, 224)
        transforms.Resize((224, 224)),
    ])

    transform_tensor = transforms.Compose([# Convert the resized image to tensor
                transforms.ToTensor(),
                ])

    # Process each frame and calculate scores
    scores = []
    iqa_scores = { key: [] for key in IQA_LIST }
    frame_count = 0

    for i, image_file in enumerate(image_files):

        # Load image and apply transforms
        image_path = os.path.join(complete_path, image_file)
        frame = Image.open(image_path)
        frame = transform_pil(frame)
        frame_tensor = transform_tensor(frame)

        # Pass to GPU
        frame_tensor = frame_tensor.to(device)


        if args.prompt_augmentation:
            # Extract the number from the image file name
            image_number = int(os.path.splitext(image_file)[0])

            # Determine the prompt_augmentation based on the number
            if 0 <= image_number <= 6:
                prompt_augmentation = ', front view'
            elif 7 <= image_number <= 18:
                prompt_augmentation = ', side view'
            elif 19 <= image_number <= 32:
                prompt_augmentation = ', overhead view'
            elif 33 <= image_number <= 43:
                prompt_augmentation = ', side view'
            elif 44 <= image_number <= 49:
                prompt_augmentation = ', front view'

            # Multiply by 255 since the original tensor is [0,1], compute scores
            score = metric_clipscore(frame_tensor*255, prompt+prompt_augmentation).detach().item()
        else:
            # Multiply by 255 since the original tensor is [0,1], compute scores
            score = metric_clipscore(frame_tensor*255, prompt).detach().item()


        iqa_dictionary = metric_clipiqa(frame_tensor*255)

        # Extract different prompt scores
        iga_cur_scores = {key: iqa_dictionary[key].detach().item() for key in IQA_LIST}

        # Print results
        if args.print_individual_scores:
            print(f'Frame {i+1} CLIP Score for : {score:0.2f}', end=' ')
            for key in IQA_LIST:
                print(f'IQA {key}: {iga_cur_scores[key]:0.2f}', end=' ')
            print("\n")
        
        # Store scores
        scores.append(score)
        for key in IQA_LIST:
            iqa_scores[key].append(iga_cur_scores[key])
        
    # Calculate average score
    average_score = np.mean(scores)
    print(f"CLIP score: {average_score:0.2f} ± {np.std(scores):0.2f}", end=' ')
    for key in IQA_LIST:
        print(f"IQA {key}: {np.mean(iqa_scores[key]):0.2f} ± {np.std(iqa_scores[key]):0.2f}", end=' ')
    print("\n")
    
    average_iqa_scores = {key: np.mean(iqa_scores[key]) for key in IQA_LIST}
    return average_score, average_iqa_scores


# Use GPU if available
device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")

# Initialize CLIPScore metric
# 1. clip-vit-base-patch16
# 2. clip-vit-base-patch32
# 3. clip-vit-large-patch14
# 4. clip-vit-large-patch14-336
metric_clipscore = CLIPScore(model_name_or_path="openai/clip-vit-base-patch32")
metric_clipscore = metric_clipscore.to(device)

# Initialize CLIP IQA metrics
metric_clipiqa = CLIPImageQualityAssessment(prompts=('quality', 'noisiness', 'sharpness', 'real', 'natural'))
metric_clipiqa = metric_clipiqa.to(device)

print('Networks Loaded')

root_folder = args.root_folder
image_folders = os.listdir(root_folder)

total_scores = []
total_iqa_scores = {key: [] for key in IQA_LIST}
for image_folder in image_folders:
    res = compute_scores(image_folder, args.iteration)
    if res is not None:
        score, iqa_scores = res
        total_scores.append(score)
        for key in IQA_LIST:
            total_iqa_scores[key].append(iqa_scores[key])

print(f"\n\n Final results:")
print(f"CLIP score: {np.mean(total_scores):0.2f} ± {np.std(total_scores):0.2f}")
for key in IQA_LIST:
    print(f"IQA {key}: {np.mean(total_iqa_scores[key]):0.2f} ± {np.std(total_iqa_scores[key]):0.2f}")
    
