import os
import json
import torch
from PIL import Image
import numpy as np
from numpy.linalg import norm
from transformers import AutoImageProcessor, AutoModel
import logging
from tqdm import tqdm
Image.MAX_IMAGE_PIXELS = None

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


def load_and_resize_image(image_path):
    pixel_threshold = 178956970 
    
    try:
        img = Image.open(image_path)
        width, height = img.size
        
        if width * height > pixel_threshold:
            logging.warning(f"Image {os.path.basename(image_path)} ({width * height} pixels) exceeds threshold, resizing...")
            ratio = (pixel_threshold / (width * height)) ** 0.5
            new_width = int(width * ratio)
            new_height = int(height * ratio)
            img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
            logging.info(f"Resized to {new_width}x{new_height}")
            
        return img.convert('RGB')
    except Exception as e:
        logging.error(f"Failed to load or resize image {image_path}: {e}")
        raise

def get_embedding_hf(image, processor, model, device):
    try:
        inputs = processor(images=image, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model(**inputs)
        embedding = outputs.pooler_output.cpu().numpy()
        return embedding
    except Exception as e:
        logging.error(f"Failed to extract embedding: {e}")
        return None

def cosine_similarity(a, b):
    epsilon = 1e-8
    return np.dot(a, b.T) / ((norm(a) * norm(b)) + epsilon)

def process_all_models(source_folder1, source_folder2, output_dir, model_names, task_names, model, processor, device):
    os.makedirs(output_dir, exist_ok=True)
    
    label_imgs_folder = os.path.join(source_folder2, "label_imgs")
    if not os.path.isdir(label_imgs_folder):
        logging.error(f"Label images folder not found at '{label_imgs_folder}'. Aborting.")
        return

    for model_name in model_names:
        model_folder_path = os.path.join(source_folder1, model_name)
        if not os.path.isdir(model_folder_path):
            logging.warning(f"Model folder '{model_folder_path}' not found, skipping.")
            continue

        logging.info(f"--- Processing Model: {model_name} ---")
        model_scores = {}
        output_file_path = os.path.join(output_dir, f"{model_name}.json")

        if os.path.exists(output_file_path):
            with open(output_file_path, 'r') as f:
                model_scores = json.load(f)

        # 遍历任务类别文件夹
        for task_name in task_names:
            generated_category_path = os.path.join(model_folder_path, task_name)

            if not os.path.exists(generated_category_path):
                logging.warning(f"Generated images directory not found, skipping: {generated_category_path}")
                continue

            logging.info(f"  Processing Category: {task_name}")
            
            if task_name not in model_scores:
                model_scores[task_name] = {
                    'average_score': 0,
                    'image_scores': {}
                }

            category_scores = []

            image_files = [f for f in os.listdir(generated_category_path) if f.endswith(('.png', '.jpg', '.jpeg'))]
            for img_name in tqdm(image_files, desc=f"  {model_name}/{task_name}"):

                if img_name in model_scores[task_name]['image_scores']:
                    category_scores.append(model_scores[task_name]['image_scores'][img_name])
                    continue

                generated_img_path = os.path.join(generated_category_path, img_name)
                label_img_path = os.path.join(label_imgs_folder, img_name)

                if not os.path.exists(label_img_path):
                    logging.warning(f"Corresponding label image not found, skipping: {label_img_path}")
                    continue

                try:
                    generated_image = load_and_resize_image(generated_img_path)
                    label_image = load_and_resize_image(label_img_path)
                    
                    embedding_generated = get_embedding_hf(generated_image, processor, model, device)
                    embedding_label = get_embedding_hf(label_image, processor, model, device)
                    
                    if embedding_generated is None or embedding_label is None:
                        continue

                    similarity = cosine_similarity(embedding_generated, embedding_label).item()
                    
                    model_scores[task_name]['image_scores'][img_name] = similarity
                    category_scores.append(similarity)

                except Exception as e:
                    logging.error(f"Error processing image {img_name}: {e}")

            if category_scores:
                average_score = np.mean(category_scores)
                model_scores[task_name]['average_score'] = average_score
                logging.info(f"  Category '{task_name}' Average Score: {average_score:.4f}")
            else:
                logging.warning(f"  No images were successfully processed for category '{task_name}'.")

        try:
            with open(output_file_path, 'w', encoding='utf-8') as f:
                json.dump(model_scores, f, indent=4, ensure_ascii=False)
            logging.info(f"Results for model {model_name} saved to: {output_file_path}\n")
        except Exception as e:
            logging.error(f"Failed to save JSON file for model {model_name}: {e}")

if __name__ == '__main__':

    source_folder1 = "gpt5imgs"

    source_folder2 = "label_imgs"
    
    output_dir = "results"

    model_path_local = 'dinov2-base'

    logging.info(f"Loading model from {model_path_local}...")
    try:
        processor = AutoImageProcessor.from_pretrained(model_path_local)
        model = AutoModel.from_pretrained(model_path_local)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model.to(device)
        model.eval()
        logging.info(f"Model successfully loaded to {device}")
    except Exception as e:
        logging.error(f"Failed to load model: {e}")
        exit()

    try:
        model_names = [d for d in os.listdir(source_folder1) if os.path.isdir(os.path.join(source_folder1, d))]
        if not model_names:
            logging.error(f"No model subdirectories found in '{source_folder1}'. Please check the path.")
    except FileNotFoundError:
        logging.error(f"Source folder '{source_folder1}' not found.")
        model_names = []

    task_names = ["Interaction_Authoring"]

    if model_names:
        logging.info(f"Models to process: {model_names}")
        logging.info(f"Tasks to process: {task_names}")
        process_all_models(source_folder1, source_folder2, output_dir, model_names, task_names, model, processor, device)
        logging.info("All tasks completed.")