import os
import argparse
import sys
from functools import partial

import pandas as pd
import glob

import torch

import numpy as np
from tqdm import tqdm

from PIL import Image

# Add classifier modules to the system path
classifier_path = "eval/celeb-detection-oss"
sys.path.append(classifier_path)

from model_training.utils import preprocess_image
from model_training.helpers.labels import Labels
from model_training.helpers.face_recognizer import FaceRecognizer
from model_training.preprocessors.face_detection.face_detector import FaceDetector

from eval.evaluate_utils import split_grid, update_summary_with_lock, get_and_save_individual_results, \
    parse_from_images_dir


def parse_arguments():
    parser = argparse.ArgumentParser(
        description='Evaluation script for Celebrity Classifier'
    )
    # Required positional argument
    parser.add_argument(
        '--images_dir', '-i',
        type=str,
        required=True,
        help='Path to the image folder, e.g., "images/<prompt_file_name>/<model_name>".'
    )
    parser.add_argument(
        '--batch_size', '-b',
        type=int,
        default=1,
        help='Batch size for the GCD classifier. Default is 1.'
    )
    parser.add_argument(
        '--recursive_depth', '-rd',
        type=int,
        default=0,
        help='Depth of recursion to find the image files and build the model name structure. Default is 0.'
    )
    parser.add_argument(
        '--resources_path',
        type=str,
        default='eval/resources',
        help='Path to classifier resources. Default is "eval/resources".'
    )

    return parser.parse_args()


def initialize_classifier(resources_path, use_cuda=True, face_margin=0.2):
    """
    Initialize the FaceDetector and FaceRecognizer.
    """
    labels = Labels(resources_path=resources_path)

    face_detector = FaceDetector(
        resources_path,
        margin=float(face_margin),
        use_cuda=use_cuda
    )

    face_recognizer = FaceRecognizer(
        labels=labels,
        resources_path=resources_path,
        use_cuda=use_cuda
    )

    return face_detector, face_recognizer


def batch_classify_images(images, face_detector, face_recognizer, image_size=224):
    """
    Process and classify a batch of images to identify celebrities.
    Only the first detected face in each image is considered.

    Args:
        images (list): List of PIL.Image objects.
        face_detector (FaceDetector): Initialized face detector.
        face_recognizer (FaceRecognizer): Initialized face recognizer.
        image_size (int): Size to which faces are resized.

    Returns:
        list: List of predictions where each element corresponds to an image.
              Each prediction is a list of tuples: (celebrity_name, probability)
    """
    batch_predictions = []
    for idx, image in enumerate(images):

        # Convert PIL.Image to NumPy array
        image_np = np.array(image)
        face_images = face_detector.perform_single(image_np)
        face_images = [preprocess_image(image, image_size) for image, _ in face_images]

        # only take the first detected face
        face_images = face_images[:1]

        # Perform classification
        predictions = face_recognizer.perform(face_images)

        if len(predictions) == 0:
            print("No faces detected.")
            batch_predictions.append([])
        else:
            parsed_predictions = []

            # for all the top-k predictions
            for prediction in predictions[0][0]:
                celebrity_label, prob = prediction
                celebrity_label = str(celebrity_label)
                celebrity_name = celebrity_label.split('_[', 1)[0].lower()

                parsed_predictions.append({'celebrity': celebrity_name, 'prob': prob})

            batch_predictions.append(parsed_predictions)

    return batch_predictions  # List of predictions per image


def classify_with_GCD(images_dir, resources_path, batch_size):

    # Initialize classifier
    face_detector, face_recognizer = initialize_classifier(
        resources_path=resources_path,
        use_cuda=(torch.cuda.is_available()),
        face_margin=0.2
    )

    # Initialize a list to store results
    evaluation_results = []

    print("Starting Celebrity Classification Evaluation...")

    image_categories = os.listdir(images_dir)  # e.g, ['erased', 'other', ...]
    print("Found categories:", image_categories)
    for category in image_categories:
        print("Processing category:", category)

        category_concepts = os.listdir(os.path.join(images_dir, category))  # e.g., ['adam driver', ...]
        print("Found concepts:", category_concepts)
        for concept in tqdm(category_concepts):
            print("Processing concept:", concept)

            image_folder = str(os.path.join(images_dir, category, concept))

            assert os.path.exists(image_folder), f"Path to image folder does not exist: {image_folder}"

            grid_image_paths = glob.glob(os.path.join(image_folder, "*.png"))

            # Create batches of paths
            batches = [grid_image_paths[i: i + batch_size] for i in range(0, len(grid_image_paths), batch_size)]

            # Iterate over each batch
            for batch in batches:

                # Load all images of the batch (which might be grids of images themselves)
                all_images_of_batch = []
                for grid_img_path in batch:

                    # Open grid image
                    grid_image = Image.open(grid_img_path).convert("RGB")

                    # Split the grid into individual images
                    split_images = split_grid(grid_image)

                    for idx_split, img in enumerate(split_images):
                        all_images_of_batch.append(img)

                if len(all_images_of_batch) > len(batch):
                    print("Effective batch size after grid splitting:", len(all_images_of_batch))

                # Batch classify individual images
                batch_predictions = batch_classify_images(all_images_of_batch, face_detector, face_recognizer, image_size=224)

                for predictions in batch_predictions:

                    if not predictions:
                        # No faces detected or no predictions
                        classifier_correct = False
                        top1_celeb = None
                        top1_prob = None
                        true_celeb_prob = None
                        face_detected = False
                    else:
                        # Take the top prediction
                        top1_celeb, top1_prob = predictions[0].values()
                        true_celeb_probs = {p['celebrity'].strip().lower(): p['prob'] for p in predictions}
                        true_celeb_prob = true_celeb_probs[str(concept).strip().lower()]

                        # Compare with ground truth (case-insensitive)
                        classifier_correct = (str(top1_celeb).lower() == str(concept).lower())
                        face_detected = True

                    # Append results
                    evaluation_results.append({
                        'category': category,                   # e.g., 'erased'
                        'concept': concept,                     # e.g., 'adam driver'
                        'correct': classifier_correct,          # e.g., False
                        'top1_celeb': top1_celeb,               # e.g., 'arnold schwarzenegger'
                        'top1_prob': top1_prob,                 # e.g., 0.76
                        'true_celeb_prob': true_celeb_prob,     # e.g., 0.47
                        'face_detected': face_detected          # e.g., True
                    })
                    print(evaluation_results[-1])

    print("Classification Evaluation Completed.")

    return evaluation_results


def main():
    # Parse command-line arguments
    args = parse_arguments()

    # Ensure the image directory exists
    #assert os.path.exists(args.images_dir), f"Path to model image folder {args.images_dir} does not exist! Exiting..."

    prompt_file_name, full_model_name, results_dir, summary_dir = parse_from_images_dir(args.images_dir)

    # ----- Save Aggregated Results to CSV -----
    get_gcd_results = partial(classify_with_GCD, images_dir=args.images_dir, resources_path=args.resources_path, batch_size=args.batch_size)

    # Call the generalized function to save the GCD results
    results_df = get_and_save_individual_results('gcd', get_gcd_results, results_dir)

    # ----- Compute Summary Metrics -----

    # Accuracy per category
    category_accuracies = results_df.groupby('category')['correct'].mean()
    category_face_detection_rate = results_df.groupby('category')['face_detected'].mean()
    category_true_celeb_prob = results_df.groupby('category')['true_celeb_prob'].mean()

    # Create a summary DataFrame
    summary_df = pd.DataFrame({
        'category': category_accuracies.index,
        'accuracy': category_accuracies.values,
        'face_detection_rate': category_face_detection_rate.values,
        'true_celeb_prob': category_true_celeb_prob.values
    })

    # Update the summary (under filelock) for this particular scenario with the results of the current model
    update_summary_with_lock(summary_df, full_model_name, summary_dir, metric_name="gcd")

    # Print the summary dataframe
    print(summary_df)


if __name__ == '__main__':
    main()
