from utils_CNN_cifar10 import get_args, main, EXP_DIR
from toydata import read_txt, save_txt
import json
import time

import os
import numpy as np
import matplotlib.pyplot as plt

args = get_args()
LOG_DIR = EXP_DIR + "test/"

if args.identifier:
    LOG_DIR += f"{args.identifier}/"
elif args.debug:
    LOG_DIR += "debug/"
else:
    LOG_DIR += f"n{args.n}_f0NA_LT_m0/"

INP_DIR = LOG_DIR
OUT_DIR = LOG_DIR + "output/"
foldername = f"{args.agg}_{args.attack}_niid{args.noniid if args.noniid else args.dirichlet if args.dirichlet else False}" \
             f"_n{args.n}_f{args.f}_m{args.momentum}_nlpsize{args.nlpsize}_nlpobj{args.nlpobj}_mix{args.mixing}_clip{args.grad_clip}" \
             f"_s{args.bucketing}_seed{args.seed}"
LOG_DIR += foldername

save_dir = EXP_DIR + 'images_cifar10/' + foldername
if not os.path.exists(save_dir):
    os.makedirs(save_dir, exist_ok=True)

# Number of iterations = 960
if args.debug:
    MAX_BATCHES_PER_EPOCH = 50
    EPOCHS = 10
else:
    MAX_BATCHES_PER_EPOCH = 50
    EPOCHS = 30

if not args.plot:
    start = time.perf_counter_ns()
    main(args, LOG_DIR, EPOCHS, MAX_BATCHES_PER_EPOCH)
    end = time.perf_counter_ns()
    print(args.agg, args.attack, round(end - start) * 0.000000001, round(end - start) * 0.000000001 / 60)
    data = []
    filename = os.path.join(LOG_DIR, "stats")
    with open(filename, 'r') as f:
        for line in f:
            line = line.replace("'", '"')
            line = line.replace("nan", '1.5e+31')
            line = line.replace("inf", '1.5e+31')
            data.append(json.loads(line))
    ls_train = []
    ac_train = []
    ls_test = []
    ac_test = []
    t = 0
    tst = 0
    for line in data:
        if line['_meta']['type'] == 'train':
            t += 1
            ls_train.append(str(t) + ' ' + str(line['Loss']))
            ac_train.append(str(t) + ' ' + str(line['top1']))
        elif line['_meta']['type'] == 'validation':
            tst += 1
            ls_test.append(str(tst) + ' ' + str(line['Loss']))
            ac_test.append(str(tst) + ' ' + str(line['top1']))

    save_txt(ls_train, save_dir + '/losses_train.txt')
    save_txt(ac_train, save_dir + '/accs_train.txt')
    save_txt(ls_test, save_dir + '/losses_test.txt')
    save_txt(ac_test, save_dir + '/accs_test.txt')
    save_txt([str(round(end - start) * 0.000000001 / 60)], save_dir + '/time.txt')

else:
    fig_folder = EXP_DIR + 'images/' + f"img_cifar10_niid{args.noniid if args.noniid else args.dirichlet if args.dirichlet else False}_n{args.n}_f{args.f}_m{args.momentum}_seed{args.seed}/"
    if not os.path.exists(fig_folder):
        os.mkdir(fig_folder)

    for agg in ["avg", "cm", "cp", "rfa", "krum", "tm"]:
        nodecolor = {"ALIE": "grey", "BF": "blue", "IPM": "green", "mimic": "orange", "MinMax": "violet",
                     "MinSum": "purple", "LF": "chocolate"}
        def nlpcolor_obj(nlpobj):
            if nlpobj == 1:
                return "red"
            elif nlpobj == 0:
                return "chocolate"
            elif 0 < nlpobj < 1:
                return "salmon"
            elif nlpobj > 1:
                return "gold"
            else:
                return "darkred"

        def nlpcolor(nlpsize):
            if nlpsize == 64:
                return "red"
            elif nlpsize == 16:
                return "salmon"
            elif nlpsize == 4:
                return "gold"
            else:
                return "darkred"

        loss_fig, loss_ax = plt.subplots(figsize=(10, 10))
        loss_ax.set_xlim(0, 70)
        loss_ax.set_ylim(0, 5.5)
        loss_ax.set_xlabel('Epoch')
        loss_ax.set_ylabel('Loss')

        acc_fig, acc_ax = plt.subplots(figsize=(10, 10))
        acc_ax.set_xlim(0, 70)
        acc_ax.set_ylim(0, 100)
        acc_ax.set_xlabel('Epoch')
        acc_ax.set_ylabel('Accuracy')

        for attack in ["ALIE", "BF", "IPM", "mimic", "MinMax", "MinSum", "LF"]:
            grid_identifier = f"cifar10_{agg}_{attack}_niid{args.noniid if args.noniid else args.dirichlet if args.dirichlet else False}_n{args.n}_f{args.f}_m{args.momentum}_nlpsize{args.nlpsize}_nlpobj{args.nlpobj}_mix{args.mixing}_s{args.bucketing}_seed{args.seed}"
            path = EXP_DIR + 'images/' + grid_identifier
            acc_xs, acc_ys = read_txt(path + '/accs_train.txt')
            loss_xs, loss_ys = read_txt(path + '/losses_train.txt')

            loss_ax.plot(loss_xs, loss_ys,
                         linewidth=1,
                         label=attack,
                         color=nodecolor[attack])

            acc_ax.plot(acc_xs, acc_ys,
                         linewidth=1,
                         label=attack,
                         color=nodecolor[attack])

        if agg == 'avg':
            pass
        else:
            attack = "SSNLP"
            for size in [1, 4, 16, 64]:
                grid_identifier = f"cifar10_{agg}_{attack}_niid{args.noniid if args.noniid else args.dirichlet if args.dirichlet else False}_n{args.n}_f{args.f}_m{args.momentum}_nlpsize{size}_nlpobj{args.nlpobj}_mix{args.mixing}_s{args.bucketing}_seed{args.seed}"
                path = EXP_DIR + 'images/' + grid_identifier
                acc_xs, acc_ys = read_txt(path + '/accs_train.txt')
                loss_xs, loss_ys = read_txt(path + '/losses_train.txt')

                loss_ax.plot(loss_xs, loss_ys,
                             linewidth=1,
                             label=attack + ' ' + str(size),
                             color=nlpcolor(size))

                acc_ax.plot(acc_xs, acc_ys,
                            linewidth=1,
                            label=attack + ' ' + str(size),
                            color=nlpcolor(size))

        loss_ax.legend(labels=["ALIE", "BF", "IPM", "mimic",
                               "MinMax", "MinSum", "LF", "NLP 1", "STAB"],
                       loc="upper right",
                       fontsize=10)
        loss_ax.set_title(f"{agg}_n{args.n}_f{args.f}_m{args.momentum} Loss")
        fig_dir = fig_folder + f"cifar10_niid{args.noniid if args.noniid else args.dirichlet if args.dirichlet else False}_{agg}_n{args.n}_f{args.f}_m{args.momentum}_loss" + '.pdf'
        loss_fig.savefig(fig_dir, format='pdf')
        print('{} saved.'.format(fig_dir))

        acc_ax.legend(labels=["ALIE", "BF", "IPM", "mimic",
                               "MinMax", "MinSum", "LF", "NLP 1", "STAB"],
                       loc="lower right",
                       fontsize=10)
        acc_ax.set_title(f"{agg}_n{args.n}_f{args.f}_m{args.momentum} Accuracy")
        fig_dir = fig_folder + f"cifar10_niid{args.noniid if args.noniid else args.dirichlet if args.dirichlet else False}_{agg}_n{args.n}_f{args.f}_m{args.momentum}_acc" + '.pdf'
        acc_fig.savefig(fig_dir, format='pdf')
        print('{} saved.'.format(fig_dir))
