import os
import json
import argparse
from pathlib import Path

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from torchvision.transforms import CenterCrop
from skimage.color import rgb2lab, lab2rgb
from skimage.transform import rescale
from imageio.v3 import imread
from PIL import Image
from tqdm import tqdm
import pandas as pd
from icecream import ic
from scipy.signal import convolve2d  # type: ignore
from skimage.filters import gaussian
from typing import Callable, List, Tuple
from scipy.stats import bootstrap  # type: ignore

parser = argparse.ArgumentParser()
parser.add_argument("--force", action="store_true")
args = parser.parse_args()
force = args.force
Image.MAX_IMAGE_PIXELS = None
    
root_path = Path(__file__).parent.parent
sinsr_path = root_path.parent / "SinSR"


dataset_basename = "LIU4K_v2_valid"
subdatasets = ["Animal", "Building", "Mountain", "Street"]

downscaled_path = sinsr_path / "datasets" / (dataset_basename + "_downscaled")
sr_path = root_path / "output_sinsr" / dataset_basename
hr_path = sinsr_path / "datasets" / dataset_basename
diffs_path = sinsr_path / "masks" / dataset_basename
table_path = root_path / "table_data"
table_path.mkdir(exist_ok=True)
contamination_path = root_path / "contamination"
table_path = root_path / "table_data"

matplotlib.rc('text.latex', preamble=r'\usepackage{amsmath}')


COLORS = ['#377eb8', '#ff7f00', '#4daf4a',
          '#f781bf', '#a65628', '#984ea3',
          '#999999', '#e41a1c', '#dede00']


valid_fnames = [
    fname.parent.name + "/" + fname.stem
    for fname in hr_path.resolve().glob("**/*.png")
    if fname.is_file() and fname.parent.name != "Capture"
]
valid_fnames.sort()

dfs = []
kernel_size = 32
mask_name = 's64_varmask' #['varmask', 'ker5_varmask', 'point_varmask']


def my_left_corner_crop(x):
    return x[:2048, :2048]

def my_rgb2lab(x):
    x = rgb2lab(x)
    return x / np.array([[[100, 127 * 2, 127 * 2]]]) + np.array([[[0, 0.5, 0.5]]])


def safemax(x):
    if len(x) == 0:
        return 0
    else:
        return np.max(x)
    
def safemean(x):
    if len(x) == 0:
        return 0
    else:
        return np.mean(x)


def infer_mean(x: list[float]) -> str:
        result = bootstrap((x,), np.mean, confidence_level=0.95,
                           n_resamples=500, 
                           method='basic',
                           random_state= np.random.default_rng())
        return (result.confidence_interval.low + result.confidence_interval.high)*0.5, result.confidence_interval.high - result.confidence_interval.low


alpha = 0.1




thresholds_non_semantic = pd.read_csv(contamination_path / f"contamination_non-semantic_alpha{alpha}_dim{kernel_size}.csv")[:-1]
    
thresholds_semantic = pd.read_csv(contamination_path / f"contamination_semantic_alpha{alpha}_dim{kernel_size}.csv")[:-1]
    

non_semantic_diffs = [
                np.load(sr_path / f"{fname}_k{kernel_size}_diff.npy")
                for fname in tqdm(valid_fnames, desc="load diffs")
            ]
# semantic_diffs = [
#                 gaussian(diff, sigma=kernel_size)
#                 for diff in tqdm(original_diffs, desc="gaussian diffs")
#             ]

prob_masks = [
        np.load(sr_path / f"{i}_{mask_name}.npy")
        for i in tqdm(valid_fnames, desc="load masks")
    ]

configs = [
    {
        'name': 'non-semantic',
        'fidelity_errors': [],
        'fidelity_errors_interval': [],
        'baseline_fidelity_error': [],
        'baseline_fidelity_error_interval': [],
        'thresholds': thresholds_non_semantic,
        'diffs': non_semantic_diffs,
        },
    # {
    #     'name': 'semantic',
    #     'fidelity_errors': [],
    #     'fidelity_errors_interval': [],
    #     'baseline_fidelity_error': [],
    #     'baseline_fidelity_error_interval': [],
    #     'thresholds': thresholds_semantic,
    #     'diffs': semantic_diffs,
    #     },
]


for config in configs:
    thresholds = config['thresholds']
    diffs = config['diffs']

    samples = [safemax(diff) for diff in tqdm(diffs, desc="load samples")]

    baseline = infer_mean(samples)
    # tqdm.write(f"baseline baseline: {baseline}")
    config['baseline_fidelity_error'] = [baseline[0]] * len(thresholds)
    config['baseline_fidelity_error_interval'] = [baseline[1]] * len(thresholds)

    for i, threshold in enumerate(thresholds["Threshold"]):
        alpha = float(alpha)
        if alpha > 0.50 or alpha < 0.05:
            continue
    
        samples = [
            safemax(
                diff[prob_mask < threshold]
                )
                for diff, prob_mask in tqdm(zip(diffs, prob_masks), desc="load samples")
        ]

        fidelity_error = infer_mean(samples)
        config['fidelity_errors'].append(fidelity_error[0])
        config['fidelity_errors_interval'].append(fidelity_error[1])

    new_data = {
        "PoisoningAmount": thresholds["PoisoningAmount"],
        "Threshold": thresholds["Threshold"],
        "UpperBound": thresholds["UpperBound"],
        "MethodFidelityError": config['fidelity_errors'],
        "MethodFidelityErrorInterval": config['fidelity_errors_interval'],
        "BaselineFidelityError": config['baseline_fidelity_error'],
        "BaselineFidelityErrorInterval": config['baseline_fidelity_error_interval'],
    }

    # save the samples
    df = pd.DataFrame(new_data)
    df.to_csv(table_path /
              f"contamination_{config['name']}_alpha{alpha}_dim{kernel_size}.csv",
              index=False)
