import numpy as np
import os, sys
import pandas
import argparse

import torch

MULTITASK_COMPRESSION_DIR = os.environ['MULTITASK_COMPRESSION_DIR'] 
sys.path.append(MULTITASK_COMPRESSION_DIR)
sys.path.append(MULTITASK_COMPRESSION_DIR + '/utils/')
sys.path.append(MULTITASK_COMPRESSION_DIR + '/multitask/')

SCRATCH_DIR = MULTITASK_COMPRESSION_DIR + '/scratch/'

from textfile_utils import *
from plotting_utils import *
from collections import OrderedDict
from utils import *


size_var = r"Bottleneck Dimension $Z$"
loss_1_var = r"Task Loss 1"
loss_2_var = r"Task Loss 2"

def plot_loss_tradeoff():
    for loss_type in [ 'train', 'test' ]:

        loss_type_1 = loss_type + '_loss_1'
        loss_type_2 = loss_type + '_loss_2'

        loss_type_1_max = loss_type + '_loss_1_max'
        loss_type_2_max = loss_type + '_loss_2_max'

        loss_results_df = pandas.DataFrame(columns = [loss_1_var, loss_2_var])

        SUB_TEST_DATA_DIR = SUB_DIR + 'analytical_noise'
        
        z_dim = max(z_dims)

        pkl_file = "z_{}.pkl".format(z_dim)

        result_dict = load_pkl(SUB_TEST_DATA_DIR + '/' + pkl_file)

        loss_1 = result_dict[loss_type_1]
        loss_2 = result_dict[loss_type_2]

        loss_1_max = result_dict[loss_type_1_max]
        loss_2_max = result_dict[loss_type_2_max]

        loss_results_df.loc[len(loss_results_df.index)] = [loss_1, loss_2]
        loss_results_df.loc[len(loss_results_df.index)] = [loss_1_max, loss_2_max]

        plot_file = PLOT_DIR + '/{}_{}_loss_tradeoff.{}'.format(model_name, loss_type, fig_ext)

        sns.lineplot(x=loss_1_var, y=loss_2_var, data=loss_results_df, marker='o')
        
        # plt.legend()
        plt.title(loss_type)
        plt.savefig(plot_file)
        plt.close()


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str)
    parser.add_argument('--noise_vars', type=str)
    parser.add_argument('--z_dims', type=str)
    parser.add_argument('--fig_ext', type=str)
    args = parser.parse_args()

    model_name = args.model_name
    noise_vars = args.noise_vars.split(',')
    noise_vars = [float(noise_var) for noise_var in noise_vars]
    z_dims = args.z_dims.split(',')
    z_dims = [int(z_dim) for z_dim in z_dims]
    fig_ext = args.fig_ext

    BASE_DIR = SCRATCH_DIR + model_name
    SUB_DIR = BASE_DIR + '/compression_with_DP/'
    PLOT_DIR = SUB_DIR + 'loss_tradeoff_plot'
    remove_and_create_dir(PLOT_DIR)

    plot_loss_tradeoff()
    
