import argparse
import os
from pathlib import Path
from typing import Iterable, List, Tuple

import numpy as np
import torch
from safetensors.torch import load_file as safe_load_file
from sklearn.metrics import roc_auc_score, roc_curve
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AutoImageProcessor, AutoModelForImageClassification

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


class WildImageDataset(Dataset):
    """Image dataset backed by a directory structure of labeled subfolders."""

    def __init__(self, root: str, subsets: Iterable[str], processor: AutoImageProcessor):
        self.root = root
        self.processor = processor
        self.samples: List[Tuple[str, int]] = []  # (path, label)

        for subset in subsets:
            subset_dir = os.path.join(root, subset)
            label = 0 if subset.lower() == "real" else 1
            if not os.path.isdir(subset_dir):
                continue
            for fname in os.listdir(subset_dir):
                if fname.lower().endswith((".jpg", ".jpeg", ".png")):
                    path = os.path.join(subset_dir, fname)
                    if os.path.isfile(path):
                        self.samples.append((path, label))

        if not self.samples:
            raise RuntimeError(f"No images found under {root} for subsets {list(subsets)}")

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int):
        path, label = self.samples[idx]
        image = Image.open(path).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt")["pixel_values"].squeeze(0)
        return pixel_values, path, label


def collate_fn(batch):
    pixel_vals = torch.stack([item[0] for item in batch], dim=0)
    paths = [item[1] for item in batch]
    labels = torch.tensor([item[2] for item in batch], dtype=torch.long)
    return pixel_vals, paths, labels


def load_model(cache_dir: str, checkpoint_path: str | None):
    processor = AutoImageProcessor.from_pretrained(
        "microsoft/swinv2-small-patch4-window16-256", cache_dir=cache_dir, use_fast=True
    )
    model = AutoModelForImageClassification.from_pretrained(
        "microsoft/swinv2-small-patch4-window16-256", cache_dir=cache_dir
    )
    model.num_labels = 2
    model.config.num_labels = 2
    model.classifier = torch.nn.Linear(model.swinv2.num_features, model.num_labels)

    if checkpoint_path:
        ckpt_path = Path(checkpoint_path)
        if ckpt_path.suffix == ".safetensors":
            state = safe_load_file(str(ckpt_path), device="cpu")
        else:
            weights_file = ckpt_path / "pytorch_model.bin" if ckpt_path.is_dir() else ckpt_path
            state = torch.load(weights_file, map_location="cpu")
        model.load_state_dict(state, strict=False)

    return processor, model


def evaluate(model, dataloader, device, threshold: float = 0.5):
    model.eval()
    results = {}
    all_probs: List[float] = []
    all_targets: List[int] = []

    with torch.no_grad():
        for pixel_values, paths, labels in tqdm(dataloader, desc="Evaluating wild set"):
            pixel_values = pixel_values.to(device)
            labels_np = labels.numpy()
            outputs = model(pixel_values)
            logits = outputs.logits
            probs = torch.softmax(logits, dim=-1)[:, 1].cpu().numpy()
            preds = (probs >= threshold).astype(int)

            all_probs.extend(probs.tolist())
            all_targets.extend(labels_np.tolist())

            for prob, pred, label, path in zip(probs, preds, labels_np, paths):
                subset = os.path.basename(os.path.dirname(path)) or "unknown"
                stats = results.setdefault(subset, {"correct": 0, "total": 0, "errors": []})
                stats["total"] += 1
                if pred == label:
                    stats["correct"] += 1
                else:
                    stats["errors"].append((path, float(prob), int(label), int(pred)))

    return results, np.array(all_probs), np.array(all_targets)


def summarize(results, probs: np.ndarray, targets: np.ndarray):
    for subset in sorted(results.keys()):
        total = results[subset]["total"]
        correct = results[subset]["correct"]
        acc = correct / total if total else 0.0
        print(f"{subset}: {acc:.4f} ({correct}/{total})")

    if targets.size == 0:
        print("No targets available to compute global metrics.")
        return

    auc = roc_auc_score(targets, probs)
    print(f"\nOverall ROC AUC: {auc:.4f}")

    preds = (probs >= 0.5).astype(int)
    correct_mask = preds == targets

    def accuracy_with_sem(mask):
        if mask.size == 0:
            return float("nan"), float("nan")
        acc = mask.mean()
        sem = np.sqrt(acc * (1 - acc) / mask.size)
        return acc, sem

    overall_acc, overall_sem = accuracy_with_sem(correct_mask)
    real_mask = targets == 0
    fake_mask = targets == 1
    real_acc, real_sem = accuracy_with_sem(correct_mask[real_mask]) if real_mask.any() else (float("nan"), float("nan"))
    fake_acc, fake_sem = accuracy_with_sem(correct_mask[fake_mask]) if fake_mask.any() else (float("nan"), float("nan"))

    print(f"Accuracy (overall): {overall_acc:.4f} ± {overall_sem:.4f} (SEM)")
    print(f"Accuracy (real): {real_acc:.4f} ± {real_sem:.4f} (SEM)")
    print(f"Accuracy (fake): {fake_acc:.4f} ± {fake_sem:.4f} (SEM)")

    fpr, tpr, _ = roc_curve(targets, probs)
    return fpr, tpr, auc


def main():
    scratch_dir = os.environ.get("SCRATCH")
    default_cache = os.path.join(scratch_dir, ".cache") if scratch_dir else ".cache"
    default_dataset = os.path.join(scratch_dir, "wild") if scratch_dir else "wild"

    parser = argparse.ArgumentParser(description="Evaluate SwinV2 on wild image dataset")
    parser.add_argument("--checkpoint_path", type=str, default=None, help="Checkpoint (dir or file) to load weights from")
    parser.add_argument("--dataset_dir", type=str, default=default_dataset, help="Path to the wild dataset root")
    parser.add_argument("--subsets", type=str, nargs="*", default=["real", "fake"], help="List of subfolders to evaluate")
    parser.add_argument("--cache_dir", type=str, default=default_cache, help="Cache directory for models and processor")
    parser.add_argument("--batch_size", type=int, default=32, help="Evaluation batch size")
    parser.add_argument("--num_workers", type=int, default=4, help="Number of DataLoader workers")
    parser.add_argument("--threshold", type=float, default=0.5, help="Decision threshold on fake probability")
    parser.add_argument("--save_roc", type=str, default="auc_curve.png", help="Optional path to save ROC curve plot")
    args = parser.parse_args()

    cache_dir = os.path.abspath(args.cache_dir)
    os.makedirs(cache_dir, exist_ok=True)

    processor, model = load_model(cache_dir, args.checkpoint_path)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    dataset = WildImageDataset(args.dataset_dir, args.subsets, processor)
    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        collate_fn=collate_fn,
    )

    results, probs, targets = evaluate(model, dataloader, device, threshold=args.threshold)
    roc_data = summarize(results, probs, targets)

    if roc_data and args.save_roc:
        fpr, tpr, auc = roc_data
        import matplotlib.pyplot as plt

        plt.figure()
        plt.plot(fpr, tpr, label=f"ROC curve (AUC = {auc:.4f})")
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate")
        plt.legend(loc="lower right")
        plt.tight_layout()
        plt.savefig(args.save_roc)
        plt.close()


if __name__ == "__main__":
    main()
