import math
from re import S
from matplotlib import legend
from matplotlib.legend import Legend

import os
import numpy as np

import logging
logger = logging.getLogger()
import math
from typing import *

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import MaxNLocator
sns.set()

is_confidence = True


def certify(filename, args, eps):

    df = pd.read_csv(filename, delimiter="\t")
    pa_exp = np.array(df["pa_exp"])
    pb_exp = np.array(df["pb_exp"])
    is_acc = np.array(df["is_acc"])

    if is_confidence:
        heof_factor = np.sqrt(np.log(1 / args['alpha']) / 2 / args['N_m'])
        pa = np.maximum(1e-8, pa_exp - heof_factor)  # [num_samples]
        pb = np.minimum(1 - 1e-8, pb_exp + heof_factor)  # [num_samples]

    # Calculate the metrics
    cert_bound_exp = []
    cert_bound = []
    acc_num = 0
    for i in range(len(pa_exp)):
        if is_acc[i] == True:
            acc_num += 1
        if is_confidence:
            cert_bound.append(
                cal_bound(eps=eps,
                          delta=args['delta'],
                          barc=args['barc'],
                          ja=pa[i],
                          jb=pb[i]))
        else:
            cert_bound_exp.append(
                cal_bound(eps=eps,
                          delta=args['delta'],
                          barc=args['barc'],
                          ja=pa_exp[i],
                          jb=pb_exp[i])
            )  # Also calculate the bound using expected value.

    print("acc is %f" % (acc_num * 1.0 / len(pa_exp)))
    if is_confidence:
        cert_bound = np.array(cert_bound)
        return None, cert_bound, is_acc
    else:
        cert_bound_exp = np.array(cert_bound_exp)
        return None, cert_bound_exp, is_acc


class Accuracy(object):
    def at_radii(self, radii: np.ndarray):
        raise NotImplementedError()


class CertifiedRate(Accuracy):
    def __init__(self, filename, args, eps):
        _, cert_bound_exp, is_acc = certify(filename, args, eps)
        self.cert_bound = cert_bound_exp
        # self.cert_bound_exp = cert_bound_exp
        self.is_acc = is_acc

    def at_radii(self, radii: np.ndarray) -> np.ndarray:
        return np.array([self.at_radius(radius) for radius in radii])

    def at_radius(self, radius: float):
        return (self.cert_bound >= radius).mean()


class CertifiedAcc(Accuracy):
    def __init__(self, filename, args, eps):
        _, cert_bound_exp, is_acc = certify(filename, args, eps)
        self.cert_bound = cert_bound_exp
        self.is_acc = is_acc

    def at_radii(self, radii: np.ndarray) -> np.ndarray:
        return np.array([self.at_radius(radius) for radius in radii])

    def at_radius(self, radius: float):
        return (np.logical_and(self.cert_bound >= radius, self.is_acc)).mean()


class Line(object):
    def __init__(self,
                 quantity: Accuracy,
                 legend: str,
                 plot_fmt: str = "",
                 scale_x: float = 1):
        self.quantity = quantity
        self.legend = legend
        self.plot_fmt = plot_fmt
        self.scale_x = scale_x


def plot_certified_accuracy(outfile: str,
                            title: str,
                            max_radius: float,
                            lines: List[Line],
                            radius_step: float = 0.0001) -> None:
    radii = np.arange(0, max_radius + radius_step, radius_step)
    plt.figure()
    for line in lines:
        plt.plot(radii * line.scale_x, line.quantity.at_radii(radii),
                 line.plot_fmt)

    plt.ylim((0, 1))
    plt.xlim((0, max_radius))
    plt.tick_params(labelsize=14)
    plt.xlabel("k", fontsize=22)
    plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))  # for conf
    plt.ylabel("certified accuracy", fontsize=22)
    legends_name = []
    for method in lines:
        legends_name.append(method.legend)
    plt.legend(legends_name, loc='upper right', fontsize=16)
    # plt.legend([method.legend for method in lines], loc='upper right', fontsize=16)

    plt.title(title, fontsize=22)
    plt.tight_layout()
    plt.savefig(outfile + ".pdf")
    plt.tight_layout()
    plt.savefig(outfile + ".png", dpi=300)
    plt.close()


def plot_certified_rate(outfile: str,
                        title: str,
                        max_radius: float,
                        lines: List[Line],
                        radius_step: float = 0.0001) -> None:
    radii = np.arange(0, max_radius + radius_step, radius_step)
    plt.figure()
    for line in lines:
        plt.plot(radii * line.scale_x, line.quantity.at_radii(radii),
                 line.plot_fmt)

    plt.ylim((0, 1))
    plt.xlim((0, max_radius))
    plt.tick_params(labelsize=14)
    plt.xlabel("k", fontsize=16)
    plt.ylabel("certified rate", fontsize=16)
    plt.legend([method.legend for method in lines],
               loc='upper right',
               fontsize=16)

    plt.title(title, fontsize=20)
    plt.tight_layout()
    plt.savefig(outfile + ".pdf")
    plt.tight_layout()
    plt.savefig(outfile + ".png", dpi=300)
    plt.close()


def cal_bound(eps, delta, barc, ja, jb):
    logterm = (ja * (math.exp(eps) - 1) +
               delta * barc) / (jb * (math.exp(eps) - 1) + delta * barc)
    k = 1 / (2 * eps) * math.log(logterm)
    print(k)
    return k


def cal_ja(eps, delta, jb):
    ja = math.exp(2 * eps) * jb + (1 + math.exp(eps)) * delta
    return ja


def get_dp_result(folder_prefix, saved_model_name):
    filename = folder_prefix + saved_model_name + '/all_exp.csv'
    print(saved_model_name)
    df = pd.read_csv(filename)

    epss = [df.loc[i, 'eps'] for i in range(df.shape[0])]

    return epss


import argparse

is_insdp = False
is_cifar = False

parser = argparse.ArgumentParser()
# Dataset Setting

path = ''
if is_cifar:
    if is_insdp:
        path = 'root/folder/path/'
    else:
        path = 'root/folder/path/'
else:
    if is_insdp:
        path = 'root/folder/path/'
    else:
        path = 'root/folder/path/'
parser.add_argument('--folder_prefix', type=str, default=path)
# Smoothing Setting
parser.add_argument('--N_m', type=int, default=1000)
parser.add_argument('--barc', type=float, default=1.0)

if is_insdp:
    parser.add_argument('--delta', type=float, default=1e-5)
else:
    parser.add_argument('--delta', type=float, default=0.0029)

if is_cifar:
    parser.add_argument('--epoch', type=int, default=1)
else:
    parser.add_argument('--epoch', type=int, default=3)
# Evaluate setting
parser.add_argument('--alpha', type=float, default=0.01)

if __name__ == '__main__':
    args = parser.parse_args()
    args = vars(args)
    print(args)

    saved_epoch = args['epoch']
    if is_cifar:
        if is_insdp:
            saved_model_names = [
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
        ]
            noises = [8, 7, 6, 5, 4, 3, 2, 1]
        else:
            saved_model_names = [
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            ]
            noises = [
                4,
                3,
                2.6,
                2.3,
                1.7,
                1.3,
                1,
                0.8,
                0.5,
            ]

    else:
        if is_insdp:
            saved_model_names = [
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
        ]
            noise = [15, 10, 8, 5, 4, 3, 2, 1]

        else:
            saved_model_names = [
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
        ]
            noises = [
                3, 2.7, 2.5, 2.3, 2.1, 1.9, 1.8, 1.7, 1.5, 1, 0.8, 0.6, 0.5
            ]

    epsilons = []
    filenames = []
    lines = []
    for i in range(len(saved_model_names)):
        saved_model_name = saved_model_names[i]
        epss = get_dp_result(args['folder_prefix'], saved_model_name)
        epsilon = epss[saved_epoch - 1]
        epsilons.append(epsilon)
        filename = os.path.join(
            args['folder_prefix'] + saved_model_name,
            "Epoch%dM%dEps%.4f.txt" % (saved_epoch, args['N_m'], epsilon))
        filenames.append(filename)

    lines = [
        Line(quantity=CertifiedAcc(filenames[i], args, epsilons[i]),
             legend="$\epsilon$ = " + str(epsilons[i]))
        for i in range(len(saved_model_names))
    ]
    lines_rate = [
        Line(quantity=CertifiedRate(filenames[i], args, epsilons[i]),
             legend="$\epsilon$ = " + str(epsilons[i]))
        for i in range(len(saved_model_names))
    ]

    if is_confidence == False:

        if is_cifar:
            if is_insdp:
                plot_certified_accuracy('plots/cer_acc/insdp_cer_acc_cifar',
                                        "(d) CIFAR-10", 9.0, lines)

            else:
                plot_certified_accuracy('plots/cer_acc/cer_acc_cifar',
                                        "(b) CIFAR-10", 7.0, lines)

        else:
            if is_insdp:
                plot_certified_accuracy('plots/cer_acc/insdp_cer_acc_mnist',
                                        "(c) MNIST", 16.0, lines)

            else:
                plot_certified_accuracy('plots/cer_acc/cer_acc_mnist',
                                        "(a) MNIST", 6.0, lines)

    else:
        if is_cifar:
            if is_insdp:
                plot_certified_accuracy(
                    'plots/cer_acc_conf/insdp_cer_acc_cifar', "(d) CIFAR-10",
                    4.0, lines)

            else:
                plot_certified_accuracy('plots/cer_acc_conf/cer_acc_cifar',
                                        "(b) CIFAR-10", 6.0, lines)

        else:
            if is_insdp:
                plot_certified_accuracy(
                    'plots/cer_acc_conf/insdp_cer_acc_mnist', "(c) MNIST", 8.0,
                    lines)

            else:
                plot_certified_accuracy('plots/cer_acc_conf/cer_acc_mnist',
                                        "(a) MNIST", 4.0, lines)
