import numpy as np
import matplotlib.pyplot as plt
plt.rc('font',family='Times New Roman')
from matplotlib import rcParams
from typing import *
import pandas as pd
import seaborn as sns
import math
import os

config = {
            "font.family": 'Times New Roman',
            "font.size": 12,
            "mathtext.fontset": 'stix',
            "font.serif": ['SimSun'],
            'axes.unicode_minus': False
         }
rcParams.update(config)

method="rrwm"
label="feature"
n=100
n0=10
sample=100
sigma=0.001


if label=="feature":
    length=0.01
    scale=0.0001
else:
    length=10
    scale=0.01

file_name ="figure_rrwm"
data_file_name="result_rrwm_"+label+"/"
if os.path.exists(file_name)==False:
    os.mkdir(file_name)
if os.path.exists(file_name+"/"+method)==False:
    os.mkdir(file_name+"/"+method)

class Accuracy(object):
    def at_radii(self, radii: np.ndarray):
        raise NotImplementedError()


class ApproximateAccuracy(Accuracy):
    def __init__(self, data_file_path: str):
        self.data_file_path = data_file_path

    def at_radii(self, radii: np.ndarray) -> np.ndarray:
        df = pd.read_csv(self.data_file_path, delimiter="\t")
        return np.array([self.at_radius(df, radius) for radius in radii])

    def at_radius(self, df: pd.DataFrame, radius: float):
        return (df["correct"] & (df["radius"] >= radius)).mean()


class Line(object):
    def __init__(self, quantity: Accuracy, legend: str, plot_fmt: str = "", scale_x: float = 1):
        self.quantity = quantity
        self.legend = legend
        self.plot_fmt = plot_fmt
        self.scale_x = scale_x


def plot_certified_accuracy(outfile: str, title: str, max_radius: float,
                            lines: List[Line], radius_step: float = scale) -> None:
    print(outfile)
    radii = np.arange(0, max_radius + radius_step, radius_step)

    linestyle_str = ['solid', 'dashed', 'dotted','solid']
    color_str = ['r', 'y', 'g','b']

    plt.figure()
    item=0
    for line in lines:
        plt.rc('font', family='Times New Roman')
        plt.plot(radii * line.scale_x, line.quantity.at_radii(radii), line.plot_fmt, linestyle=linestyle_str[item],color=color_str[item])
        item+=1

    plt.ylim((0, 1.0))
    plt.xlim((0, max_radius))
    plt.tick_params(labelsize=16)
    L_output= "$\ell_{2}$"

    plt.rc('font', family='Times New Roman')
    plt.rc('font', family='Times New Roman')
    plt.xlabel(L_output + " " + "radius", fontsize=23)
    plt.rc('font', family='Times New Roman')
    plt.ylabel("certified accuracy", fontsize=23)
    plt.legend([method.legend for method in lines], loc='upper right', fontsize=14)
    plt.tight_layout()
    plt.title(title, fontsize=20)
    plt.tight_layout()
    plt.savefig(outfile + ".png", dpi=300)
    plt.close()

if __name__ == "__main__":

    plot_certified_accuracy(
        file_name + "/" + method + "/"+label+"_noise_scale"+str(sigma), "RRWM", length, [
            Line(ApproximateAccuracy(data_file_name + method + "/noise_scale"+str(sigma)+"_n"+str(n)+"_n0"+str(n0)+"_sample"+str(sample)+"_cov_Llower"),"$\ell_{2}^{lower} SCR-GM$"),
            Line(ApproximateAccuracy(data_file_name + method + "/noise_scale"+str(sigma)+"_n"+str(n)+"_n0"+str(n0)+"_sample"+str(sample)+"_cov_Lvolume"),"$\ell_{2}^{\Sigma} SCR-GM$"),
            Line(ApproximateAccuracy(data_file_name + method + "/noise_scale"+str(sigma)+"_n"+str(n)+"_n0"+str(n0)+"_sample"+str(sample)+"_cov_Lmax"),"$\ell_{2}^{max} SCR-GM$"),
            Line(ApproximateAccuracy(data_file_name + method + "/noise_scale"+str(sigma)+"_n"+str(n)+"_n0"+str(n0)+"_sample"+str(sample)+"_RS_Lmax"), "$\ell_{2}^{lower}\quad\ell_{2}^{max}\quad\ell_{2}^{\Sigma} RS$"),

        ])

