from torch.utils.tensorboard import SummaryWriter
import os
from torchmetrics import Accuracy
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
import time
import multiprocessing
from pathlib import Path
from torchvision.models import resnet18, mobilenetv3, mobilenet_v3_small, mobilenet_v3_large
from models.vit_small import ViT
from models.swin import swin_t
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from typing import List, Tuple
import json
from torchvision.transforms import InterpolationMode
from PIL import Image
# The baseline models used models from https://github.com/kentaroy47/vision-transformers-cifar10
# And the Gaze-CIFAR-10 dataset for the last test, available at https://github.com/rekkles2/Gaze-CIFAR-10

#PGD and occlusion setting:
pgd_eps= 8 / 255
pgd_alpha = 2 / 255
pgd_steps = 20
occl_patch= 16
occl_K= 5

size = 32  # 32 for CIFAR, 48 for FER
normalize1 = transforms.Normalize((0.485,), (0.229,))
normalize2 = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
normalize3 = transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
trans_FER = transforms.Compose([transforms.ToTensor(), normalize1])
trans_CIFAR10 = transforms.Compose([transforms.ToTensor(), normalize2])
trans_CIFAR100 = transforms.Compose([transforms.ToTensor(), normalize3])
trans2 = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])

mean = (0.4914, 0.4822, 0.4465)
std = (0.2023, 0.1994, 0.2010)
size = 32
#mean = (0.5071, 0.4867, 0.4408)
#std = (0.2675, 0.2565, 0.2761)
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.Resize(size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])
transform_test = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])
#train_val_dataset = datasets.ImageFolder(root="data/fer2013/train", transform=trans_FER)
#test_dataset = datasets.ImageFolder(root="data/fer2013/test", transform=trans_FER)
train_val_dataset = datasets.CIFAR10("./data/", train=True, download=True, transform=trans_CIFAR10)
test_dataset = datasets.CIFAR10("./data/", train=False, download=True, transform=trans_CIFAR10)
#train_val_dataset = datasets.CIFAR10("./data/", train=True, download=True, transform=transform_train)
#test_dataset = datasets.CIFAR10("./data/", train=False, download=True, transform=transform_test)
#train_val_dataset = datasets.CIFAR100("./data/", train=True, download=True, transform=trans_CIFAR100)
#test_dataset = datasets.CIFAR100("./data/", train=False, download=True, transform=trans_CIFAR100)

valid_size=0.1
random_seed = 1
batch_size = 128
num_workers = 1
pin_memory = False
num_train = len(train_val_dataset)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))
np.random.seed(random_seed)
np.random.shuffle(indices)
train_idx, valid_idx = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

train_dataloader = torch.utils.data.DataLoader(
        train_val_dataset,
        batch_size=batch_size,
        sampler=train_sampler,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=True
    )

val_dataloader = torch.utils.data.DataLoader(
        train_val_dataset,
        batch_size=batch_size,
        sampler=valid_sampler,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=True
    )

test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=True
    )

# ---------------------------------------------------------------------
# Models

class ResNet18ForCIFAR(nn.Module):
    def __init__(self, num_classes: int = 10, pretrained=False):
        super().__init__()
        self.model = resnet18(pretrained=False)# True
        self.model.fc = nn.Linear(512, num_classes)  # Adapt output for 10 classes

    def forward(self, x):
        return self.model(x)



class MobileNetV3SmallCIFAR(nn.Module):
    """
    MobileNetV3‑Small for CIFAR‑10 / CIFAR‑100.
      • first conv stride 1  (32×32 input)
      • classifier: 576 → 1024 → num_classes
    """
    def __init__(self, num_classes: int = 10, pretrained: bool = False):
        super().__init__()

        # ---- load backbone ----
        self.model = mobilenet_v3_large(weights=None)#'IMAGENET1K_V1'

        # ---- adapt first conv (stride 1) ----
        old_conv = self.model.features[0][0]          # Conv2d(3,16,ks=3,stride=2,pad=1)
        new_conv = nn.Conv2d(
            in_channels=3, out_channels=16,
            kernel_size=3, stride=1, padding=1, bias=False
        )
        # copy weights shrunk to stride‑1 conv
        new_conv.weight.data.copy_(old_conv.weight.data)
        self.model.features[0][0] = new_conv          # keep BN + activation

        # ---- replace LAST Linear only ----
        # original classifier: [0] Linear(576,1024), [1] Hardswish, [2] Dropout, [3] Linear(1024,1000)
        in_features_last = self.model.classifier[-1].in_features  # 1024
        self.model.classifier[-1] = nn.Linear(in_features_last, num_classes)

    def forward(self, x):
        return self.model(x)

def _random_mask(x, patch=16, fill=0.0):
    """Apply a single square mask of size `patch` at a random position."""
    B, C, H, W = x.size()
    y0 = torch.randint(0, H - patch, (B,), device=x.device)
    x0 = torch.randint(0, W - patch, (B,), device=x.device)
    for i in range(B):
        x[i, :, y0[i]:y0[i] + patch, x0[i]:x0[i] + patch] = fill
    return x


def _pgd_linf(forward, x, y,
              eps=8/255, alpha=2/255, steps=20):
    # random start
    x_adv = (x + torch.empty_like(x).uniform_(-eps, eps)).clamp(0, 1)

    for _ in range(steps):
        x_adv.requires_grad_(True)              #  make it a leaf each iter
        with torch.enable_grad():               #  ensure graph recording
            logits = forward(x_adv)             #     (no .detach()!)
            loss   = F.cross_entropy(logits, y)

        grad, = torch.autograd.grad(loss, x_adv, retain_graph=False)
        x_adv = (x_adv + alpha * grad.sign()).detach()   #   detach & clip
        x_adv = torch.max(torch.min(x_adv, x + eps), x - eps).clamp(0, 1)

    return x_adv

class UpsampledCIFAR10Test(Dataset):
    """
    CIFAR-10 *test* split recorded at 1024×1024 with human gaze.
    Each mapping JSON lists the up-sampled JPEGs that correspond
    to one CIFAR class.

    Parameters
    ----------
    mapping_dir : str | Path
        Folder that contains `gaze2cifar_map_0.json … _9.json`.
    gaze_root   : str | Path
        Root directory that holds the class sub-folders (`0/ … 9/`)
        with the `pXXX.jpg` files.
    transform   : torchvision.transforms.Compose  (optional)
        If omitted, a default “Resize→ToTensor→Normalize” pipeline is used.
    """

    def __init__(
        self,
        mapping_dir: str | Path,
        gaze_root: str | Path,
        transform: transforms.Compose | None = None,
        verbose: bool = True,
    ):
        self.samples: List[Tuple[str, int]] = []
        dropped: List[str] = []  # to count / show misses

        mapping_dir = Path(mapping_dir)
        gaze_root   = Path(gaze_root)

        # ----------------------------------------------------------
        #  Collect (img_path, label) tuples from the 10 mapping JSONs
        # ----------------------------------------------------------
        skipped = 0
        for cls in range(9):
            map_file = mapping_dir / f"gaze2cifar_map_{cls}.json"
            if not map_file.exists():
                if verbose:
                    print(f"[UpsampledCIFAR] mapping file missing: {map_file}")
                continue

            with open(map_file) as f:
                mapping: dict = json.load(f)

            for fname, meta in mapping.items():
                if meta.get("split") != "test":
                    continue

                label = meta["label"]  # 0‥9
                img_path = gaze_root / str(label) / fname

                if img_path.is_file():
                    self.samples.append((img_path.as_posix(), label))
                else:
                    dropped.append(img_path.as_posix())

            if verbose:
                print(f"[UpsampledCIFAR] kept {len(self.samples):,} images ; "
                      f"dropped {len(dropped):,} missing files")

            if not self.samples:
                raise RuntimeError("No valid samples – check paths & JSONs.")

        # ----------------------------------------------------------
        #  Standard CIFAR-10 preprocessing (down-sample + normalise)
        # ----------------------------------------------------------
        if transform is None:
            mean = (0.4914, 0.4822, 0.4465)
            std  = (0.2023, 0.1994, 0.2010)
            transform = transforms.Compose([
                transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ])
        self.transform = transform

    # -------------  Dataset interface  ---------------------------
    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int):
        # Safety-net: if a file vanishes after construction, skip on the fly
        path, label = self.samples[idx]

        try:
            img = Image.open(path).convert("RGB")
        except FileNotFoundError:
            # remove broken entry so we never try again
            del self.samples[idx]
            # choose a new index modulo new length
            return self.__getitem__(idx % len(self.samples))

        img = self.transform(img)
        return img, label



# -----------------------------------------------------------------
#  Convenience helper -- returns a DataLoader ready for testing
# -----------------------------------------------------------------
def get_upsampled_cifar10_test_loader(
    mapping_dir: str | Path,
    gaze_root: str | Path,
    batch_size: int = 128,
    num_workers: int = 4,
    pin_memory: bool = False,
    shuffle: bool = False,
) -> DataLoader:
    """
    Build a PyTorch DataLoader that behaves exactly like the
    existing CIFAR-10 *test* loader in your code-base.
    """
    dataset = UpsampledCIFAR10Test(mapping_dir, gaze_root)
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,             # keep deterministic order by default
        num_workers=num_workers,
        pin_memory=pin_memory,
    )

if __name__ == "__main__":
    multiprocessing.freeze_support()
    num_classes = 10
    MODEL_PATH = Path("models")
    MODEL_PATH.mkdir(parents=True, exist_ok=True)
    MODEL_NAME = "ViT_CIFAR_o.pth"
    #MODEL_NAME = "ResNet_CIFAR.pth"
    #MODEL_NAME = "MobileNetl_CIFAR.pth"
    MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

    # Loading the saved model
    net = ViT(
            image_size=size,
            patch_size=4,
            num_classes=10,
            dim=int(512),
            depth=4,
            heads=6,
            mlp_dim=256,
            dropout=0.1,
            emb_dropout=0.1
    )
    swinnet = swin_t(window_size=4,
                         num_classes=10,
                         downscaling_factors=(2, 2, 2, 1))
    model_loaded = net
    # model_loaded = ResNet18ForCIFAR(num_classes = num_classes)
    # model_loaded = MobileNetV3SmallCIFAR(num_classes = num_classes)
    # '''
    model_loaded.load_state_dict(torch.load(MODEL_SAVE_PATH))
    accuracy = Accuracy(task='multiclass', num_classes=10)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    accuracy = accuracy.to(device)
    loss_fn = nn.CrossEntropyLoss()

    test_loss, test_acc = 0, 0

    model_loaded.to(device)

    model_loaded.eval()
    tic = time.time()
    with torch.inference_mode():
            for X, y in test_dataloader:
                X, y = X.to(device), y.to(device)
                y_pred = model_loaded(X)

                test_loss += loss_fn(y_pred, y)
                test_acc += accuracy(y_pred, y)

            test_loss /= len(test_dataloader)
            test_acc /= len(test_dataloader)


    elapsed = (time.time() - tic) / len(test_dataloader.dataset)* batch_size

    print(f"[Clean test] Test loss: {test_loss: .5f}| Acc: {test_acc:.4f} | sec/img: {elapsed:.5f}")

    test_loss, test_acc = 0, 0
    def f(z): return model_loaded(z.to(device))

    tic = time.time()
    #with torch.inference_mode():
    for x, y in test_dataloader:
        model_loaded.train()
        with torch.enable_grad():
            x, y = x.to(device), y.to(device)
            x_adv = _pgd_linf(f, x, y,
                                   eps=pgd_eps,
                                   alpha=pgd_alpha,
                                   steps=pgd_steps)

        model_loaded.eval()
        with torch.no_grad():
                        y_pred = model_loaded(x_adv)
                        test_loss += loss_fn(y_pred, y)
                        test_acc += accuracy(y_pred, y)

        test_loss /= len(test_dataloader)
        test_acc /= len(test_dataloader)

    elapsed = (time.time() - tic) / (len(test_dataloader.dataset) * batch_size)

    print(f"[PGD test] Test loss: {test_loss: .5f}| Acc: {test_acc:.4f} | sec/img: {elapsed:.5f}")

    test_loss, test_acc = 0, 0
    correct = 0

    model_loaded.eval()
    tic = time.time()
    with torch.inference_mode():
        # model_loaded.train()
        for x, y in test_dataloader:
            x, y = x.to(device), y.to(device)
            for _ in range(occl_K):
                x_masked = _random_mask(x.clone(), patch=occl_patch)
                y_pred = model_loaded(x_masked)
                test_loss += loss_fn(y_pred, y)
                test_acc += accuracy(y_pred, y)

        test_loss /= (len(test_dataloader)*occl_K)
        test_acc /= (len(test_dataloader)*occl_K)

    elapsed = (time.time() - tic) / (len(test_dataloader.dataset) * batch_size)

    print(f"[Occlusion test] Test loss: {test_loss: .5f}| Acc: {test_acc:.4f} | sec/img: {elapsed:.5f}")

    # from here you need the Gaze-CIFAR-10 dataset, available at https://github.com/rekkles2/Gaze-CIFAR-10

    test_loader = get_upsampled_cifar10_test_loader(
        mapping_dir="./",  # or your actual path
        gaze_root="../Gaze-CIFAR-10/test data",
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    test_loss, test_acc = 0, 0

    model_loaded.to(device)

    model_loaded.eval()
    tic = time.time()
    with torch.inference_mode():
        for X, y in test_loader:
            X, y = X.to(device), y.to(device)
            y_pred = model_loaded(X)

            test_loss += loss_fn(y_pred, y)
            test_acc += accuracy(y_pred, y)

        test_loss /= len(test_loader)
        test_acc /= len(test_loader)

    elapsed = (time.time() - tic) / (len(test_dataloader.dataset) * batch_size)

    print(f"[Gaze-CIFAR test] Test loss: {test_loss: .5f}| Acc: {test_acc:.4f} | sec/img: {elapsed:.5f}")
