import os
import json
import gc
from pathlib import Path

import numpy as np
import pandas as pd
from torchvision.transforms import CenterCrop
from skimage.color import rgb2lab, gray2rgb
from skimage.transform import rescale
from imageio.v3 import imread
from PIL import Image
from tqdm import tqdm

from src.conformal import ConformalCalibratedModel
from skimage.filters import gaussian
from scipy.stats import bootstrap  # type: ignore

root_path = Path(__file__).parent.parent
sinsr_path = root_path.parent / "SinSR"


dataset_basename = "LIU4K_v2_train"
subdatasets = ["Animal", "Building", "Mountain", "Street"]

downscaled_path = sinsr_path / "datasets" / (dataset_basename + "_downscaled")
output_path = root_path / "output_sinsr" / dataset_basename
dataset_path = sinsr_path / "datasets" / dataset_basename
diffs_path = sinsr_path / "masks" / dataset_basename


def my_rgb2lab(x):
    if len(x.shape) == 2:
        x = gray2rgb(x)
    x = rgb2lab(x)
    return x / np.array([[[100, 127 * 2, 127 * 2]]]) + np.array([[[0, 0.5, 0.5]]])


def my_center_crop(x):
    return np.asarray(CenterCrop(1024)(Image.fromarray(x.astype(np.uint8))))


def my_left_corner_crop(x):
    return x[:2048, :2048]


def infer_mean(x: list[float]) -> str:
    result = bootstrap(
        (x,),
        np.mean,
        confidence_level=0.95,
        n_resamples=500,
        method="basic",
        random_state=np.random.default_rng(),
    )
    return (
        result.confidence_interval.low + result.confidence_interval.high
    ) * 0.5, result.confidence_interval.high - result.confidence_interval.low


def safemax(x):
    if len(x) == 0:
        return 0
    else:
        return np.max(x)


valid_fnames = [
    fname.parent.name + "/" + fname.stem
    for fname in diffs_path.resolve().glob("**/*.png")
    if fname.is_file()
]
valid_fnames.sort()

fnames_size = len(valid_fnames)
print(f"Number of valid filenames: {fnames_size}")


high_res_images = [
    my_left_corner_crop(imread(dataset_path / f"{i}.png"))
    for i in tqdm(valid_fnames, desc="read high res")
]

low_res_images = [
    my_rgb2lab(imread(downscaled_path / f"{i}.png"))
    for i in tqdm(valid_fnames, desc="read low res")
]

high_res_images = [
    my_rgb2lab(img) for img in tqdm(high_res_images, desc="prep high res")
]

preds = [
    np.load(output_path / f"{i}_pred.npy")
    for i in tqdm(valid_fnames, desc="load preds")
]

diffs = [
    imread(diffs_path / f"{i}.png", pilmode="L")
    for i in tqdm(valid_fnames, desc="load diffs")
]

diff_sigmas = [
    32,
]
kernel_sizes = [
    31,
]
alpha = 0.1

poisoning_amounts = [round(0.05 * i * fnames_size) for i in range(1, 21)]

names = ["semantic", "non-semantic"]


prob_masks = [
    np.load(output_path / f"{i}_s64_varmask.npy")
    for i in tqdm(valid_fnames, desc=f"load masks")
]

# clip masks into .95 quantile
gprob_masks = [
    np.clip(mask, 0, np.quantile(mask, 0.95))
    for mask in tqdm(prob_masks, desc="clip masks")
]

del prob_masks
gc.collect()

for name in names:
    for diff_sigma in tqdm(diff_sigmas):
        thresholds = []
        upper_bounds = []
        if name == "non-semantic":
            gdiffs = [
                np.load(output_path / f"{i}_k{diff_sigma}_diff.npy")
                for i in tqdm(valid_fnames, desc="load diffs")
            ]
        else:
            gdiffs = [
                gaussian(diff, sigma=diff_sigma)
                for diff in tqdm(diffs, desc="gaussian diffs")
            ]
            del diffs
            gc.collect()

        for poisoning_amount in tqdm(poisoning_amounts, desc="poisoning amounts"):

            tqdm.write(
                f"{name} | diff_sigma: {diff_sigma} | poisoning_amount: {poisoning_amount}"
            )

            # poison gdiffs to be 0 arrays
            for i in range(poisoning_amount):
                gdiffs[i] = np.zeros_like(gdiffs[i])

            conformal = ConformalCalibratedModel.calibrate(
                None,
                list(zip(low_res_images, high_res_images)),
                zip(preds, gprob_masks),
                alphas=[alpha],
                diffs=gdiffs,
                method="dynprog",
                # kernel_size=19,
            )

            threshold = conformal.thresholds[alpha]
            thresholds.append(threshold)
            tqdm.write(f"threshold: {threshold}")

            n_new = fnames_size - poisoning_amount

            upper_bounds.append(alpha * (fnames_size + 1) / (n_new + 1))

        save_path = root_path / "contamination"

        save_path /= f"contamination_{name}_alpha{alpha}_dim{diff_sigma}.csv"
        save_path.parent.mkdir(parents=True, exist_ok=True)

        data = {
            "PoisoningAmount": poisoning_amounts,
            "Threshold": thresholds,
            "UpperBound": upper_bounds,
        }
        # Creating a DataFrame
        df = pd.DataFrame(data)
        # Save DF
        df.to_csv(save_path, index=False)
