import os
import argparse
from functools import partial
import torch
import pandas as pd
from torchvision import models, transforms
from PIL import Image
from tqdm import tqdm
import glob
from eval.evaluate_utils import update_summary_with_lock, get_and_save_individual_results, parse_from_images_dir


def parse_arguments():
    parser = argparse.ArgumentParser(
        description='Evaluation script for CIFAR-10 Object Classifier'
    )
    parser.add_argument(
        '--images_dir', '-i',
        type=str,
        required=True,
        help='Path to the image folder, e.g., "images/<prompt_file_name>/<model_name>".'
    )
    parser.add_argument(
        '--recursive_depth', '-rd',
        type=int,
        default=0,
        help='Depth for recursive search of image files.'
    )
    parser.add_argument(
        '--resources_path',
        type=str,
        default='eval/resources',
        help='Path to classifier resources (model checkpoints).'
    )
    parser.add_argument(
        '--model_ckpt',
        type=str,
        default=None,
        help='Path to a CIFAR-10 model checkpoint (.pth). Defaults to <resources_path>/cifar10_resnet18.pth'
    )
    parser.add_argument(
        '--device',
        type=str,
        default='cuda:0',
        help='Torch device.'
    )
    parser.add_argument(
        '--batch_size',
        type=int,
        default=250,
        help='Inference batch size.'
    )
    return parser.parse_args()

def load_cifar10_model(device: str):
    from PyTorch_CIFAR10.cifar10_models.resnet import resnet18

    # Pretrained model
    my_model = resnet18(pretrained=True)
    my_model.eval()  # for evaluation
    my_model.to(device)

    return my_model


def classify_with_cifar10(images_dir, device, batch_size):
    # Load model
    model = load_cifar10_model(device)

    # CIFAR-10 class names
    concepts = [
        'airplane', 'automobile', 'bird', 'cat', 'deer',
        'dog', 'frog', 'horse', 'ship', 'truck'
    ]
    # Preprocess for CIFAR-10 sized inputs
    preprocess = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2470, 0.2435, 0.2616)),
    ])

    results = []

    image_categories = os.listdir(images_dir)  # e.g, ['erased', 'other', ...]
    print("Found categories:", image_categories)
    for category in image_categories:
        print("Processing category:", category)

        category_concepts = os.listdir(os.path.join(images_dir, category))  # e.g., ['adam driver', ...]
        print("Found concepts:", category_concepts)
        for concept in tqdm(category_concepts):

            if category == 'other' and concept in os.path.basename(images_dir):
                print(f"CONCEPT {concept} skipped for 'other' category as it appears to be the target ({os.path.basename(images_dir)})!")
                continue

            image_folder = str(os.path.join(images_dir, category, concept))
            assert os.path.exists(image_folder), f"Path to image folder does not exist: {image_folder}"

            grid_image_paths = glob.glob(os.path.join(image_folder, "*.png"))

            # Load all images of the batch (which might be grids of images themselves)
            all_images = []

            for grid_img_path in grid_image_paths:
                # Open grid image
                grid_image = Image.open(grid_img_path).convert("RGB")
                all_images.append(preprocess(grid_image))

            if batch_size is None or batch_size > len(all_images):
                batch_size = len(all_images)

            data = torch.stack(all_images)

            # Inference loop
            for i in range((len(data)-1)//batch_size + 1):
                batch = data[i*batch_size:(i+1)*batch_size].to(device)
                with torch.no_grad():
                    logits = model(batch)
                    probs, idxs = torch.topk(logits.softmax(dim=1), 1, dim=1)
                probs = probs.cpu().numpy()
                idxs = idxs.cpu().numpy()

                for j in range(len(probs)):
                    entry = {
                        'category': category
                    }
                    entry[f'concept'] = concept
                    entry[f'correct'] = concept.lower().replace(' ', '_') == concepts[idxs[j, 0]].lower().replace(' ', '_')
                    entry[f'top1_concept'] = concepts[idxs[j, 0]]
                    entry[f'top1_score'] = float(probs[j, 0])
                    results.append(entry)

    return results


def main():
    args = parse_arguments()
    prompt_file_name, full_model_name, results_dir, summary_dir = parse_from_images_dir(args.images_dir)

    model_tag = f"cifar10_resnet18"
    get_results_fn = partial(
        classify_with_cifar10,
        images_dir=args.images_dir,
        device=args.device,
        batch_size=args.batch_size
    )

    # Save individual results
    results_df = get_and_save_individual_results(
        model_tag,
        get_results_fn,
        results_dir
    )

    # Summarize: count frequency of each predicted class across all positions
    category_accuracies = results_df.groupby('category')['correct'].mean()

    # Create a summary DataFrame
    summary_df = pd.DataFrame({
        'category': category_accuracies.index,
        'accuracy': category_accuracies.values,
    })

    # Update the summary (under filelock) for this particular scenario with the results of the current model
    update_summary_with_lock(summary_df, full_model_name, summary_dir, metric_name="cifar10")

    print(summary_df)


if __name__ == '__main__':
    main()
