import os
import json
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision import models
from torchvision.models import (
    ResNet50_Weights,
    ResNet101_Weights,
    ShuffleNet_V2_X1_0_Weights
)
import timm
from tqdm import tqdm
import pandas as pd
import torch.nn.functional as F
from PIL import Image

from config.config import PREDICTIONS_TRAIN, IMAGENET_DIR


class FilenameLabelDataset(Dataset):
    def __init__(self, samples, transform=None):
        self.samples = samples  # List of (path, label)
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        image = Image.open(path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

def main():
    # Load config
    DATASET_PATH = os.path.join(IMAGENET_DIR,"split_data/train")
    OUTPUT_DIR = PREDICTIONS_TRAIN
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # Set device
    device = (
        torch.device("cuda") if torch.cuda.is_available()
        else torch.device("mps") if torch.backends.mps.is_available()
        else torch.device("cpu")
    )

    # Image preprocessing
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
    ])

    # Collect image paths and parse true labels from filenames
    all_image_paths = []
    for root, _, files in os.walk(DATASET_PATH):
        for file in files:
            if file.lower().endswith((".jpg", ".jpeg", ".png")):
                all_image_paths.append(os.path.join(root, file))

    corrected_samples = []
    for path in all_image_paths:
        filename = os.path.basename(path)
        try:
            label = int(filename.split("_")[0])
            corrected_samples.append((path, label))
        except ValueError:
            print(f"Skipping file with unexpected name format: {filename}")

    # Create custom dataset
    dataset = FilenameLabelDataset(corrected_samples, transform=transform)
    dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4)

    image_ids = [os.path.basename(p).split(".")[0] for p, _ in corrected_samples]

    # Models to evaluate
    model_specs = {
        "deit_tiny_patch16_224": lambda: timm.create_model('deit_tiny_patch16_224', pretrained=True),
        "deit_small_patch16_224": lambda: timm.create_model('deit_small_patch16_224', pretrained=True),
        "deit_base_patch16_224": lambda: timm.create_model('deit_base_patch16_224', pretrained=True),
        "resnet101": lambda: models.resnet101(weights=ResNet101_Weights.DEFAULT),
        "resnet50": lambda: models.resnet50(weights=ResNet50_Weights.DEFAULT),
        "shufflenet_v2_x1_0": lambda: models.shufflenet_v2_x1_0(weights=ShuffleNet_V2_X1_0_Weights.DEFAULT)
    }

    # Track overall model accuracies
    model_accuracies = []

    # Evaluate all models
    for model_name, constructor in model_specs.items():
        print(f"\nEvaluating {model_name}...")
        model = constructor().to(device)
        model.eval()

        true_scores = []
        pred_scores = []
        pred_classes = []
        true_classes = []

        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in tqdm(dataloader, desc=f"{model_name}"):
                images = images.to(device)
                labels = labels.to(device)

                logits = model(images)
                probs = F.softmax(logits, dim=1)
                preds = torch.argmax(probs, dim=1)

                # Scores
                pred_score = probs[torch.arange(len(probs)), preds].cpu().tolist()
                true_score = probs[torch.arange(len(probs)), labels].cpu().tolist()

                # Track
                pred_scores.extend(pred_score)
                true_scores.extend(true_score)
                pred_classes.extend(preds.cpu().tolist())
                true_classes.extend(labels.cpu().tolist())

                correct += (preds == labels).sum().item()
                total += labels.size(0)

        # Accuracy
        accuracy = correct / total if total > 0 else 0.0
        print(f"Accuracy for {model_name}: {accuracy:.4f}")

        # Save model accuracy for summary
        model_accuracies.append({
            "model_name": model_name,
            "accuracy": accuracy
        })

        # Save per-image predictions
        df = pd.DataFrame({
            f"{model_name}_true_score": true_scores,
            f"{model_name}_pred_score": pred_scores,
            f"{model_name}_pred_class": pred_classes,
            f"{model_name}_true_class": true_classes,
        }, index=image_ids)
        df.index.name = "image_id"

        out_path = os.path.join(OUTPUT_DIR, f"{model_name}.csv")
        df.to_csv(out_path)
        print(f"Saved predictions to {out_path}")

    # Save overall accuracy summary
    summary_df = pd.DataFrame(model_accuracies)
    summary_df.to_csv(os.path.join(OUTPUT_DIR, "summary.csv"), index=False)
    print(f"\nSaved summary accuracy report to {os.path.join(OUTPUT_DIR, 'summary.csv')}")
    print("\n=== All models evaluated ===")

if __name__ == '__main__':
    main()
