import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.models import ResNet50_Weights, ResNet101_Weights, ShuffleNet_V2_X1_0_Weights
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F
import timm
import pandas as pd
import random
from collections import defaultdict
from config.config import  DATASET_TEST_INTERNAL, PREDICTIONS_TEST_INTERNAL
from pathlib import Path

class FilenameLabelDataset(torch.utils.data.Dataset):
    def __init__(self, samples, subset_dict, transform=None):
        # samples: list of (path, int_label)
        self.samples = list(samples)
        self.subset_dict = dict(subset_dict)
        self.transform = transform

    def __len__(self):
        # must return an int
        return int(len(self.samples))

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        subset = self.subset_dict[os.path.basename(path)]
        image = Image.open(path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        # ensure tensor long
        label = torch.tensor(int(label), dtype=torch.long)
        return image, label, subset

def main():

    DATASET_PATH = DATASET_TEST_INTERNAL
    OUTPUT_DIR = PREDICTIONS_TEST_INTERNAL

    # make sure the directory exists
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    # Set device
    device = (
        torch.device("cuda") if torch.cuda.is_available()
        else torch.device("mps") if torch.backends.mps.is_available()
        else torch.device("cpu")
    )

    # Image preprocessing
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
    ])

    # Collect image paths
    all_image_paths = []
    for root, _, files in os.walk(DATASET_PATH):
        for file in files:
            if file.lower().endswith((".jpg", ".jpeg", ".png")):
                all_image_paths.append(os.path.join(root, file))

    # Group image paths by label
    label_to_paths = defaultdict(list)
    for path in all_image_paths:
        filename = os.path.basename(path)

        try:
            label = int(filename.split("_")[0])
            label_to_paths[label].append(path)
        except ValueError:
            print(f"Skipping file with unexpected name format: {filename}")

    # Shuffle and assign labels to subsets
    SEED =42
    random.seed(SEED)
    torch.manual_seed(SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(SEED)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    all_labels = list(label_to_paths.keys())
    random.shuffle(all_labels)

    subset_dict = {}
    label_subset_map = {}
    for i, label in enumerate(all_labels):
        subset_name = ["subset_A", "subset_B", "subset_C"][i % 3]
        label_subset_map[label] = subset_name
        for path in label_to_paths[label]:
            subset_dict[os.path.basename(path)] = subset_name

    corrected_samples = []
    image_ids = []
    for path in all_image_paths:
        filename = os.path.basename(path)
        try:
            label = filename.split("_")[0]
            corrected_samples.append((path, label))
            image_ids.append(os.path.splitext(filename)[0])
        except ValueError:
            print(f"Skipping file with unexpected name format: {filename}")

    dataset = FilenameLabelDataset(corrected_samples, subset_dict, transform=transform)
    dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4)

    model_specs = {
        "deit_tiny_patch16_224": lambda: timm.create_model('deit_tiny_patch16_224', pretrained=True),
        "deit_small_patch16_224": lambda: timm.create_model('deit_small_patch16_224', pretrained=True),
        "deit_base_patch16_224": lambda: timm.create_model('deit_base_patch16_224', pretrained=True),
        "resnet101": lambda: models.resnet101(weights=ResNet101_Weights.DEFAULT),
        "resnet50": lambda: models.resnet50(weights=ResNet50_Weights.DEFAULT),
        "shufflenet_v2_x1_0": lambda: models.shufflenet_v2_x1_0(weights=ShuffleNet_V2_X1_0_Weights.DEFAULT)
    }

    model_accuracies = []

    for model_name, constructor in model_specs.items():
        print(f"\nEvaluating {model_name}...")
        model = constructor().to(device)
        model.eval()

        true_scores = []
        pred_scores = []
        pred_classes = []
        true_classes = []
        subsets = []

        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels, batch_subsets in tqdm(dataloader, desc=f"{model_name}"):
                images = images.to(device)

                labels = labels.to(device)

                logits = model(images)
                probs = F.softmax(logits, dim=1)
                preds = torch.argmax(probs, dim=1)

                pred_score = probs[torch.arange(len(probs)), preds].cpu().tolist()
                true_score = probs[torch.arange(len(probs)), labels].cpu().tolist()

                pred_scores.extend(pred_score)
                true_scores.extend(true_score)
                pred_classes.extend(preds.cpu().tolist())
                true_classes.extend(labels.cpu().tolist())
                subsets.extend(batch_subsets)

                correct += (preds == labels).sum().item()
                total += labels.size(0)

        accuracy = correct / total if total > 0 else 0.0
        print(f"Accuracy for {model_name}: {accuracy:.4f}")
        model_accuracies.append({
            "model_name": model_name,
            "accuracy": accuracy
        })

        df = pd.DataFrame({
            "subset": subsets,
            f"{model_name}_true_score": true_scores,
            f"{model_name}_pred_score": pred_scores,
            f"{model_name}_pred_class": pred_classes,
            f"{model_name}_true_class": true_classes,
        }, index=image_ids)
        df.index.name = "image_id"

        # Save to a distinct CSV named after the model
        out_path = os.path.join(OUTPUT_DIR, f"{model_name}.csv")
        df.to_csv(out_path)
        print(f"Saved predictions to {out_path}")

    # Save summary CSV
    summary_df = pd.DataFrame(model_accuracies)
    summary_df.to_csv(os.path.join(OUTPUT_DIR, "test_prediction.csv"), index=False)
    print(f"\nSaved summary accuracy report to {os.path.join(OUTPUT_DIR, 'test_prediction.csv')}")
    print("\n=== All models evaluated ===")

if __name__ == '__main__':
    main()
