import os
import argparse
import pandas as pd

import torch
from tqdm import tqdm

from functools import partial

from cleanfid import fid

from eval.evaluate_utils import update_summary_with_lock, get_and_save_individual_results, \
    parse_from_images_dir


def parse_arguments():
    parser = argparse.ArgumentParser(
        description='Evaluation script for Celebrity Classifier'
    )
    # Required positional argument
    parser.add_argument(
        '--images_dir', '-i',
        type=str,
        required=True,
        help='Path to the image folder, e.g., "images/<prompt_file_name>/<model_name>".'
    )
    parser.add_argument(
        '--recursive_depth', '-rd',
        type=int,
        default=0,
        help='Depth of recursion to find the image files and build the model name structure. Default is 0.'
    )

    return parser.parse_args()


def evaluate_against_original_generated_images(images_dir):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize a list to store results
    evaluation_results = []

    print("Starting FID evaluation...")
    image_categories = os.listdir(images_dir)  # e.g, ['erased', 'other', ...]

    print("Found categories:", image_categories)
    for category in image_categories:
        print("Processing category:", category)

        category_concepts = os.listdir(os.path.join(images_dir, category))  # e.g., ['adam driver', ...]
        print("Found concepts:", category_concepts)
        for concept in tqdm(category_concepts):
            print("Processing concept:", concept)

            image_folder = str(os.path.join(images_dir, category, concept))

            assert os.path.exists(image_folder), f"Path to image folder does not exist: {image_folder}"

            score = fid.compute_fid(image_folder, dataset_name="coco2014_val",  mode="clean", dataset_split="custom")

            # Append results
            evaluation_results.append({
                'category': category,                   # e.g., 'erased'
                'concept': concept,                     # e.g., 'adam driver'
                'fid_score': score
            })
            print(evaluation_results[-1])

    print("FID Score Evaluation Completed.")

    return evaluation_results


def main():
    # Parse command-line arguments
    args = parse_arguments()

    # Ensure the image directory exists
    assert os.path.exists(args.images_dir)

    prompt_file_name, full_model_name, results_dir, summary_dir = parse_from_images_dir(args.images_dir)

    # ----- Save Aggregated Results to CSV -----
    get_quality_results = partial(evaluate_against_original_generated_images, images_dir=args.images_dir)

    # Call the generalized function to save the GCD results
    results_df = get_and_save_individual_results('fid', get_quality_results,  results_dir)

    # Accuracy per category
    fid_score_per_category = results_df.groupby('category')['fid_score'].mean()

    # Create a summary DataFrame
    summary_df = pd.DataFrame({
        'category': fid_score_per_category.index,
        'fid_score': fid_score_per_category.values,
    })

    # Update the summary (under filelock) for this particular scenario with the results of the current model
    update_summary_with_lock(summary_df, full_model_name, summary_dir, metric_name="fid")

    # Print the summary dataframe
    print(summary_df)


if __name__ == '__main__':
    main()
