import numpy as np
import os, sys
import pandas
import argparse

import torch

MULTITASK_COMPRESSION_DIR = os.environ['MULTITASK_COMPRESSION_DIR'] 
sys.path.append(MULTITASK_COMPRESSION_DIR)
sys.path.append(MULTITASK_COMPRESSION_DIR + '/utils/')
sys.path.append(MULTITASK_COMPRESSION_DIR + '/multitask/')

SCRATCH_DIR = MULTITASK_COMPRESSION_DIR + '/scratch/'

from textfile_utils import *
from plotting_utils import *
from collections import OrderedDict
from utils import *


size_var = r"Bottleneck Dimension $Z$"
loss_var = r"Loss"
cost_var = r"Cost"
noise_vars_var = r"Noise Variation"

def plot_loss():
    for loss_type in [ 'train_loss_1', 'train_loss_2',
                       'test_loss_1', 'test_loss_2' ]:

        loss_results_df = pandas.DataFrame(columns = [size_var, loss_var, noise_vars_var])

        for noise_var in noise_vars:
            SUB_TEST_DATA_DIR = SUB_DIR + 'noise_var_' + str(noise_var)
            noise_var_legend = r'noise_var=' + str(noise_var)

            for z_dim in z_dims:
                pkl_file = "z_{}.pkl".format(z_dim)

                result_dict = load_pkl(SUB_TEST_DATA_DIR + '/' + pkl_file)

                z_dim = result_dict['z_dim']
                loss = result_dict[loss_type]

                loss_results_df.loc[len(loss_results_df.index)] = [z_dim, loss, noise_var_legend]

        plot_file = PLOT_DIR + '/{}_{}_bottleneck.{}'.format(model_name, loss_type, fig_ext)

        sns.pointplot(x=size_var, y=loss_var, data=loss_results_df, hue=noise_vars_var)
        
        plt.legend()
        plt.title(loss_type)
        plt.savefig(plot_file)
        plt.close()


def plot_cost():
    for cost_type in [ 'train_cost_1', 'train_cost_2', 
                       'test_cost_1', 'test_cost_2' ]:

        cost_results_df = pandas.DataFrame(columns = [size_var, cost_var, weight_var])

        valid_flag = 1

        for weight in weights:
            if valid_flag == 0:
                    break

            SUB_TEST_DATA_DIR = SUB_DIR + 'weight_' + str(weight)
            weight_legend = r'$w_1$=' + str(weight)

            for z_dim in z_dims:
                pkl_file = "z_{}.pkl".format(z_dim)

                result_dict = load_pkl(SUB_TEST_DATA_DIR + '/' + pkl_file)

                if result_dict[cost_type]["opt_cost"] == None:
                    valid_flag = 0
                    break

                opt_cost = result_dict[cost_type]["opt_cost"]
                extra_cost = result_dict[cost_type]["extra_cost"]

                num_samples = opt_cost.shape[0]

                opt_cost_mean = torch.mean(opt_cost).item()
                total_cost = opt_cost + extra_cost

                total_cost = total_cost.reshape(num_samples).tolist()

                
                basic_results_df = pandas.DataFrame()
                basic_results_df[cost_var] = total_cost
                basic_results_df[size_var] = [result_dict['z_dim']] * num_samples
                basic_results_df[weight_var] = [ weight_legend ] * num_samples
                
                cost_results_df = cost_results_df.append(basic_results_df)

        if valid_flag == 0:
            continue

        plot_file = PLOT_DIR + '/{}_{}_bottleneck.{}'.format(model_name, cost_type, fig_ext)

        # print(cost_results_df)
        sns.pointplot(x=size_var, y=cost_var, data=cost_results_df, hue=weight_var)
        plt.axhline(
            y = opt_cost_mean, linewidth = 2.0, ls = '--', color = 'black', label = 'Optimal')
        plt.legend()
        plt.title(cost_type)
        plt.savefig(plot_file)
        plt.close()


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str)
    parser.add_argument('--noise_vars', type=str)
    parser.add_argument('--z_dims', type=str)
    parser.add_argument('--fig_ext', type=str)
    args = parser.parse_args()

    model_name = args.model_name
    noise_vars = args.noise_vars.split(',')
    noise_vars = [float(noise_var) for noise_var in noise_vars]
    z_dims = args.z_dims.split(',')
    z_dims = [int(z_dim) for z_dim in z_dims]
    fig_ext = args.fig_ext

    BASE_DIR = SCRATCH_DIR + model_name
    SUB_DIR = BASE_DIR + '/compression_with_DP/'
    PLOT_DIR = SUB_DIR + 'loss_plot'
    remove_and_create_dir(PLOT_DIR)

    plot_loss()
    # plot_cost()

    
