import numpy as np
import os
import argparse
import os.path
import pandas as pd
import matplotlib.pyplot as plt
from os import path
import seaborn as sns
from matplotlib.lines import Line2D

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="getresults")
    parser.add_argument("--metric", type=str, default="r2", required=False)

    config = parser.parse_args()
    sns.set(font_scale=2)
    MoM_k = [5, 50, 100, 150, 200]
    MoM_list = []
    for k in MoM_k:
        MoM_list.append("MoM_{}_1024".format(k))
    # algorithm_list = ["MoM_5_128", "MoM_5", "MoM_5_1024", "MoM_5_2048", "MoM_5_4096", "MoM_5_8192"]  #  "PRL", "Huber"
    # algorithm_list = ["Standard", "MoM_61_1024"]
    algorithm_list =["Standard"] + MoM_list
    colormap = {"MoM_5": "red", "PRL": "blue"}

    r2_table = []
    auc_table = []
    head = ["*{0:<13}".format(" ")] + algorithm_list
    result_csv_test = []
    result_csv_test.append(["*"] + algorithm_list)
    print(head)
    result_alg = []
    for data in ["Synthetic_adversary"]:
        result = []
        LinePlot = []
        for algorithm in algorithm_list:
            # for output_d in [1, 10, 100, 500, 1000, 5000, 10000]:
            for output_d in [1]:
                for anomaly_ratio in [0.01]:
                    for seed in [0, 1, 2, 3, 4]:
                        result_path = "./Result/{}/{}/{}/{}/{}/result.npy".format(
                            data, output_d, anomaly_ratio, algorithm, seed
                        )
                        if os.path.isfile(result_path):
                            results = np.load(result_path, allow_pickle=True)
                            results_dict = results.item()
                            r2_clean = results_dict["r2_clean"]
                            r2_poison = results_dict["r2_poison"]
                            name = str.split(algorithm, '_')
                            k = name[1]
                            result_alg.append(
                                [k, r2_clean, r2_poison, output_d]
                            )
    result_alg = np.stack(result_alg)
    sns.set_theme(style="whitegrid")
    plt.subplot(1, 2, 1)
    sns.boxplot(
        x=result_alg[:, 0],
        y=result_alg[:, 1].astype(float),
        # markers=True,
        # style=result_alg[:, 0],
        # hue=result_alg[:, 0],
        # err_style="bars",
    )
    # sns.lineplot(
    #     x=result_alg[:, 3],
    #     y=result_alg[:, 1].astype(float),
    #     markers=True,
    #     style=result_alg[:, 0],
    #     hue=result_alg[:, 0],
    #     err_style="bars",
    # )
    # plt.xticks(rotation=45)
    plt.xlabel("number of blocks")
    plt.ylabel("R2")
    plt.title("clean test data")
    plt.subplot(1, 2, 2)
    sns.boxplot(
        x=result_alg[:, 0],
        y=np.exp(result_alg[:, 2].astype(float)),
        # markers=True,
        # style=result_alg[:, 0],
        # hue=result_alg[:, 0],
        # err_style="bars",
    )
    # sns.lineplot(
    #     x=result_alg[:, 3],
    #     y=(result_alg[:, 2].astype(float)),
    #     markers=True,
    #     style=result_alg[:, 0],
    #     hue=result_alg[:, 0],
    #     err_style="bars",
    # )
    plt.title("adversarial attacked test data")
    plt.xlabel("number of blocks")
    plt.ylabel("R2")
    fig = plt.gcf()
    fig.set_size_inches(18.5, 10.5)
    plt.show()
    print("figure")
