import os
import json
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.models import ResNet50_Weights, ResNet101_Weights, ShuffleNet_V2_X1_0_Weights
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F
import timm
import pandas as pd
import random
from collections import defaultdict

from config.config import DATASET_TEST_EXTERNAL,PREDICTIONS_TEST_EXTERNAL, CLASS_DIR_IMAGENET,OBJECTNET_TO_IMAGENET

def normalize_label(s):
    return s.strip().replace("_", " ").lower()

class FilenameLabelDataset(Dataset):
    def __init__(self, samples, subset_dict, transform=None):
        self.samples = samples
        self.subset_dict = subset_dict
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        subset = self.subset_dict[os.path.basename(path)]
        image = Image.open(path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label, subset, os.path.splitext(os.path.basename(path))[0]

def main():

    DATASET_PATH = DATASET_TEST_EXTERNAL
    OUTPUT_DIR = PREDICTIONS_TEST_EXTERNAL
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    with open(CLASS_DIR_IMAGENET, "r") as f:
        imagenet_index_to_word = json.load(f)
    with open(OBJECTNET_TO_IMAGENET, "r") as f:
        objectnet_to_imagenet_words = json.load(f)

    word_to_index = {v.lower(): int(k) for k, v in imagenet_index_to_word.items()}
    index_to_word = {int(k): v for k, v in imagenet_index_to_word.items()}

    normalized_objectnet_to_imagenet = {
        normalize_label(k): normalize_label(v)
        for k, v in objectnet_to_imagenet_words.items()
    }

    synonym_to_index = {}
    missing_mappings = []

    for objnet_label, imagenet_word in normalized_objectnet_to_imagenet.items():
        if imagenet_word in word_to_index:
            synonym_to_index[objnet_label] = word_to_index[imagenet_word]
        else:
            missing_mappings.append(objnet_label)

    print(f"\n Mapped {len(synonym_to_index)} ObjectNet labels to ImageNet indices")
    print(f"  Missing {len(missing_mappings)} ObjectNet labels")
    if missing_mappings:
        print(f"Some missing mappings: {missing_mappings[:10]} ...")

    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    device = (
        torch.device("cuda") if torch.cuda.is_available()
        else torch.device("mps") if torch.backends.mps.is_available()
        else torch.device("cpu")
    )

    all_image_paths = []
    for root, _, files in os.walk(DATASET_PATH):
        for file in files:
            if file.lower().endswith((".jpg", ".jpeg", ".png")):
                all_image_paths.append(os.path.join(root, file))

    label_to_paths = defaultdict(list)
    for path in all_image_paths:
        filename = os.path.basename(path)
        raw_label = normalize_label(filename.split("__")[0])
        if raw_label not in synonym_to_index:
            continue
        label_to_paths[raw_label].append(path)

    SEED = 42
    random.seed(SEED)
    torch.manual_seed(SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(SEED)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    all_labels = list(label_to_paths.keys())
    random.shuffle(all_labels)
    subset_dict = {}
    for i, label in enumerate(all_labels):
        subset = ["subset_A", "subset_B", "subset_C"][i % 3]
        for path in label_to_paths[label]:
            subset_dict[os.path.basename(path)] = subset

    corrected_samples = []
    for path in all_image_paths:
        filename = os.path.basename(path)
        raw_label = normalize_label(filename.split("__")[0])
        if raw_label not in synonym_to_index:
            continue
        label = synonym_to_index[raw_label]
        corrected_samples.append((path, label))

    # Save clean image IDs (with known labels)
    clean_ids = [
        os.path.splitext(os.path.basename(p))[0]
        for p, _ in corrected_samples
    ]
    clean_ids_path = os.path.join(OUTPUT_DIR, "clean_ids.csv")
    with open(clean_ids_path, "w") as f:
        f.write("image_id\n")
        for img_id in sorted(clean_ids):
            f.write(f"{img_id}\n")
    print(f"\n  Saved clean image IDs (with known labels) to {clean_ids_path}")

    print(f"\n  Total images: {len(all_image_paths)}")
    print(f" Images with known labels: {len(corrected_samples)}")
    print(f" Skipped images: {len(all_image_paths) - len(corrected_samples)}")

    dataset = FilenameLabelDataset(corrected_samples, subset_dict, transform)
    dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4)

    model_specs = {
        "deit_tiny_patch16_224": lambda: timm.create_model('deit_tiny_patch16_224', pretrained=True),
        "deit_small_patch16_224": lambda: timm.create_model('deit_small_patch16_224', pretrained=True),
        "deit_base_patch16_224": lambda: timm.create_model('deit_base_patch16_224', pretrained=True),
        "resnet101": lambda: models.resnet101(weights=ResNet101_Weights.DEFAULT),
        "resnet50": lambda: models.resnet50(weights=ResNet50_Weights.DEFAULT),
        "shufflenet_v2_x1_0": lambda: models.shufflenet_v2_x1_0(weights=ShuffleNet_V2_X1_0_Weights.DEFAULT),
    }

    model_accuracies = []
    global_correct_matches = set()
    per_model_correct_ids = {}

    for model_name, constructor in model_specs.items():
        print(f"\n  Evaluating {model_name}...")
        model = constructor().to(device)
        model.eval()

        pred_scores, true_scores = [], []
        pred_classes, true_classes = [], []
        pred_words, true_words = [], []
        image_ids, subsets = [], []
        correct = total = 0
        correct_ids = set()

        with torch.no_grad():
            for images, labels, batch_subsets, batch_ids in tqdm(dataloader, desc=model_name):
                images, labels = images.to(device), labels.to(device)
                logits = model(images)
                probs = F.softmax(logits, dim=1)
                preds = torch.argmax(probs, dim=1)

                pred_score = probs[torch.arange(len(probs)), preds].cpu().tolist()
                true_score = probs[torch.arange(len(probs)), labels].cpu().tolist()

                pred_scores.extend(pred_score)
                true_scores.extend(true_score)
                pred_classes.extend(preds.cpu().tolist())
                true_classes.extend(labels.cpu().tolist())
                subsets.extend(batch_subsets)
                image_ids.extend(batch_ids)

                pred_words.extend([index_to_word.get(p.item(), "UNK") for p in preds])
                true_words.extend([index_to_word.get(l.item(), "UNK") for l in labels])

                correct += (preds == labels).sum().item()
                total += labels.size(0)

                for i in range(len(preds)):
                    if preds[i] == labels[i]:
                        global_correct_matches.add(batch_ids[i])
                        correct_ids.add(batch_ids[i])

        per_model_correct_ids[model_name] = correct_ids

        accuracy = correct / total if total > 0 else 0.0
        print(f" Top-1 Accuracy for {model_name}: {accuracy:.4f}")

        model_accuracies.append({
            "model_name": model_name,
            "accuracy": accuracy,
        })

        df = pd.DataFrame({
            "image_id": image_ids,
            "subset": subsets,
            f"{model_name}_true_class": true_classes,
            f"{model_name}_true_class_word": true_words,
            f"{model_name}_true_score": true_scores,
            f"{model_name}_pred_class": pred_classes,
            f"{model_name}_pred_class_word": pred_words,
            f"{model_name}_pred_score": pred_scores,
        })

        out_path = os.path.join(OUTPUT_DIR, f"{model_name}.csv")
        df.to_csv(out_path, index=False)
        print(f"📄 Saved predictions to {out_path}")

    # Save union of correct image IDs
    matching_ids_path = os.path.join(OUTPUT_DIR, "correct_image_ids.csv")
    with open(matching_ids_path, "w") as f:
        f.write("image_id\n")
        for img_id in sorted(global_correct_matches):
            f.write(f"{img_id}\n")
    print(f"\n Saved image IDs correct in ANY model to {matching_ids_path}")

    # Save intersection of correct image IDs
    common_correct_ids = set.intersection(*per_model_correct_ids.values())
    common_ids_path = os.path.join(OUTPUT_DIR, "correct_image_ids_all_models.csv")
    with open(common_ids_path, "w") as f:
        f.write("image_id\n")
        for img_id in sorted(common_correct_ids):
            f.write(f"{img_id}\n")
    print(f"  Saved image IDs correct in ALL models to {common_ids_path}")

    summary_df = pd.DataFrame(model_accuracies)
    summary_csv = os.path.join(OUTPUT_DIR, "test_prediction.csv")
    summary_df.to_csv(summary_csv, index=False)
    print(f"\n Saved accuracy summary to {summary_csv}")
    print(" All models evaluated successfully.")

if __name__ == '__main__':
    main()
