import numpy as np
import matplotlib.pyplot as plt
plt.rc('font',family='Times New Roman')
from matplotlib import rcParams
from typing import *
import pandas as pd
import seaborn as sns
import math
import os

config = {
            "font.family": 'Times New Roman',
            "font.size": 12,
            "mathtext.fontset": 'stix',
            "font.serif": ['SimSun'],
            'axes.unicode_minus': False
         }
rcParams.update(config)

method="ngmv2"
method_name="NGMv2"
L_type="Llower"

if L_type=="Lmax":
    length=30
if L_type=="Lvolume":
    length=40
if L_type=="Llower":
    length=8

file_name ="figure_deep"
data_file_name="result_deep_certify/"
if os.path.exists(file_name)==False:
    os.mkdir(file_name)
if os.path.exists(file_name+"/"+method)==False:
    os.mkdir(file_name+"/"+method)

class Accuracy(object):
    def at_radii(self, radii: np.ndarray):
        raise NotImplementedError()

class ApproximateAccuracy(Accuracy):
    def __init__(self, data_file_path: str):
        self.data_file_path = data_file_path

    def at_radii(self, radii: np.ndarray) -> np.ndarray:
        df = pd.read_csv(self.data_file_path, delimiter="\t")
        return np.array([self.at_radius(df, radius) for radius in radii])

    def at_radius(self, df: pd.DataFrame, radius: float):
        return (df["correct"] & (df["radius"] >= radius)).mean()

class Line(object):
    def __init__(self, quantity: Accuracy, legend: str, plot_fmt: str = "", scale_x: float = 1):
        self.quantity = quantity
        self.legend = legend
        self.plot_fmt = plot_fmt
        self.scale_x = scale_x

def plot_certified_accuracy(outfile: str, title: str, max_radius: float,
                            lines: List[Line], radius_step: float = 0.01) -> None:
    print(outfile)
    radii = np.arange(0, max_radius + radius_step, radius_step)
    linestyle_str = ['solid', 'dashed', 'dotted','solid']
    color_str = ['r', 'y', 'g','b']

    plt.figure()

    item=0
    for line in lines:
        plt.rc('font', family='Times New Roman')
        plt.plot(radii * line.scale_x, line.quantity.at_radii(radii), line.plot_fmt,linestyle=linestyle_str[item],color=color_str[item])
        item+=1

    plt.ylim((0, 1.0))
    plt.xlim((0, max_radius))
    plt.tick_params(labelsize=20)
    if L_type == "Lmax":
        L_output = "$\ell_{2}^{max}$"
    if L_type == "Llower":
        L_output = "$\ell_{2}^{lower}$"
    if L_type == "Lvolume":
        L_output = "$\ell_{2}^{\Sigma}$"

    plt.rc('font', family='Times New Roman')
    plt.xlabel(L_output + " " + "radius", fontsize=26)
    plt.rc('font', family='Times New Roman')
    plt.ylabel("certified accuracy", fontsize=26)
    plt.legend([method.legend for method in lines], loc='upper right', fontsize=16)
    plt.tight_layout()
    plt.title(title, fontsize=24)
    plt.tight_layout()
    plt.savefig(outfile + ".png", dpi=600)
    plt.close()

if __name__ == "__main__":

    plot_certified_accuracy(
        file_name+"/"+method+"/ori_sigma1_"+L_type, method_name, length, [
            Line(ApproximateAccuracy(data_file_name + method + "/sigmaori1_n1000_n0100_sample100_cov_pro0_cov_" + L_type),"$SCR-GM$"),
            Line(ApproximateAccuracy(data_file_name + method + "/sigmaori1_n1000_n0100_sample100_cov_pro10_ancer_"+ L_type), "$ ANCER$"),
            Line(ApproximateAccuracy(data_file_name + method + "/sigmaori1_n1000_n0100_sample100_cov_pro10_iso_"+ L_type), "$ DDRS$"),
            Line(ApproximateAccuracy(data_file_name + method + "/sigmaori1_n1000_n0100_sample100_cov_pro10_RS_"+ L_type), "$ RS$"),
        ])

    plot_certified_accuracy(
        file_name+"/" + method + "/ori_sigma5_" + L_type, method_name, length, [
            Line(ApproximateAccuracy(data_file_name + method + "/sigmaori5_n1000_n0100_sample100_cov_pro0_cov_" + L_type),"$SCR-GM$"),
            Line(ApproximateAccuracy(data_file_name + method + "/sigmaori5_n1000_n0100_sample100_cov_pro10_ancer_" + L_type), "$ ANCER$"),
            Line(ApproximateAccuracy(data_file_name + method + "/sigmaori5_n1000_n0100_sample100_cov_pro10_iso_" + L_type), "$ DDRS$"),
            Line(ApproximateAccuracy(data_file_name + method + "/sigmaori5_n1000_n0100_sample100_cov_pro10_RS_" + L_type), "$ RS$"),
        ])
    '''
    plot_certified_accuracy(
        file_name+"/"+method+"/ori_sigma10_"+L_type, method_name, length, [
            Line(ApproximateAccuracy(data_file_name + method + "/sigmaori10_n1000_n0100_sample100_cov_pro0_cov_" + L_type),"$SCR-GM$"),
            Line(ApproximateAccuracy(data_file_name + method + "/sigmaori10_n1000_n0100_sample100_cov_pro10_ancer_"+ L_type), "$ ANCER$"),
            Line(ApproximateAccuracy(data_file_name + method + "/sigmaori10_n1000_n0100_sample100_cov_pro10_iso_"+ L_type), "$ DDRS$"),
            Line(ApproximateAccuracy(data_file_name + method + "/sigmaori10_n1000_n0100_sample100_cov_pro10_RS_" + L_type), "$ RS$"),
       ])

    plot_certified_accuracy(
        file_name+"/"+ method + "/ori_sigma15_" + L_type, method_name, length, [
            Line(ApproximateAccuracy(data_file_name + method + "/sigmaori15_n1000_n0100_sample100_cov_pro0_cov_" + L_type),"$SCR-GM$"),
            Line(ApproximateAccuracy(data_file_name + method + "/sigmaori15_n1000_n0100_sample100_cov_pro10_ancer_" + L_type), "$ ANCER$"),
            Line(ApproximateAccuracy(data_file_name + method + "/sigmaori15_n1000_n0100_sample100_cov_pro10_iso_" + L_type), "$ DDRS$"),
            Line(ApproximateAccuracy(data_file_name + method + "/sigmaori15_n1000_n0100_sample100_cov_pro10_RS_" + L_type), "$ RS$"),
        ])

    plot_certified_accuracy(
        file_name+"/"+ method + "/ori_sigma20_" + L_type, method_name, length, [
            Line(ApproximateAccuracy(data_file_name + method + "/sigmaori20_n1000_n0100_sample100_cov_pro0_cov_" + L_type),"$SCR-GM$"),
            Line(ApproximateAccuracy(data_file_name + method + "/sigmaori20_n1000_n0100_sample100_cov_pro10_ancer_" + L_type), "$ ANCER$"),
            Line(ApproximateAccuracy(data_file_name + method + "/sigmaori20_n1000_n0100_sample100_cov_pro10_iso_" + L_type), "$ DDRS$"),
            Line(ApproximateAccuracy(data_file_name + method + "/sigmaori20_n1000_n0100_sample100_cov_pro10_RS_" + L_type), "$ RS$"),
        ])
    '''

