import numpy as np
import os, sys
import pandas
import argparse

import torch

MULTITASK_COMPRESSION_DIR = os.environ['MULTITASK_COMPRESSION_DIR'] 
sys.path.append(MULTITASK_COMPRESSION_DIR)
sys.path.append(MULTITASK_COMPRESSION_DIR + '/utils/')
sys.path.append(MULTITASK_COMPRESSION_DIR + '/multitask/')

SCRATCH_DIR = MULTITASK_COMPRESSION_DIR + '/scratch/'

from textfile_utils import *
from plotting_utils import *
from collections import OrderedDict
from utils import *


loss_1_var = r"Target Task Loss"
loss_2_var = r"Adversarial Task Loss"
train_type_var = r"Privacy Type"

def plot_loss_tradeoff(z_dim):

    pkl_file = "z_{}.pkl".format(z_dim)

    for loss_type in [ 'train', 'test' ]:

        loss_type_1 = loss_type + '_loss_1'
        loss_type_2 = loss_type + '_loss_2'

        loss_results_df = pandas.DataFrame(columns = [loss_1_var, loss_2_var, train_type_var])

        for train_type in ['ADV', 'LDP']:
            try:
                SUB_DIR = BASE_DIR + '/compression_with_{}/'.format(train_type)

                labels = ADV_labels if train_type == 'ADV' else LDP_labels
                for label in labels:
                    SUB_TEST_DATA_DIR = SUB_DIR + label

                    result_dict = load_pkl(SUB_TEST_DATA_DIR + '/' + pkl_file)

                    loss_1 = result_dict[loss_type_1]
                    loss_2 = result_dict[loss_type_2]

                    if train_type == 'ADV':
                        curr_label = "Adversarial (Ours)"
                    else:
                        curr_label = "Local Differential Privacy"

                    loss_results_df.loc[len(loss_results_df.index)] = [loss_1, loss_2, curr_label]

                    if train_type == 'LDP':
                        bm_loss_1 = result_dict['bm_' + loss_type_1]
                        bm_loss_2 = result_dict['bm_' + loss_type_2]

                        loss_results_df.loc[len(loss_results_df.index)] = [bm_loss_1, bm_loss_2, "Benchmark"]


            except Exception as e:
                print(e)
                continue

        plot_file = PLOT_DIR + '/{}_{}_loss_tradeoff_z_dim_{}.{}'.format(
            model_name, loss_type, z_dim, fig_ext)

        sns.lineplot(x=loss_1_var, y=loss_2_var, data=loss_results_df, hue=train_type_var, marker='o')
        
        plt.legend(loc="lower right")
        # plt.arrow(1.5, 15, -1.3, 2, head_width=0.2, color='black')
        # plt.text(1.5, 16, "Better", fontsize=15)
        # plt.title(loss_type)
        plt.savefig(plot_file)
        plt.close()


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str)
    parser.add_argument('--adv_weights', type=str)
    parser.add_argument('--ldp_epsilons', type=str)
    parser.add_argument('--ldp_deltas', type=str)
    parser.add_argument('--z_dims', type=str)
    parser.add_argument('--fig_ext', type=str)
    args = parser.parse_args()

    model_name = args.model_name
    ADV_weights = args.adv_weights.split(',')
    ADV_weights = [float(weight) for weight in ADV_weights]
    LDP_epsilons = args.ldp_epsilons.split(',')
    LDP_epsilons = [float(ldp_epsilon) for ldp_epsilon in LDP_epsilons]
    LDP_deltas = args.ldp_deltas.split(',')
    LDP_deltas = [float(ldp_delta) for ldp_delta in LDP_deltas]
    z_dims = args.z_dims.split(',')
    z_dims = [int(z_dim) for z_dim in z_dims]
    fig_ext = args.fig_ext

    ADV_labels = ['weight_{}'.format(weight) for weight in ADV_weights]
    LDP_labels = [ 'e_{}_d_{}'.format(LDP_epsilons[i], LDP_deltas[i]) \
                   for i in range(len(LDP_epsilons)) ]

    BASE_DIR = SCRATCH_DIR + model_name
    PLOT_DIR = BASE_DIR + '/loss_tradeoff_plot/'
    remove_and_create_dir(PLOT_DIR)

    for z_dim in z_dims:
        plot_loss_tradeoff(z_dim)
    
