import os
import argparse
import pandas as pd

import torch
from tqdm import tqdm
from PIL import Image
import clip

from functools import partial

from eval.evaluate_utils import update_summary_with_lock, get_and_save_individual_results, \
    parse_from_images_dir


def parse_arguments():
    parser = argparse.ArgumentParser(
        description='Evaluation script for Celebrity Classifier (CLIP score)'
    )
    # Required positional argument
    parser.add_argument(
        '--images_dir', '-i',
        type=str,
        required=True,
        help='Path to the image folder, e.g., "images/<prompt_file_name>/<model_name>".'
    )
    parser.add_argument(
        '--recursive_depth', '-rd',
        type=int,
        default=0,
        help='Depth of recursion to find the image files and build the model name structure. Default is 0.'
    )

    return parser.parse_args()


def evaluate_clip_score(images_dir, batch_size=100):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model, preprocess = clip.load("ViT-B/32", device=device)

    prompt_csv = os.path.join("prompts/general_prompts/utility_eval_mscoco_10k_prompts.csv")
    assert os.path.exists(prompt_csv), f"Could not find prompt CSV at {prompt_csv}"

    df = pd.read_csv(prompt_csv)
    results = []

    print(f"Loaded {len(df)} prompts from {prompt_csv}")

    # Gather all data entries first
    entries = []
    for _, row in df.iterrows():
        idx = row["id"]
        category = row["category"]
        concept = row["concept"]
        prompt = row["prompt"]

        img_filename = f"{idx}_{concept}.png"
        img_path = os.path.join(images_dir, "other", "coco", img_filename)
        assert os.path.exists(img_path), f"Image file not found: {img_path}"

        entries.append({
            "img_path": img_path,
            "prompt": prompt,
            "id": idx,
            "category": category,
            "concept": concept
        })

    # Batch over entries
    for i in tqdm(range(0, len(entries), batch_size)):
        batch = entries[i:i + batch_size]

        # Load and preprocess images
        images = [preprocess(Image.open(e["img_path"]).convert("RGB")) for e in batch]
        texts = [e["prompt"] for e in batch]

        with torch.no_grad():
            image_input = torch.stack(images).to(device)
            text_input = clip.tokenize(texts).to(device)

            image_features = model.encode_image(image_input)
            text_features = model.encode_text(text_input)

            image_features /= image_features.norm(dim=-1, keepdim=True)
            text_features /= text_features.norm(dim=-1, keepdim=True)

            similarities = (image_features @ text_features.T).diag().cpu().tolist()

        # Collect results
        for entry, score in zip(batch, similarities):
            results.append({
                "id": entry["id"],
                "category": entry["category"],
                "concept": entry["concept"],
                "prompt": entry["prompt"],
                "clip_score": score
            })

        print(results[-1])

    return results


def main():
    args = parse_arguments()
    assert os.path.exists(args.images_dir), f"Directory not found: {args.images_dir}"

    prompt_file_name, full_model_name, results_dir, summary_dir = parse_from_images_dir(args.images_dir)

    # Partial for evaluator
    get_clip_results = partial(evaluate_clip_score, images_dir=args.images_dir)

    # Run and save per-image results
    results_df = get_and_save_individual_results('clip_score', get_clip_results, results_dir)

    # Aggregate by category
    clip_score_per_category = results_df.groupby('category')['clip_score'].mean().reset_index()

    # Build summary DataFrame
    summary_df = pd.DataFrame({
        'category': clip_score_per_category['category'],
        'clip_score': clip_score_per_category['clip_score'],
    })

    # Update summary with lock
    update_summary_with_lock(summary_df, full_model_name, summary_dir, metric_name="clip_score")

    print("Per-category CLIP scores:")
    print(summary_df)


if __name__ == "__main__":
    main()
