from datasets import load_dataset, DownloadConfig, Image as HFImage
import os
import argparse
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoImageProcessor, AutoModelForImageClassification

from PIL import Image, ImageFile, UnidentifiedImageError
ImageFile.LOAD_TRUNCATED_IMAGES = True

from safetensors.torch import load_file as safe_load_file

import io
from collections import defaultdict
import numpy as np
from tqdm import tqdm
import cv2
import random
from pathlib import Path

from sklearn.metrics import roc_auc_score, f1_score, accuracy_score


class MetricAccumulator:
    def __init__(self):
        self.per_model = defaultdict(lambda: {"TP": 0, "TN": 0, "FP": 0, "FN": 0})
        self.labels = []
        self.preds = []
        self.scores = []

    def add(self, labels, preds, probs, model_names):
        for label, pred, prob, name in zip(labels, preds, probs, model_names):
            label = int(label)
            pred = int(pred)
            prob = float(prob)
            name = name or ("real" if label == 0 else "unknown")
            stats = self.per_model[name]
            if label == 1:
                if pred == 1:
                    stats["TP"] += 1
                else:
                    stats["FN"] += 1
            else:
                if pred == 0:
                    stats["TN"] += 1
                else:
                    stats["FP"] += 1
            self.labels.append(label)
            self.preds.append(pred)
            self.scores.append(prob)

    def per_model_metrics(self):
        results = {}
        for name, stats in self.per_model.items():
            TP, TN, FP, FN = stats["TP"], stats["TN"], stats["FP"], stats["FN"]
            total = TP + TN + FP + FN
            acc = (TP + TN) / total if total else 0.0
            tpr = TP / (TP + FN) if (TP + FN) else 0.0
            tnr = TN / (TN + FP) if (TN + FP) else 0.0
            results[name] = {"acc": acc, "tpr": tpr, "tnr": tnr, "total": total, "correct": TP + TN}
        return results

    def overall_metrics(self):
        if not self.labels:
            return {"auc_roc": float("nan"), "f1": float("nan"), "accuracy": float("nan")}
        labels = np.array(self.labels)
        preds = np.array(self.preds)
        scores = np.array(self.scores)
        metrics = {}
        try:
            metrics["auc_roc"] = roc_auc_score(labels, scores)
        except ValueError:
            metrics["auc_roc"] = float("nan")
        try:
            metrics["f1"] = f1_score(labels, preds)
        except ValueError:
            metrics["f1"] = float("nan")
        metrics["accuracy"] = accuracy_score(labels, preds)
        return metrics


def estimate_blur_laplacian(img_np):
    gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
    return cv2.Laplacian(gray, cv2.CV_64F).var()


def degrade_image_to_match_laion5(img_pil, real_blur_vals, real_res_vals,
                                  noise_var=0.0005, jpeg_quality_range=(70, 95), seed=None):
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)

    target_h, target_w = random.choice(real_res_vals)
    orig_w, orig_h = img_pil.size
    orig_area = orig_w * orig_h
    target_area = target_h * target_w
    scale = (target_area / orig_area) ** 0.5
    new_w = max(1, int(orig_w * scale))
    new_h = max(1, int(orig_h * scale))

    img_pil = img_pil.resize((new_w, new_h), Image.BILINEAR)
    img_np = np.array(img_pil)

    target_blur = np.random.choice(real_blur_vals)
    blur_val = estimate_blur_laplacian(img_np)
    if blur_val > target_blur * 1.2:
        img_np = cv2.GaussianBlur(img_np, (0, 0), sigmaX=0.3, sigmaY=0.3)

    sigma = int(255 * (noise_var ** 0.5))
    if sigma > 0:
        noise = np.zeros_like(img_np, dtype=np.int16)
        cv2.randn(noise, 0, sigma)
        img_np = cv2.add(img_np.astype(np.int16), noise, dtype=cv2.CV_8U)

    quality = np.random.randint(*jpeg_quality_range)
    encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
    _, encimg = cv2.imencode('.jpg', img_np, encode_param)
    img_np = cv2.imdecode(encimg, cv2.IMREAD_COLOR)

    return Image.fromarray(img_np)


class SemiTruthEvalDataset(Dataset):
    def __init__(self, split="train", cache_dir=None, processor=None):
        self.processor = processor
        self.ds = load_dataset(
            "semi-truths/Semi-Truths-Evalset",
            split=split,
            cache_dir=cache_dir,
        )
        self.ds = self.ds.cast_column("png", HFImage(decode=True))
        self.ds = self.ds.filter(lambda ex: ex["png"] is not None)
        self.ds = self.ds.filter(
            lambda ex: ex["png"] is not None
            and "mask" not in (ex.get("__url__") or ""),
        )

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        ex = self.ds[int(idx)]
        image = ex["png"]
        if self.processor is None:
            raise ValueError("Processor must be provided for SemiTruthEvalDataset")
        pixel_values = self.processor(image, return_tensors="pt")["pixel_values"].squeeze(0)

        url = ex.get("__url__", "") or ""
        label = 0 if "original" in url else 1
        key = ex.get("__key__", "") or ""
        model_name = key.split("/")[0] if "/" in key else ("real" if label == 0 else "fake")
        if label == 0:
            model_name = "real"
        return pixel_values, label, model_name


class GenImageEvalDataset(Dataset):
    def __init__(self, base_dir, model_names, processor=None, split="val"):
        self.processor = processor
        self.samples = []
        self.base_dir = base_dir
        if not base_dir or not os.path.isdir(base_dir):
            raise FileNotFoundError(f"GenImage directory not found at {base_dir}")

        for model_name in model_names:
            model_dir = os.path.join(base_dir, model_name)
            if not os.path.isdir(model_dir):
                continue
            for class_name, label in [("ai", 1), ("nature", 0)]:
                class_dir = os.path.join(model_dir, split, class_name)
                if not os.path.isdir(class_dir):
                    continue
                for fname in os.listdir(class_dir):
                    if fname.lower().endswith((".jpg", ".jpeg", ".png")):
                        self.samples.append((os.path.join(class_dir, fname), label, model_name))

        if not self.samples:
            raise RuntimeError(f"No images discovered for GenImage split '{split}' at {base_dir}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label, model_name = self.samples[idx % len(self.samples)]
        for _ in range(len(self.samples)):
            try:
                image = Image.open(path).convert("RGB")
                break
            except (UnidentifiedImageError, OSError):
                idx = (idx + 1) % len(self.samples)
                path, label, model_name = self.samples[idx]
        else:
            raise RuntimeError("Failed to load any valid GenImage image")

        if self.processor is None:
            raise ValueError("Processor required for GenImageEvalDataset")
        pixel_values = self.processor(image, return_tensors="pt")["pixel_values"].squeeze(0)
        model_out = model_name if label == 1 else "real"
        return pixel_values, label, model_out


def collate_eval_batch(batch):
    pixel_values = torch.stack([item[0] for item in batch])
    labels = torch.tensor([item[1] for item in batch], dtype=torch.long)
    model_names = [item[2] for item in batch]
    return pixel_values, labels, model_names


def evaluate_openfake(model, processor, args, device):
    stats_path = Path(__file__).resolve().parent.parent / "real_train_stats.npz"
    real_train_stats = np.load(stats_path)
    real_blur_vals = real_train_stats['blur_vals']
    real_res_vals = real_train_stats['res_vals']

    download_config = DownloadConfig(cache_dir=args.cache_dir)
    dataset = load_dataset(
        "Anonymous460/OpenFake",
        split="test",
        streaming=True,
        download_config=download_config,
    )

    def preprocess(example):
        image = example["image"]
        if not isinstance(image, Image.Image):
            if isinstance(image, np.ndarray):
                image = Image.fromarray(image)
            elif isinstance(image, (bytes, bytearray)):
                image = Image.open(io.BytesIO(image))
            elif isinstance(image, dict):
                if image.get("bytes") is not None:
                    image = Image.open(io.BytesIO(image["bytes"]))
                elif image.get("path"):
                    image = Image.open(image["path"])
                else:
                    raise ValueError(f"Unsupported image dict keys: {image.keys()}")
            else:
                raise ValueError(f"Unsupported image type: {type(image)}")
        if image.mode != "RGB":
            image = image.convert("RGB")

        raw_label = example["label"]
        if isinstance(raw_label, str):
            label = 0 if raw_label.lower() == "real" else 1
        else:
            label = int(raw_label)

        if label == 1 and args.degradation:
            image = degrade_image_to_match_laion5(
                image,
                real_blur_vals,
                real_res_vals,
            )
        inputs = processor(image, return_tensors="pt")
        model_name = example.get("model")
        if not model_name:
            model_name = "real" if label == 0 else "unknown"
        return {
            "pixel_values": inputs["pixel_values"].squeeze(0),
            "label": label,
            "model_name": model_name,
        }

    dataset = dataset.map(preprocess)

    batch_size = args.batch_size
    buffer_pixels, buffer_labels, buffer_models = [], [], []
    accumulator = MetricAccumulator()

    def process_batch():
        if not buffer_pixels:
            return
        pixel_batch = torch.stack(buffer_pixels).to(device)
        labels = buffer_labels.copy()
        models = buffer_models.copy()
        buffer_pixels.clear()
        buffer_labels.clear()
        buffer_models.clear()
        with torch.no_grad():
            outputs = model(pixel_batch)
            logits = outputs.logits
            probs = torch.softmax(logits, dim=-1)[:, 1].cpu().numpy()
            preds = logits.argmax(dim=-1).cpu().numpy()
        accumulator.add(labels, preds.tolist(), probs.tolist(), models)

    for example in tqdm(dataset, desc="Evaluating OpenFake"):
        buffer_pixels.append(example["pixel_values"])
        buffer_labels.append(int(example["label"]))
        buffer_models.append(example["model_name"] or ("real" if example["label"] == 0 else "unknown"))
        if len(buffer_pixels) >= batch_size:
            process_batch()
    process_batch()
    return accumulator


def evaluate_with_loader(model, dataloader, device, desc="Evaluating"):
    accumulator = MetricAccumulator()
    model.eval()
    with torch.no_grad():
        for pixel_values, labels, model_names in tqdm(dataloader, desc=desc):
            pixel_values = pixel_values.to(device)
            logits = model(pixel_values).logits
            probs = torch.softmax(logits, dim=-1)[:, 1].cpu().numpy()
            preds = (probs >= 0.5).astype(int)
            accumulator.add(labels.cpu().numpy().tolist(), preds.tolist(), probs.tolist(), model_names)
    return accumulator


def print_summary(accumulator):
    per_model = accumulator.per_model_metrics()
    for name in sorted(per_model.keys()):
        stats = per_model[name]
        total = stats["total"]
        correct = stats["correct"]
        acc = stats["acc"]
        tpr = stats["tpr"]
        tnr = stats["tnr"]
        print(f"Model: {name} — Acc: {acc:.4f} ({correct}/{total}), TPR: {tpr:.4f}, TNR: {tnr:.4f}")

    overall = accumulator.overall_metrics()
    auc = overall["auc_roc"]
    f1 = overall["f1"]
    acc = overall["accuracy"]
    print(f"Overall — AUC-ROC: {auc:.4f}, F1: {f1:.4f}, Acc: {acc:.4f}")


def load_model_and_processor(cache_dir, checkpoint_path=None):
    processor = AutoImageProcessor.from_pretrained(
        "microsoft/swinv2-small-patch4-window16-256", cache_dir=cache_dir, use_fast=True
    )
    model = AutoModelForImageClassification.from_pretrained(
        "microsoft/swinv2-small-patch4-window16-256", cache_dir=cache_dir
    )
    model.num_labels = 2
    model.config.num_labels = 2
    model.classifier = torch.nn.Linear(model.swinv2.num_features, model.num_labels)

    if checkpoint_path:
        if checkpoint_path.endswith(".safetensors"):
            state = safe_load_file(checkpoint_path, device='cpu')
        else:
            ckpt_file = checkpoint_path
            if os.path.isdir(checkpoint_path):
                ckpt_file = os.path.join(checkpoint_path, "pytorch_model.bin")
            state = torch.load(ckpt_file, map_location='cpu')
        model.load_state_dict(state, strict=False)
    return processor, model


def main():
    scratch_dir = os.environ.get("SCRATCH")
    default_cache_dir = os.path.join(scratch_dir, ".cache") if scratch_dir else ".cache"

    parser = argparse.ArgumentParser(description="Evaluate SwinV2 detectors across datasets")
    parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Path to checkpoint to load")
    parser.add_argument("--num_workers", type=int, default=4, help="DataLoader worker processes")
    parser.add_argument("--cache_dir", type=str, default=default_cache_dir, help="Cache directory for datasets and models")
    parser.add_argument("--batch_size", type=int, default=64, help="Evaluation batch size")
    parser.add_argument("--dataset", type=str, default="openfake", choices=["openfake", "semi-truths", "genimage"], help="Dataset to evaluate on")
    parser.add_argument('--degradation', action='store_true', help="Apply degradation to OpenFake fake images before evaluation")
    args = parser.parse_args()

    args.cache_dir = os.path.abspath(args.cache_dir or default_cache_dir)
    os.makedirs(args.cache_dir, exist_ok=True)

    processor, model = load_model_and_processor(args.cache_dir, args.resume_from_checkpoint)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    dataset_choice = args.dataset.lower()
    if dataset_choice == "openfake":
        accumulator = evaluate_openfake(model, processor, args, device)
    elif dataset_choice == "semi-truths":
        dataset = SemiTruthEvalDataset(split="train", cache_dir=args.cache_dir, processor=processor)
        loader = DataLoader(
            dataset,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            collate_fn=collate_eval_batch,
        )
        accumulator = evaluate_with_loader(model, loader, device, desc="Evaluating Semi-Truths")
    else:  # genimage
        base_dir = Path(os.environ.get("SCRATCH") or Path.cwd()) / "genimage"
        model_folders = [
            "ADM",
            "BigGAN",
            "glide",
            "Midjourney",
            "stable_diffusion_v_1_4",
            "stable_diffusion_v_1_5",
            "VQDM",
            "wukong",
        ]
        dataset = GenImageEvalDataset(str(base_dir), model_folders, processor=processor, split="val")
        loader = DataLoader(
            dataset,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            collate_fn=collate_eval_batch,
        )
        accumulator = evaluate_with_loader(model, loader, device, desc="Evaluating GenImage")

    print_summary(accumulator)


if __name__ == "__main__":
    main()
