import numpy as np
import os, sys
import pandas
import argparse

MULTITASK_COMPRESSION_DIR = os.environ['MULTITASK_COMPRESSION_DIR'] 
sys.path.append(MULTITASK_COMPRESSION_DIR)
sys.path.append(MULTITASK_COMPRESSION_DIR + '/utils/')
sys.path.append(MULTITASK_COMPRESSION_DIR + '/multitask/')

SCRATCH_DIR = MULTITASK_COMPRESSION_DIR + '/scratch/'

from textfile_utils import *
from plotting_utils import *
from collections import OrderedDict
from utils import *


dim_1_var = r"Mean"
dim_2_var = r"Variance"
tag_var = r"Tag"
weight_var = r"Weight"


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str)
    parser.add_argument('--fig_ext', type=str)
    args = parser.parse_args()

    model_name = args.model_name
    fig_ext = args.fig_ext

    BASE_DIR = SCRATCH_DIR + model_name
    DATA_DIR = BASE_DIR + '/dataset/'
    data_dict = load_pkl(DATA_DIR + '/data.pkl')

    PLOT_DIR = DATA_DIR + 'scatter_plot_mean_var/'

    train_tags = data_dict['train_tags']
    test_tags = data_dict['test_tags']

    remove_and_create_dir(PLOT_DIR)


    train_scatter_df = pandas.DataFrame(columns = [dim_1_var, dim_2_var, tag_var])
    test_scatter_df = pandas.DataFrame(columns = [dim_1_var, dim_2_var, tag_var])

    train_scatter_df[dim_1_var] = data_dict['train_dataset_mean']
    train_scatter_df[dim_2_var] = data_dict['train_dataset_var']
    train_scatter_df[tag_var] = list(train_tags)

    plot_file = PLOT_DIR + '/{}_train_scatter_mean_var.{}'.format(
        model_name, fig_ext)
    
    sns.scatterplot(data=train_scatter_df, x=dim_1_var, y=dim_2_var, hue=tag_var,
                    palette=sns.color_palette(n_colors=data_dict['num_clusters']))
    
    plt.legend()
    plt.title("Train")
    plt.savefig(plot_file)
    plt.close()


    test_scatter_df[dim_1_var] = data_dict['test_dataset_mean']
    test_scatter_df[dim_2_var] = data_dict['test_dataset_var']
    test_scatter_df[tag_var] = list(test_tags)

    plot_file = PLOT_DIR + '/{}_test_scatter_mean_var.{}'.format(
        model_name, fig_ext)

    sns.scatterplot(data=test_scatter_df, x=dim_1_var, y=dim_2_var, hue=tag_var,
                    palette=sns.color_palette(n_colors=data_dict['num_clusters']))
    
    plt.legend()
    plt.title("Test")
    plt.savefig(plot_file)
    plt.close()

    
