from datasets import load_dataset, DownloadConfig, Image as HFImage
import numpy as np
import os
from transformers import AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer
from transformers import default_data_collator
import torch
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
import argparse

import random
import cv2
from PIL import Image, ImageFile, UnidentifiedImageError
import io
from torchvision import transforms
from torch.utils.data import Dataset, Subset
from collections import Counter
from pathlib import Path
ImageFile.LOAD_TRUNCATED_IMAGES = True     # ← tolerate partial / corrupt files


def estimate_blur_laplacian(img_np):
    gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
    return cv2.Laplacian(gray, cv2.CV_64F).var()

def degrade_image_to_match_laion5(img_pil, real_blur_vals, real_res_vals,
                                  noise_var=0.0005, jpeg_quality_range=(70, 95), seed=None):
    """
    Lightly degrades an image to mimic real training images' resolution and blur distribution.
    """
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)

    # === Step 1: Resize to match real training image resolution ===
    if random.random() < 0.2:
        target_h, target_w = random.choice(real_res_vals)
        orig_w, orig_h = img_pil.size
        orig_area = orig_w * orig_h
        target_area = target_h * target_w
        scale = (target_area / orig_area) ** 0.5
        new_w = max(1, int(orig_w * scale))
        new_h = max(1, int(orig_h * scale))

        img_pil = img_pil.resize((new_w, new_h), Image.BILINEAR)
    img_np = np.array(img_pil)

    if random.random() < 0.2:
        target_blur = np.random.choice(real_blur_vals)
        blur_val = estimate_blur_laplacian(img_np)
        if blur_val > target_blur * 1.2:
            # GaussianBlur in OpenCV is highly optimized and releases the GIL
            img_np = cv2.GaussianBlur(img_np, (0, 0), sigmaX=0.3, sigmaY=0.3)

    if random.random() < 0.2:
        # === Step 3: Add light Gaussian noise (OpenCV) ===
        sigma = int(255 * (noise_var ** 0.5))
        if sigma > 0:
            noise = np.zeros_like(img_np, dtype=np.int16)
            cv2.randn(noise, 0, sigma)             # in‑place Gaussian noise
            img_np = cv2.add(img_np.astype(np.int16), noise, dtype=cv2.CV_8U)

    if random.random() < 0.2:
        # === Step 4: Mild JPEG compression ===
        quality = np.random.randint(*jpeg_quality_range)
        encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
        _, encimg = cv2.imencode('.jpg', img_np, encode_param)
        img_np = cv2.imdecode(encimg, cv2.IMREAD_COLOR)

    return Image.fromarray(img_np)


class SemiTruthDataset(Dataset):
    """Dataset backed by local Semi-Truths folders."""

    def __init__(self, base_dir, processor=None, transform=None):
        self.samples = []
        self.processor = processor
        self.transform = transform

        if not base_dir or not os.path.isdir(base_dir):
            raise FileNotFoundError(f"Semi-Truths directory not found at {base_dir}")

        for class_name, label in [("inpainting", 1), ("editing", 1), ("real", 0)]:
            class_dir = os.path.join(base_dir, class_name)
            if not os.path.isdir(class_dir):
                continue
            for fname in os.listdir(class_dir):
                if fname.lower().endswith((".jpg", ".jpeg", ".png")):
                    self.samples.append((os.path.join(class_dir, fname), label))

        if not self.samples:
            raise RuntimeError(f"No images discovered under {base_dir}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        if not self.samples:
            raise RuntimeError("Semi-Truths dataset is empty")

        path, label = self.samples[idx % len(self.samples)]
        for _ in range(len(self.samples)):
            try:
                image = Image.open(path).convert("RGB")
                break
            except (UnidentifiedImageError, OSError):
                idx = (idx + 1) % len(self.samples)
                path, label = self.samples[idx]
        else:
            raise RuntimeError("Failed to load any valid Semi-Truths image")

        if self.transform is not None and random.random() < 0.5:
            image = self.transform(image)

        if self.processor is None:
            raise ValueError("Processor required for SemiTruthDataset")

        pixel_values = self.processor(image, return_tensors="pt")["pixel_values"].squeeze(0)
        return pixel_values, label


class SemiTruthEval(Dataset):
    """Wrap semi-truths/Semi-Truths-Evalset and return (pixel_values, label, model_name)."""

    def __init__(self, split="train", cache_dir=None, processor=None, transform=None):
        self.ds = load_dataset(
            "semi-truths/Semi-Truths-Evalset",
            split=split,
            cache_dir=cache_dir,
        )

        self.ds = self.ds.cast_column("png", HFImage(decode=True))
        self.ds = self.ds.filter(lambda ex: ex["png"] is not None)
        self.ds = self.ds.filter(
            lambda ex: ex["png"] is not None
            and "mask" not in (ex.get("__url__") or ""),
        )

        self.processor = processor
        self.transform = transform

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        if len(self.ds) == 0:
            raise RuntimeError("Semi-Truths eval dataset is empty")

        ex = self.ds[int(idx)]
        image = ex["png"]

        if self.processor is not None:
            pixel_values = self.processor(image, return_tensors="pt")["pixel_values"].squeeze(0)
        elif self.transform is not None:
            pixel_values = self.transform(image)
        else:
            raise ValueError("Either processor or transform must be provided for SemiTruthEval")

        url = ex.get("__url__", "") or ""
        label = 0 if "original" in url else 1

        key = ex.get("__key__", "") or ""
        model_name = key.split("/")[0] if "/" in key else ("real" if label == 0 else "fake")

        return pixel_values, label, model_name


class GenImageDataset(Dataset):
    """Dataset wrapper for GenImage corpus under $SCRATCH/genimage."""

    def __init__(self, base_dir, model_names, processor=None, transform=None, split="train"):
        self.samples = []
        self.processor = processor
        self.transform = transform

        if not base_dir or not os.path.isdir(base_dir):
            raise FileNotFoundError(f"GenImage directory not found at {base_dir}")

        for model_name in model_names:
            model_dir = os.path.join(base_dir, model_name)
            if not os.path.isdir(model_dir):
                continue
            for class_name, label in [("ai", 1), ("nature", 0)]:
                class_dir = os.path.join(model_dir, split, class_name)
                if not os.path.isdir(class_dir):
                    continue
                for fname in os.listdir(class_dir):
                    if fname.lower().endswith((".jpg", ".jpeg", ".png")):
                        self.samples.append((os.path.join(class_dir, fname), label))

        if not self.samples:
            raise RuntimeError(f"No images discovered for GenImage split '{split}' at {base_dir}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        if not self.samples:
            raise RuntimeError("GenImage dataset is empty")

        path, label = self.samples[idx % len(self.samples)]
        for _ in range(len(self.samples)):
            try:
                image = Image.open(path).convert("RGB")
                break
            except (UnidentifiedImageError, OSError):
                idx = (idx + 1) % len(self.samples)
                path, label = self.samples[idx]
        else:
            raise RuntimeError("Failed to load any valid GenImage image")

        if self.transform is not None and random.random() < 0.5:
            image = self.transform(image)

        if self.processor is None:
            raise ValueError("Processor required for GenImageDataset")

        pixel_values = self.processor(image, return_tensors="pt")["pixel_values"].squeeze(0)
        return pixel_values, label


scratch_dir = os.environ.get("SCRATCH")
default_cache_dir = os.path.join(scratch_dir, ".cache") if scratch_dir else ".cache"

def main(args):
    cache_dir = os.path.abspath(args.cache_dir or default_cache_dir)
    os.makedirs(cache_dir, exist_ok=True)
    args.cache_dir = cache_dir

    dataset_choice = args.dataset.lower()
    project_map = {
        "openfake": "SwinOpenFake",
        "semi-truths": "Semi-Truths",
        "genimage": "GenImage",
    }
    if dataset_choice not in project_map:
        raise ValueError(f"Unsupported dataset choice: {args.dataset}")
    os.environ["WANDB_PROJECT"] = project_map[dataset_choice]

    processor = AutoImageProcessor.from_pretrained(
        "microsoft/swinv2-small-patch4-window16-256",
        cache_dir=cache_dir,
        use_fast=True,
    )
    model = AutoModelForImageClassification.from_pretrained(
        "microsoft/swinv2-small-patch4-window16-256",
        cache_dir=cache_dir,
    )

    model.num_labels = 2
    model.config.num_labels = 2
    model.classifier = torch.nn.Linear(model.swinv2.num_features, model.num_labels)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(256, scale=(0.5, 1.0), ratio=(0.33, 3.0)),
        transforms.ColorJitter(contrast=0.5, brightness=0.3, saturation=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.RandomHorizontalFlip(p=0.1),
        transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 1.0)),
    ])

    if dataset_choice == "openfake":
        stats_path = Path(__file__).resolve().parent.parent / "real_train_stats.npz"
        real_train_stats = np.load(stats_path)
        real_blur_vals = real_train_stats['blur_vals']
        real_res_vals = real_train_stats['res_vals']

        download_config = DownloadConfig(cache_dir=cache_dir)
        train_data = load_dataset(
            "Anonymous460/OpenFake",
            split="train",
            streaming=True,
            download_config=download_config,
        )
        eval_data = load_dataset(
            "Anonymous460/OpenFake",
            split="test",
            streaming=True,
            download_config=download_config,
        )

        def preprocess_train(example):
            image = example["image"]
            if not isinstance(image, Image.Image):
                if isinstance(image, np.ndarray):
                    image = Image.fromarray(image)
                elif isinstance(image, (bytes, bytearray)):
                    image = Image.open(io.BytesIO(image))
                elif isinstance(image, dict):
                    if image.get("bytes") is not None:
                        image = Image.open(io.BytesIO(image["bytes"]))
                    elif image.get("path"):
                        image = Image.open(image["path"])
                    else:
                        raise ValueError(f"Unsupported image dict keys: {image.keys()}")
                else:
                    raise ValueError(f"Unsupported image type: {type(image)}")
            if image.mode != "RGB":
                image = image.convert("RGB")

            raw_label = example["label"]
            label = 0 if isinstance(raw_label, str) and raw_label.lower() == "real" else int(raw_label)
            if label == 1:
                image = degrade_image_to_match_laion5(
                    image,
                    real_blur_vals,
                    real_res_vals,
                    seed=args.seed if hasattr(args, "seed") else None,
                )
            if random.random() < 0.5:
                image = train_transform(image)
            inputs = processor(image, return_tensors="pt")
            return {"pixel_values": inputs["pixel_values"].squeeze(0), "label": label}

        def preprocess_eval(example):
            image = example["image"]
            if not isinstance(image, Image.Image):
                if isinstance(image, np.ndarray):
                    image = Image.fromarray(image)
                elif isinstance(image, (bytes, bytearray)):
                    image = Image.open(io.BytesIO(image))
                elif isinstance(image, dict):
                    if image.get("bytes") is not None:
                        image = Image.open(io.BytesIO(image["bytes"]))
                    elif image.get("path"):
                        image = Image.open(image["path"])
                    else:
                        raise ValueError(f"Unsupported image dict keys: {image.keys()}")
                else:
                    raise ValueError(f"Unsupported image type: {type(image)}")
            if image.mode != "RGB":
                image = image.convert("RGB")

            raw_label = example["label"]
            label = 0 if isinstance(raw_label, str) and raw_label.lower() == "real" else int(raw_label)
            inputs = processor(image, return_tensors="pt")
            return {"pixel_values": inputs["pixel_values"].squeeze(0), "label": label}

        train_data = train_data.map(preprocess_train)
        eval_data = eval_data.map(preprocess_eval)
        data_collator = default_data_collator

        def compute_metrics(pred):
            logits = torch.from_numpy(pred.predictions)
            probs = torch.softmax(logits, dim=-1)[:, 1].cpu().numpy()
            preds = (probs >= 0.5).astype(int)
            labels = pred.label_ids
            precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary")
            auc_roc = roc_auc_score(labels, probs)
            acc = accuracy_score(labels, preds)
            return {
                "accuracy": acc,
                "precision": precision,
                "recall": recall,
                "f1": f1,
                "auc_roc": auc_roc,
            }

        max_steps = 600000 // max(args.batch_size, 1) * max(args.num_epochs, 1)
        training_args = TrainingArguments(
            output_dir=args.output_dir,
            per_device_train_batch_size=args.batch_size,
            per_device_eval_batch_size=args.batch_size,
            num_train_epochs=args.num_epochs,
            learning_rate=args.learning_rate,
            save_strategy="steps",
            save_steps=500,
            save_total_limit=20,
            logging_steps=100,
            eval_strategy="steps",
            eval_steps=500,
            metric_for_best_model="f1",
            greater_is_better=True,
            max_steps=max_steps,
            dataloader_num_workers=args.num_workers,
            dataloader_pin_memory=torch.cuda.is_available(),
            load_best_model_at_end=True,
            run_name=f"swinv2-{dataset_choice}",
            report_to="wandb",
        )

        trainer_cls = Trainer
        trainer_kwargs = {}

    elif dataset_choice == "semi-truths":
        base_dir = Path(scratch_dir or Path.cwd()) / "semi-truths"
        train_data = SemiTruthDataset(str(base_dir), processor=processor, transform=train_transform)
        eval_data = SemiTruthEval(split="train", cache_dir=cache_dir, processor=processor)

        if len(eval_data) > 10000:
            indices = np.random.choice(len(eval_data), 10000, replace=False)
            eval_data = Subset(eval_data, indices.tolist())

        counts = Counter(label for _, label in train_data.samples)
        total = sum(counts.values())
        class_weights = [total / max(counts.get(0, 1), 1), total / max(counts.get(1, 1), 1)]
        weight_tensor = torch.tensor(class_weights, dtype=torch.float32, device=device)

        def data_collator(batch):
            pixel_values = torch.stack([item[0] for item in batch])
            labels = torch.tensor([item[1] for item in batch], dtype=torch.long)
            return {"pixel_values": pixel_values, "labels": labels}

        def compute_metrics(pred):
            logits = torch.from_numpy(pred.predictions)
            probs = torch.softmax(logits, dim=-1)[:, 1].cpu().numpy()
            preds = (probs >= 0.5).astype(int)
            labels = pred.label_ids
            precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary")
            auc_roc = roc_auc_score(labels, probs)
            acc = accuracy_score(labels, preds)
            return {
                "accuracy": acc,
                "precision": precision,
                "recall": recall,
                "f1": f1,
                "auc_roc": auc_roc,
            }

        class WeightedTrainer(Trainer):
            def compute_loss(self, model, inputs, return_outputs=False):
                labels = inputs["labels"]
                outputs = model(pixel_values=inputs["pixel_values"], labels=labels)
                logits = outputs.logits
                loss_fct = torch.nn.CrossEntropyLoss(weight=weight_tensor.to(logits.device))
                loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
                return (loss, outputs) if return_outputs else loss

        training_args = TrainingArguments(
            output_dir=args.output_dir,
            per_device_train_batch_size=args.batch_size,
            per_device_eval_batch_size=args.batch_size,
            num_train_epochs=args.num_epochs,
            learning_rate=args.learning_rate,
            save_strategy="steps",
            save_steps=2000,
            save_total_limit=20,
            logging_steps=100,
            eval_strategy="steps",
            eval_steps=2000,
            metric_for_best_model="f1",
            greater_is_better=True,
            dataloader_num_workers=args.num_workers,
            dataloader_pin_memory=torch.cuda.is_available(),
            load_best_model_at_end=True,
            run_name=f"swinv2-{dataset_choice}",
            report_to="wandb",
            ddp_find_unused_parameters=False,
        )

        trainer_cls = WeightedTrainer
        trainer_kwargs = {}

    else:  # dataset_choice == "genimage"
        base_dir = Path(scratch_dir or Path.cwd()) / "genimage"
        model_folders = [
            "ADM",
            "BigGAN",
            "glide",
            "Midjourney",
            "stable_diffusion_v_1_4",
            "stable_diffusion_v_1_5",
            "VQDM",
            "wukong",
        ]
        train_data = GenImageDataset(str(base_dir), model_folders, processor=processor, transform=train_transform, split="train")
        eval_data = GenImageDataset(str(base_dir), model_folders, processor=processor, transform=train_transform, split="val")

        def data_collator(batch):
            pixel_values = torch.stack([item[0] for item in batch])
            labels = torch.tensor([item[1] for item in batch], dtype=torch.long)
            return {"pixel_values": pixel_values, "labels": labels}

        def compute_metrics(pred):
            logits = torch.from_numpy(pred.predictions)
            probs = torch.softmax(logits, dim=-1)[:, 1].cpu().numpy()
            preds = (probs >= 0.5).astype(int)
            labels = pred.label_ids
            precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary")
            auc_roc = roc_auc_score(labels, probs)
            acc = accuracy_score(labels, preds)
            return {
                "accuracy": acc,
                "precision": precision,
                "recall": recall,
                "f1": f1,
                "auc_roc": auc_roc,
            }

        training_args = TrainingArguments(
            output_dir=args.output_dir,
            per_device_train_batch_size=args.batch_size,
            per_device_eval_batch_size=args.batch_size,
            num_train_epochs=args.num_epochs,
            learning_rate=args.learning_rate,
            save_strategy="steps",
            save_steps=2000,
            save_total_limit=20,
            logging_steps=100,
            eval_strategy="steps",
            eval_steps=2000,
            metric_for_best_model="f1",
            greater_is_better=True,
            dataloader_num_workers=args.num_workers,
            dataloader_pin_memory=torch.cuda.is_available(),
            load_best_model_at_end=True,
            run_name=f"swinv2-{dataset_choice}",
            report_to="wandb",
            ddp_find_unused_parameters=False,
        )

        trainer_cls = Trainer
        trainer_kwargs = {}

    trainer = trainer_cls(
        model=model,
        args=training_args,
        train_dataset=train_data,
        eval_dataset=eval_data,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        **trainer_kwargs,
    )

    return trainer, eval_data

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train a model on OpenFake dataset")
    parser.add_argument("--output_dir", type=str, default="./swinv2-finetuned-openfake", help="Output directory for model checkpoints")
    parser.add_argument("--num_epochs", type=int, default=5, help="Number of training epochs")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training and evaluation")
    parser.add_argument("--learning_rate", type=float, default=8e-5, help="Learning rate for the optimizer")
    parser.add_argument("--num_workers", type=int, default=4,help="DataLoader worker processes")
    parser.add_argument("--cache_dir", type=str, default=default_cache_dir, help="Cache directory for datasets and models")
    parser.add_argument(
        "--dataset",
        type=str,
        default="openfake",
        choices=["openfake", "semi-truths", "genimage"],
        help="Select which dataset pipeline to use",
    )
    parser.add_argument(
        "--resume_from_checkpoint",
        type=str,
        default=None,
        help="Path to a checkpoint directory or checkpoint name to resume training from"
    )
    args = parser.parse_args()
    trainer, eval = main(args)
    trainer.train(resume_from_checkpoint=False)
    #trainer._load_from_checkpoint(resume_from_checkpoint=args.resume_from_checkpoint)
    
    # evaluate the model
    eval_results = trainer.evaluate(eval_dataset=eval)
    print(f"Evaluation results: {eval_results}")
