import os
import json
import argparse
from pathlib import Path

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from torchvision.transforms import CenterCrop
from skimage.color import rgb2lab
from skimage.transform import rescale
from imageio.v3 import imread
from PIL import Image
from tqdm import tqdm
import pandas as pd
from icecream import ic
from scipy.signal import convolve2d  # type: ignore
from scipy.stats import bootstrap  # type: ignore

parser = argparse.ArgumentParser()
parser.add_argument("--force", action="store_true")
args = parser.parse_args()
force = args.force

root_path = Path(__file__).parent.parent
sinsr_path = root_path.parent / "SinSR"


dataset_basename = "LIU4K_v2_valid"
subdatasets = ["Animal", "Building", "Mountain", "Street"]

downscaled_path = sinsr_path / "datasets" / (dataset_basename + "_downscaled")
output_path = root_path / "output_sinsr" / dataset_basename
dataset_path = sinsr_path / "datasets" / dataset_basename
diffs_path = sinsr_path / "masks" / dataset_basename
fig_output_path = root_path / "figs-neurips"
fig_output_path.mkdir(exist_ok=True)
table_path = root_path / "table_data"
table_path.mkdir(exist_ok=True)


matplotlib.rc('text', usetex=True)
matplotlib.rc('text.latex', preamble=r'\usepackage{amsmath}')
plt.rcParams.update({"figure.dpi": 150})
plt.rcParams["font.family"] = "Times"

COLORS = ['#ff7f00', '#377eb8', '#4daf4a',
          '#f781bf', '#a65628', '#984ea3',
          '#999999', '#e41a1c', '#dede00']


BLUES = ["#4292c6", "#2171b5", "#08306b"]
GREENS = ["#41ab5d", "#238b45", "#00441b"]


valid_fnames = [
    fname.parent.name + "/" + fname.stem
    for fname in dataset_path.resolve().glob("**/*.png")
    if fname.is_file() and fname.parent.name != "Capture"
]


non_semantic_kernel_sizes = [32, 64] # 1, 16,
semantic_kernel_sizes = [32, 64] #, 128]

mask = 's64_varmask' #['varmask', 'ker5_varmask', 'point_varmask']
configs = [
    {
        'name': 'Non-semantic $D_p$',
        'save_path': table_path / f"avg_mask_fidelity_non-semantic_dim{kernel_size}.csv",
        'thresholds_path': root_path / f"conformal/sinsr-thresholds-s64-dp{kernel_size}-bs.json",
        'kernel_size': kernel_size,
        'main_color': BLUES[0],
        'color': BLUES[i+1],
    } for i, kernel_size in enumerate(non_semantic_kernel_sizes)
]

configs += [
    {
        'name': 'Semantic $D_p$',
        'save_path': table_path / f"avg_mask_fidelity_semantic_dim{kernel_size}.csv",
        'thresholds_path': root_path / f"conformal/sinsr-thresholds-s64-d{kernel_size}-bs.json",
        'kernel_size': kernel_size,
        'main_color': GREENS[0],
        'color': GREENS[i+1],
    } for i, kernel_size in enumerate(semantic_kernel_sizes)
]

if force:
    
    def my_left_corner_crop(x):
        return x[:2048, :2048]

    def my_rgb2lab(x):
        x = rgb2lab(x)
        return x / np.array([[[100, 127 * 2, 127 * 2]]]) + np.array([[[0, 0.5, 0.5]]])


    def safemax(x):
        if len(x) == 0:
            return 0
        else:
            return np.max(x)


    def infer_mean(x: list[float]) -> str:
        result = bootstrap((x,), np.mean, confidence_level=0.95,
                           n_resamples=500, 
                           method='basic',
                           random_state= np.random.default_rng())
        return (result.confidence_interval.low + result.confidence_interval.high)*0.5, result.confidence_interval.high - result.confidence_interval.low

    

    prob_masks = [
        np.load(output_path / f"{i}_{mask}.npy")
        for i in tqdm(valid_fnames, desc="load masks")
    ]
    
    
    for config in tqdm(configs, desc="configurations"):
        kernel_size = config['kernel_size']
        diffs = [
            np.load(output_path / f"{fname}_k{kernel_size}_diff.npy")
            for fname in tqdm(valid_fnames, desc="load diffs")
        ]
        
        alphas = []
        method_fidelity_error = []
        method_fidelity_error_interval = []

        avg_mask_size_error = []
        avg_mask_size_error_interval = []
        with open(config['thresholds_path'], "r", encoding='utf-8') as file:
            thresholds = json.load(file)

        for alpha, threshold in tqdm(thresholds.items()):
            alpha = float(alpha)
            if alpha > 0.50:
                continue

            samples = [
                    safemax(
                        diff[prob_mask < threshold]
                    )
                    for diff, prob_mask in tqdm(zip(diffs, prob_masks), desc="load samples")
            ]
            
            eff_samples = [
                np.mean(prob_mask > threshold)
                for prob_mask in prob_masks
            ]
            
            coverage_err = infer_mean(samples)
            mask_size_err = infer_mean(eff_samples)

            alphas.append(alpha)
            method_fidelity_error.append(coverage_err[0])
            method_fidelity_error_interval.append(coverage_err[1])
            avg_mask_size_error.append(mask_size_err[0])
            avg_mask_size_error_interval.append(mask_size_err[1])
            

        samples = [safemax(diff)
                    for diff in tqdm(diffs, desc="load samples")]

        base_infer = infer_mean(samples)

        baseline_fidelity_error = [base_infer[0]] * len(alphas)
        baseline_fidelity_error_interval = [base_infer[1]] * len(alphas)
        # print(f"alpha=none,  coverage err = {infer_mean(samples)}")

        # Creating a DataFrame
        df = pd.DataFrame(
            {
                "FidelityLevel": alphas,
                "BaselineFidelityError": baseline_fidelity_error,
                "BaselineFidelityErrorInterval": baseline_fidelity_error_interval,
                "MethodFidelityError": method_fidelity_error,
                "MethodFidelityErrorInterval": method_fidelity_error_interval,
                "AverageConformalMaskSize": avg_mask_size_error,
                "AverageConformalMaskSizeInterval": avg_mask_size_error_interval,
            }
        )

        # Save DF
        config['df'] = df
        df.to_csv(config['save_path'], index=False)
else:
    for config in configs:
        config['df'] = pd.read_csv(config['save_path'])[:-1]

# Create a lineplot with the data (with fill between the lines)
fig, (ax_1, ax_2, ax_3) = plt.subplots(1, 3, figsize=(9, 2.1), gridspec_kw={"wspace": 0.06},
                                        constrained_layout=True)
len_config = len(configs)



first = True
for idx, config in enumerate(configs):
    df = config['df']
    kernel_size = config['kernel_size']
    name = config['name']
    main_color = config['main_color']
    color = config['color']
    
    if kernel_size == 32:
        if first:
            first = False
            df["FidelityLevel"] = df["FidelityLevel"].round(3)
            ax_1.plot(df["FidelityLevel"], df["BaselineFidelityError"], '--', label="W/o our method", color=COLORS[0])
            ax_1.fill_between(
                df["FidelityLevel"],
                df["BaselineFidelityError"] - df["BaselineFidelityErrorInterval"],
                df["BaselineFidelityError"] + df["BaselineFidelityErrorInterval"],
                alpha=0.1,
                facecolor=COLORS[0],
            )

        ax_1.plot(df["FidelityLevel"], df["MethodFidelityError"],
                label=name,
                color=main_color)

        ax_1.fill_between(
            df["FidelityLevel"],
            df["MethodFidelityError"] - df["MethodFidelityErrorInterval"],
            df["MethodFidelityError"] + df["MethodFidelityErrorInterval"],
            alpha=0.1,
            facecolor=main_color,
        )

    if name == "Non-semantic $D_p$":
        ax_2.plot(df["FidelityLevel"], df["AverageConformalMaskSize"],
                label=f"{kernel_size}-pixel-wide blur for $D_p$",
                color=color)
        ax_2.fill_between(
            df["FidelityLevel"],
            df["AverageConformalMaskSize"] - df["AverageConformalMaskSizeInterval"],
            df["AverageConformalMaskSize"] + df["AverageConformalMaskSizeInterval"],
            alpha=0.1,
            facecolor=color,
        )
    else:
        ax_3.plot(df["FidelityLevel"], df["AverageConformalMaskSize"],
                label=f"{kernel_size}-pixel-wide blur for $D_p$",
                color=color)
        ax_3.fill_between(
            df["FidelityLevel"],
            df["AverageConformalMaskSize"] - df["AverageConformalMaskSizeInterval"],
            df["AverageConformalMaskSize"] + df["AverageConformalMaskSizeInterval"],
            alpha=0.1,
            facecolor=color,
        )

    
ax_1.legend(loc='lower right')
ax_1.set_xlabel(r"Fidelity Level ($\alpha$)")
ax_1.set_ylabel("Fidelity Error\n[Lower is better]")


ax_2.legend(loc='upper right')
ax_2.set_xlabel(r"Fidelity Level ($\alpha$)")
ax_2.set_ylabel("Average Mask Size\n[Lower is better]")
ax_2.yaxis.set_major_formatter(ticker.PercentFormatter(1.0))
ax_2.set_xlim(0.0, 0.3)

ax_3.legend(loc='upper right')
ax_3.set_xlabel(r"Fidelity Level ($\alpha$)")
ax_3.set_ylabel("Average Mask Size\n[Lower is better]")
ax_3.yaxis.set_major_formatter(ticker.PercentFormatter(1.0))

fig.savefig(fig_output_path / "fig2.png", dpi=150, bbox_inches='tight')
