import numpy as np
import os, sys
import pandas
import argparse

MULTITASK_COMPRESSION_DIR = os.environ['MULTITASK_COMPRESSION_DIR'] 
sys.path.append(MULTITASK_COMPRESSION_DIR)
sys.path.append(MULTITASK_COMPRESSION_DIR + '/utils/')
sys.path.append(MULTITASK_COMPRESSION_DIR + '/multitask/')

SCRATCH_DIR = MULTITASK_COMPRESSION_DIR + '/scratch/'

from textfile_utils import *
from plotting_utils import *
from collections import OrderedDict
from utils import *


dim_1_var = r"Dimension 1"
dim_2_var = r"Dimension 2"
tag_var = r"Tag"
weight_var = r"Weight"


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str)
    parser.add_argument('--weights', type=str)
    parser.add_argument('--z_dims', type=str)
    parser.add_argument('--fig_ext', type=str)
    args = parser.parse_args()

    model_name = args.model_name
    weights = args.weights.split(',')
    weights = [float(weight) for weight in weights]
    z_dims = args.z_dims.split(',')
    z_dims = [int(z_dim) for z_dim in z_dims]
    fig_ext = args.fig_ext

    BASE_DIR = SCRATCH_DIR + model_name
    DATA_DIR = BASE_DIR + '/dataset/'
    data_dict = load_pkl(DATA_DIR + '/data.pkl')

    SUB_DIR = BASE_DIR + '/compression_with_ADV/'
    PLOT_DIR = SUB_DIR + 'scatter_plot/'

    train_tags = data_dict['train_tags']
    test_tags = data_dict['test_tags']

    remove_and_create_dir(PLOT_DIR)

    for weight in weights:
        SUB_TEST_DATA_DIR = SUB_DIR + 'weight_' + str(weight)
        weight_legend = r'$w_1$=' + str(weight)


        train_scatter_df = pandas.DataFrame(columns = [dim_1_var, dim_2_var, tag_var])
        test_scatter_df = pandas.DataFrame(columns = [dim_1_var, dim_2_var, tag_var])

        for z_dim in z_dims:
            pkl_file = "z_{}.pkl".format(z_dim)

            result_dict = load_pkl(SUB_TEST_DATA_DIR + '/' + pkl_file)

            z_dim = result_dict['z_dim']
            if z_dim == 0:
                continue

            train_Y_hat_2 = result_dict['train_Y_hat_2']
            assert(train_Y_hat_2.shape[0] == 2)

            train_scatter_df[dim_1_var] = list(train_Y_hat_2[0, :])
            train_scatter_df[dim_2_var] = list(train_Y_hat_2[1, :])
            train_scatter_df[tag_var] = list(train_tags)
    
            plot_file = PLOT_DIR + '/{}_train_scatter_weight_{}_zdim_{}.{}'.format(
                model_name, weight, z_dim, fig_ext)

            # print(dist_df)
            
            sns.scatterplot(data=train_scatter_df, x=dim_1_var, y=dim_2_var, hue=tag_var)
            
            plt.legend()
            # plt.title(loss_type)
            plt.savefig(plot_file)
            plt.close()


            test_Y_hat_2 = result_dict['test_Y_hat_2']
            assert(test_Y_hat_2.shape[0] == 2)

            test_scatter_df[dim_1_var] = list(test_Y_hat_2[0, :])
            test_scatter_df[dim_2_var] = list(test_Y_hat_2[1, :])
            test_scatter_df[tag_var] = list(test_tags)
    
            plot_file = PLOT_DIR + '/{}_test_scatter_weight_{}_zdim_{}.{}'.format(
                model_name, weight, z_dim, fig_ext)

            # print(dist_df)
            
            sns.scatterplot(data=test_scatter_df, x=dim_1_var, y=dim_2_var, hue=tag_var)
            
            plt.legend()
            # plt.title(loss_type)
            plt.savefig(plot_file)
            plt.close()

    
