import numpy as np
import os, sys
import pandas
import argparse

import torch

MULTITASK_COMPRESSION_DIR = os.environ['MULTITASK_COMPRESSION_DIR'] 
sys.path.append(MULTITASK_COMPRESSION_DIR)
sys.path.append(MULTITASK_COMPRESSION_DIR + '/utils/')
sys.path.append(MULTITASK_COMPRESSION_DIR + '/multitask/')

SCRATCH_DIR = MULTITASK_COMPRESSION_DIR + '/scratch/'

from textfile_utils import *
from plotting_utils import *
from collections import OrderedDict
from utils import *


size_var = r"Bottleneck Dimension $Z$"
loss_1_var = r"Task Loss 1"
loss_2_var = r"Task Loss 2"
weight_var = r"Weight"

def plot_loss_tradeoff():
    for loss_type in [ 'train', 'test' ]:

        loss_type_1 = loss_type + '_loss_1'
        loss_type_2 = loss_type + '_loss_2'

        loss_results_df = pandas.DataFrame(columns = [loss_1_var, loss_2_var])

        for weight in weights:
            SUB_TEST_DATA_DIR = SUB_DIR + 'weight_' + str(weight)
            weight_legend = r'$w_1$=' + str(weight)

            # for z_dim in z_dims:

            z_dim = max(z_dims)

            pkl_file = "z_{}.pkl".format(z_dim)

            result_dict = load_pkl(SUB_TEST_DATA_DIR + '/' + pkl_file)

            # z_dim = result_dict['z_dim']
            loss_1 = result_dict[loss_type_1]
            loss_2 = result_dict[loss_type_2]

            loss_results_df.loc[len(loss_results_df.index)] = [loss_1, loss_2]

        plot_file = PLOT_DIR + '/{}_{}_loss_tradeoff.{}'.format(model_name, loss_type, fig_ext)

        sns.lineplot(x=loss_1_var, y=loss_2_var, data=loss_results_df, marker='o')
        
        # plt.legend()
        plt.title(loss_type)
        plt.savefig(plot_file)
        plt.close()


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str)
    parser.add_argument('--weights', type=str)
    parser.add_argument('--z_dims', type=str)
    parser.add_argument('--fig_ext', type=str)
    args = parser.parse_args()

    model_name = args.model_name
    weights = args.weights.split(',')
    weights = [float(weight) for weight in weights]
    z_dims = args.z_dims.split(',')
    z_dims = [int(z_dim) for z_dim in z_dims]
    fig_ext = args.fig_ext

    BASE_DIR = SCRATCH_DIR + model_name
    SUB_DIR = BASE_DIR + '/compression_with_ADV/'
    PLOT_DIR = SUB_DIR + 'loss_tradeoff_plot'
    remove_and_create_dir(PLOT_DIR)

    plot_loss_tradeoff()
    
