### Preamble ##########################################################################################################

"""
Runner for computing no-refernce image quality metrics for experiments. Requires the Python package `pyiqa`.
"""

#######################################################################################################################

### Imports ###########################################################################################################

import numpy as np
import pandas as pd
import torch
import timm
import os
import pyiqa
from tqdm import tqdm
from torchvision import transforms
from misc.path_configs import EXPERIMENT_DIR
from misc.dataset_readers import NatAdvDiffImageDataset, ImageNetDataset, ImageDataset

from PIL import Image

from misc.classifier_pipeline import resnet50_config, inceptionv3_config, vith14_config, ClassifierPipeline
from misc.path_configs import IMAGENET_CLASSES

#######################################################################################################################

EXPERIMENTS_TO_COMPUTE = [
    "ClassifierFree2000",
    "AdvComp",
    "DDIMNatADiff1_0.0",
    "DDIMNatADiff1_0.1",
    "DDIMNatADiff1_0.2",
    "DDIMNatADiff1_0.3",
    "DDIMNatADiff1_0.4",
    "DDIMNatADiff1_0.5",
    "DDIMNatADiff2",
    "DDIMNatADiff3",
    "DDIMNatADiff1_Similarity",
    "DDIMNatADiff2_Similarity",
    "DDIMNatADiff3_Similarity",
    "DDIMAdvDiff1",
    "DDIMAdvDiff2",
    "DDIMAdvDiff3",
    "DDIMAdvDiff1_Similarity",
    "DDIMAdvDiff2_Similarity",
    "DDIMAdvDiff3_Similarity",
    "AdversarialContentAttack1",
    "AdversarialContentAttack2",
    "AdversarialContentAttack3",
    "DiffAttack1",
    "DiffAttack2",
    "DiffAttack3",
    "NCF1",
    "NCF2",
    "NCF3",
]
SUB_SAMPLE = None

#######################################################################################################################

transform = transforms.Compose([lambda x: x.to(dtype=torch.float32) / 255])

if os.path.exists(EXPERIMENT_DIR / "noref_image_quality.csv"):
    df = pd.read_csv(EXPERIMENT_DIR / "noref_image_quality.csv")
    noref_imqual = df.to_dict(orient="list")
else:
    noref_imqual = {"Experiment_Name": [], "NIQE": [], "BRISQUE": [], "TReS": [], "ImQual_Sample_Size": []}

for experiment_name in EXPERIMENTS_TO_COMPUTE:
    if experiment_name not in noref_imqual["Experiment_Name"]:
        print(f"=== COMPUTING STATISTICS FOR {experiment_name} ===")

        niqe_metric = pyiqa.create_metric("niqe", device="cuda")
        brisque_metric = pyiqa.create_metric("brisque", device="cuda")
        tres_metric = pyiqa.create_metric("tres", device="cuda")

        dataset = ImageDataset(EXPERIMENT_DIR / experiment_name, transforms=transform)
        if SUB_SAMPLE is None:
            dataset = dataset
        else:
            if SUB_SAMPLE > len(dataset):
                raise ValueError("`SUB_SAMPLE` can not be larger than the size of the experiment dataset")
            idxs = np.random.randint(0, len(dataset), SUB_SAMPLE)
            dataset = torch.utils.data.Subset(dataset, idxs)
        dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=10, num_workers=10, shuffle=False, pin_memory=True
        )

        niqe_sum = 0
        brisque_sum = 0
        tres_sum = 0

        for im_batch in tqdm(dataloader):
            niqe_sum += torch.sum(niqe_metric(im_batch)).item()
            brisque_sum += torch.sum(brisque_metric(im_batch)).item()
            tres_sum += torch.sum(tres_metric(im_batch)).item()

        noref_imqual["Experiment_Name"].append(experiment_name)
        noref_imqual["NIQE"].append(niqe_sum / len(dataloader.dataset))
        noref_imqual["BRISQUE"].append(brisque_sum / len(dataloader.dataset))
        noref_imqual["TReS"].append(tres_sum / len(dataloader.dataset))
        noref_imqual["ImQual_Sample_Size"].append(len(dataloader.dataset))
    else:
        print(f"=== SKIPPING {experiment_name} ===")

noref_imqual = pd.DataFrame(noref_imqual).sort_values(by="Experiment_Name")
noref_imqual.to_csv(EXPERIMENT_DIR / "noref_image_quality.csv", index=False)

#######################################################################################################################
