from dataclasses import dataclass
from multiplot import *
import random
import os
import pickle
from typing import Any

WORKING_DIRECTORY = "/path/to/working/directory"
METHODS = ["Eps_Greedy", "Greedy", "NeuralUCB", "PLOT", "AdvPLOT_1e_nor",
           "Adv_1e_nor"]
DATASETS = ["Adult", "Bank", "Crime", "German", "MNIST"]
SEED = 20_000
NUM_SEEDS = 5
VERSION_SUFFIX = ''
PROTECTED_CHARACTERISTIC = ''
PARALLEL = True
FAST = False
# Hardcoded
BATCH = 32
DECAY = 0.05
TIMESTEPS = 2500
NUM_EXPERIMENTS = 5
EXPERIMENT_RESULTS_DIR = 'experiment_results'

VERSION = f"_{TIMESTEPS}t_{'_'.join(METHODS)}_{'_'.join(DATASETS)}{VERSION_SUFFIX}{PROTECTED_CHARACTERISTIC}"
JOB_PREFIX = "exp_"
PARALLEL_STR = "_parallel" if PARALLEL else ""
JOB_NAME = f"{JOB_PREFIX}{PARALLEL_STR}_{VERSION}"

METHODS_FOR_PLOT = METHODS
DATASETS_FOR_PLOT = DATASETS

DISPLAY_NAMES = {
    "AdvPLOT_1e_nor": "AdOpt",
    "Adv_1e_nor": "Adversarial",
    "AdOpt": "AdOpt",
    "PLOT": "PLOT",
    "NeuralUCB": "NeuralUCB",
    "Greedy": "Greedy",
    "Eps_Greedy": "Epsilon Greedy"
}

FIG_DIRECTORY = 'figs'
DATA_FILE = 'data/data_dump.p'
TIME_FILE = 'data/time.txt'
LINEWIDTH = 1
LINESTYLE = "solid"
COLORS = {
    'Greedy': 'green',
    'NeuralUCB': 'grey',
    'AdvPLOT_1e_nor': 'red',
    'Adv_1e_nor': 'orange',
    'PLOT': 'blue',
    'Eps_Greedy': 'purple'
}


@dataclass
class PlotData:
    mean_train_cum_regret_averages: np.ndarray
    std_train_cum_regret_averages: np.ndarray
    mean_train_cum_regret_justadv_averages: np.ndarray
    std_train_cum_regret_justadv_averages: np.ndarray
    mean_protected_accepted_averages: Any
    std_protected_accepted_averages: Any
    mean_actual_protected_accepted: Any


def aggregate(data):

    timesteps = data[0][0]
    regrets = np.stack([exp[2].mean_train_cum_regret_averages for exp in data])
    means = np.mean(regrets, axis=0)
    stds = np.std(regrets, axis=0)
    regrets_justadv = np.stack(
        [exp[2].mean_train_cum_regret_justadv_averages for exp in
         data])
    means_justadv = np.mean(regrets_justadv, axis=0)
    stds_justadv = np.std(regrets_justadv, axis=0)
    protected = np.stack(
        [exp[2].mean_protected_accepted_averages for exp in data])
    protected_means = np.mean(protected, axis=0)
    protected_stds = np.std(protected, axis=0)
    true_protected = np.mean(
        np.stack([exp[2].mean_actual_protected_accepted for exp in data]), axis=0)
    return timesteps, means, stds, means_justadv, stds_justadv, protected_means, \
           protected_stds, true_protected


if __name__ == '__main__':
    # random seeds separately
    seed = SEED
    random.seed(SEED)
    for i in range(NUM_SEEDS):
        for dataset in DATASETS_FOR_PLOT:
            plot_data = {}
            timesteps = {}
            for method in METHODS_FOR_PLOT:
                data_files = [
                    os.path.join(WORKING_DIRECTORY,
                                 JOB_NAME + '_' + '' + str(seed) + '_' + str(
                                     i), EXPERIMENT_RESULTS_DIR, dataset,
                                 method, DATA_FILE) for i in
                    range(NUM_EXPERIMENTS)
                ]
                time_files = [
                    os.path.join(WORKING_DIRECTORY,
                                 JOB_NAME + '_' + '' + str(seed) + '_' + str(
                                     i), EXPERIMENT_RESULTS_DIR, dataset,
                                 method, TIME_FILE) for i in
                    range(NUM_EXPERIMENTS)
                ]
                data = []
                time = []
                for file in data_files:
                    if os.path.exists(file):
                        with (open(file, "rb")) as openfile:
                            data.append(pickle.load(openfile))
                    else:
                        print(f'file {file} not found - skipping it.')
                for file in time_files:
                    if os.path.exists(file):
                        with (open(file, "rb")) as openfile:
                            time.append([line.strip() for line in openfile])
                    else:
                        print(f'file {file} not found - skipping it.')
                timesteps_, means, stds, means_justadv, stds_justadv, protected_means, \
                protected_stds, true_protected = aggregate(data)
                timesteps[method] = timesteps_
                plot_data[method] = PlotData(
                    mean_train_cum_regret_averages=means,
                    std_train_cum_regret_averages=stds,
                    mean_train_cum_regret_justadv_averages=means_justadv,
                    std_train_cum_regret_justadv_averages=stds_justadv,
                    mean_protected_accepted_averages=protected_means,
                    std_protected_accepted_averages=protected_stds,
                    mean_actual_protected_accepted=true_protected)

            plot_name = f'{dataset}_regrets_{"_".join(METHODS_FOR_PLOT)}_{TIMESTEPS}T_{seed}'
            base_figs_directory = os.path.join(WORKING_DIRECTORY,
                                               FIG_DIRECTORY)

            if not os.path.isdir(base_figs_directory):
                try:
                    os.makedirs(base_figs_directory)
                except OSError:
                    print("Creation of figs directories failed")
                else:
                    print("Successfully created the figs directory")
            for label in METHODS_FOR_PLOT:
                plot_regret(
                    timesteps=timesteps[label],
                    experiment_results=plot_data[label],
                    color=COLORS[label],
                    label=DISPLAY_NAMES[label]
                )
            plt.title(f'{dataset}')
            print(
                f'saving figure to {os.path.join(base_figs_directory, plot_name + ".png")}')
            plt.savefig(
                os.path.join(base_figs_directory, plot_name + '.png'),
                bbox_inches="tight",
                dpi=300
            )
            plt.close('all')

        seed = random.randint(1, 20_000)

    # collate all random seeds together
    for dataset in DATASETS_FOR_PLOT:
        plot_data = {}
        timesteps = {}
        for method in METHODS_FOR_PLOT:
            seed = SEED
            random.seed(SEED)

            data = []
            for i in range(NUM_SEEDS):
                data_files = [
                    os.path.join(WORKING_DIRECTORY,
                                 JOB_NAME + '_' + '' + str(seed) + '_' + str(
                                     i), EXPERIMENT_RESULTS_DIR, dataset,
                                 method, DATA_FILE) for i in
                    range(NUM_EXPERIMENTS)
                ]
                for file in data_files:
                    if os.path.exists(file):
                        with (open(file, "rb")) as openfile:
                            data.append(pickle.load(openfile))
                    else:
                        print(f'file {file} not found - skipping it.')
                seed = random.randint(1, 20_000)

            timesteps_, means, stds, means_justadv, stds_justadv, protected_means, \
            protected_stds, true_protected = aggregate(data)
            timesteps[method] = timesteps_
            plot_data[method] = PlotData(mean_train_cum_regret_averages=means,
                                         std_train_cum_regret_averages=stds,
                                         mean_train_cum_regret_justadv_averages=means_justadv,
                                         std_train_cum_regret_justadv_averages=stds_justadv,
                                         mean_protected_accepted_averages=protected_means,
                                         std_protected_accepted_averages=protected_stds,
                                         mean_actual_protected_accepted=true_protected)

        plot_name = f'{dataset}_regrets_{"_".join(METHODS_FOR_PLOT)}_{TIMESTEPS}T_total'
        base_figs_directory = os.path.join(WORKING_DIRECTORY, FIG_DIRECTORY)

        if not os.path.isdir(base_figs_directory):
            try:
                os.makedirs(base_figs_directory)
            except OSError:
                print("Creation of figs directories failed")
            else:
                print("Successfully created the figs directory")
        for label in METHODS_FOR_PLOT:
            plot_regret(
                timesteps=timesteps[label],
                experiment_results=plot_data[label],
                color=COLORS[label],
                label=DISPLAY_NAMES[label]
            )
        lg = plt.legend(bbox_to_anchor=(1.05, 1), fontsize=8,
                        loc="upper left")
        plt.title(f'{dataset}')
        print(
            f'saving figure to {os.path.join(base_figs_directory, plot_name + ".png")}')
        plt.savefig(
            os.path.join(base_figs_directory, plot_name + '.png'),
            bbox_inches="tight",
            dpi=300
        )
        plt.close('all')
        if PROTECTED_CHARACTERISTIC:
            for label in METHODS_FOR_PLOT:
                plot_protected(
                    timesteps=timesteps[label],
                    experiment_results=plot_data[label],
                    color=COLORS[label],
                    label=DISPLAY_NAMES[label]
                )

            plt.plot(
                timesteps[label],
                plot_data[label].mean_actual_protected_accepted,
                label='True Value',
                linestyle='--',
                color='black',
            )
            lg = plt.legend(bbox_to_anchor=(1.05, 1), fontsize=8,
                            loc="upper left")
            plt.title(f'Dataset: {dataset}, Characteristic: {PROTECTED_CHARACTERISTIC}')
            plot_name = f'{dataset}_protected_{PROTECTED_CHARACTERISTIC}_{"_".join(METHODS_FOR_PLOT)}_{TIMESTEPS}T_total'
            print(
                f'saving figure to {os.path.join(base_figs_directory, plot_name + ".png")}')
            plt.savefig(
                os.path.join(base_figs_directory, plot_name + '.png'),
                bbox_inches="tight",
                dpi=300
            )
            plt.close('all')

    for dataset in DATASETS_FOR_PLOT:
        for method in METHODS_FOR_PLOT:
            seed = SEED
            random.seed(SEED)
            time = []
            data = []
            for i in range(NUM_SEEDS):
                data_files = [
                    os.path.join(WORKING_DIRECTORY,
                                 JOB_NAME + '_' + '' + str(
                                     seed) + '_' + str(i),
                                 EXPERIMENT_RESULTS_DIR, dataset, method,
                                 DATA_FILE) for i in range(NUM_EXPERIMENTS)
                ]
                for file in data_files:
                    if os.path.exists(file):
                        with (open(file, "rb")) as openfile:
                            data.append(pickle.load(openfile))
                    else:
                        print(f'file {file} not found - skipping it.')
                time_files = [
                    os.path.join(WORKING_DIRECTORY,
                                 JOB_NAME + '_' + '' + str(seed) + '_' + str(
                                     i), EXPERIMENT_RESULTS_DIR, dataset,
                                 method, TIME_FILE) for i in
                    range(NUM_EXPERIMENTS)
                ]
                for file in time_files:
                    if os.path.exists(file):
                        with open(file, 'rb') as openfile:
                            time.append(
                                [float(x) for x in next(openfile).split()][0])
                    else:
                        print(f'file {file} not found - skipping it.')
                seed = random.randint(1, 20_000)
            regrets = np.stack(
                [exp[2].mean_train_cum_regret_averages for exp in data])
            means = np.mean(regrets, axis=0)
            print(f'{method}_{dataset} time: {np.mean(time):.5}, final regret: {means[-1]:.4}')
