import os
import json
import argparse
from pathlib import Path

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from torchvision.transforms import CenterCrop
from skimage.color import rgb2lab
from skimage.transform import rescale
from imageio.v3 import imread
from PIL import Image
from tqdm import tqdm
import pandas as pd
from icecream import ic
from scipy.signal import convolve2d  # type: ignore
from typing import Callable, List, Tuple

Image.MAX_IMAGE_PIXELS = None
    
root_path = Path(__file__).parent.parent
sinsr_path = root_path.parent / "SinSR"

dataset_basename = "LIU4K_v2_valid"
subdatasets = ["Animal", "Building", "Mountain", "Street"]

downscaled_path = sinsr_path / "datasets" / (dataset_basename + "_downscaled")
output_path = root_path / "output_sinsr" / dataset_basename
dataset_path = sinsr_path / "datasets" / dataset_basename
diffs_path = sinsr_path / "masks" / dataset_basename
fig_output_path = root_path / "figs-neurips"
fig_output_path.mkdir(exist_ok=True)
table_path = root_path / "table_data"
matplotlib.rc('text', usetex=True)
matplotlib.rc('text.latex', preamble=r'\usepackage{amsmath}')
plt.rcParams.update({"figure.dpi": 150})
plt.rcParams["font.family"] = "Times"

COLORS = ['#377eb8', '#ff7f00', '#4daf4a',
          '#f781bf', '#a65628', '#984ea3',
          '#999999', '#e41a1c', '#dede00']


kernel_size = 32
mask = 's64_varmask' #['varmask', 'ker5_varmask', 'point_varmask']
alpha = 0.1
fig_1, ax_1  = plt.subplots(1, 1, figsize=(5, 3), tight_layout=True)
fig_2, ax_2 = plt.subplots(1, 1, figsize=(5, 3), tight_layout=True)

poisoning = pd.read_csv(table_path /
                        f"contamination_non-semantic_alpha{alpha}_dim{kernel_size}.csv")[1:]

poisoning["BaselineFidelityError"] = [poisoning["MethodFidelityError"].iloc[0]] * len(poisoning)
poisoning["BaselineFidelityErrorInterval"] = [poisoning["MethodFidelityErrorInterval"].iloc[0]] * len(poisoning)
print(poisoning)

for i, name in enumerate(('Non-semantic', 'Semantic')):
    psnr = pd.read_csv(table_path / f"psnr_{name.lower()}_dim{kernel_size}.csv")[1:]
    
# Create a lineplot with the data (with fill between the lines)

    psnr["FidelityLevel"] = psnr["FidelityLevel"].round(3)
    if i == 0:
        ax_1.plot(psnr["FidelityLevel"],
                psnr["BaselinePSNR"], '--', label="W/o our method", color=COLORS[0])
        ax_1.fill_between(
            psnr["FidelityLevel"],
            psnr["BaselinePSNR"] - psnr["BaselinePSNRInterval"],
            psnr["BaselinePSNR"] + psnr["BaselinePSNRInterval"],
            alpha=0.1,
            facecolor=COLORS[0],
        )
        
            # Plot a dashed line for the lower bound
        ax_1.plot(psnr["FidelityLevel"], psnr["LowerBound"],
                label="Lower bound (Prop. 3.1)", linestyle='--', color=COLORS[4])
        
        # fill area above lower bound
        ax_1.fill_between(
            psnr["FidelityLevel"],
            psnr["LowerBound"],
            35,
            alpha=0.05,
            facecolor=COLORS[4],
        )
        
        
        ax_2.plot(poisoning["PoisoningAmount"],
                poisoning["BaselineFidelityError"], '--', label="W/o poisoning", color=COLORS[0])
        ax_2.fill_between(
            poisoning["PoisoningAmount"],
            poisoning["BaselineFidelityError"] - poisoning["BaselineFidelityErrorInterval"],
            poisoning["BaselineFidelityError"] + poisoning["BaselineFidelityErrorInterval"],
            alpha=0.1,
            facecolor=COLORS[0],
        )
        
            # Plot a dashed line for the lower bound
        ax_2.plot(poisoning["PoisoningAmount"], poisoning["UpperBound"],
                label="Upper bound (Prop. 3.2)", linestyle='--', color=COLORS[4])
        
        # fill area below upper bound
        ax_2.fill_between(
            poisoning["PoisoningAmount"],
            0,
            poisoning["UpperBound"],
            alpha=0.05,
            facecolor=COLORS[4],
        )
        

    ax_1.plot(psnr["FidelityLevel"], psnr["PSNR"],
            label=f"{name} $D_p$",
            color=COLORS[i + 1])

    ax_1.fill_between(
        psnr["FidelityLevel"],
        psnr["PSNR"] - psnr["PSNRInterval"],
        psnr["PSNR"] + psnr["PSNRInterval"],
        alpha=0.1,
        facecolor=COLORS[i + 1],
    )
    
ax_2.plot(poisoning["PoisoningAmount"], poisoning["MethodFidelityError"],
            label=f"Non-semantic $D_p$",
            color=COLORS[1])
ax_2.fill_between(
    poisoning["PoisoningAmount"],
    poisoning["MethodFidelityError"] - poisoning["MethodFidelityErrorInterval"],
    poisoning["MethodFidelityError"] + poisoning["MethodFidelityErrorInterval"],
    alpha=0.1,
    facecolor=COLORS[1],
)


    
ax_1.legend(loc='lower right')
ax_1.set_xlabel(r"Fidelity Level ($\alpha$)")
ax_1.set_ylabel("Average PSNR\n[Higher is better]")
ax_1.set_ylim(12.5, 35)

ax_2.legend(loc='lower right')
ax_2.set_xlabel(r"Leaked Amount ($n_{\text{leaked}}$)")
ax_2.set_ylabel("Fidelity Error\n[Lower is better]")
ax_2.set_ylim(0.025, 0.2)

fig_1.tight_layout()
fig_1.savefig(fig_output_path / "fig4a.png", dpi=150, bbox_inches='tight')

fig_2.tight_layout()
fig_2.savefig(fig_output_path / "fig4b.png", dpi=150, bbox_inches='tight')
