import numpy as np
import os, sys
import pandas
import argparse

import torch

MULTITASK_COMPRESSION_DIR = os.environ['MULTITASK_COMPRESSION_DIR'] 
sys.path.append(MULTITASK_COMPRESSION_DIR)
sys.path.append(MULTITASK_COMPRESSION_DIR + '/utils/')
sys.path.append(MULTITASK_COMPRESSION_DIR + '/multitask/')

SCRATCH_DIR = MULTITASK_COMPRESSION_DIR + '/scratch/'

from textfile_utils import *
from plotting_utils import *
from collections import OrderedDict
from utils import *


size_var = r"Bottleneck Dimension $Z$"
loss_var = r"Loss"
cost_var = r"Cost"
weight_var = r"Weight"

def plot_loss():
    for loss_type in [ 'overall_train_loss', 'train_loss_1', 'train_loss_2',
                       'overall_test_loss', 'test_loss_1', 'test_loss_2',
                       'train_task_agnostic_loss_1', 'train_task_agnostic_loss_2',
                       'test_task_agnostic_loss_1', 'test_task_agnostic_loss_2',
                            ]:
        try:

            loss_results_df = pandas.DataFrame(columns = [size_var, loss_var, weight_var])

            for i in range(label_length):
                if train_type == "ADV":
                    weight = ADV_weights[i]
                    label = 'weight_{}'.format(weight)
                    label_legend = r'$\lambda$=' + str(weight)
                elif train_type == "LDP":
                    epsilon = LDP_epsilons[i]
                    delta = LDP_deltas[i]
                    label = 'e_{}_d_{}'.format(epsilon, delta)
                    label_legend = r'$\epsilon$=' + str(epsilon) + r',$\delta$=' + str(delta)

                SUB_TEST_DATA_DIR = SUB_DIR + label

                for z_dim in z_dims:
                    pkl_file = "z_{}.pkl".format(z_dim)

                    result_dict = load_pkl(SUB_TEST_DATA_DIR + '/' + pkl_file)

                    if loss_type not in result_dict:
                        raise Exception('Cannot plot {}'.format(loss_type))

                    z_dim = result_dict['z_dim']
                    loss = result_dict[loss_type]

                    loss_results_df.loc[len(loss_results_df.index)] = [z_dim, loss, label_legend]

            plot_file = PLOT_DIR + '/{}_{}_bottleneck.{}'.format(model_name, loss_type, fig_ext)

            sns.pointplot(x=size_var, y=loss_var, data=loss_results_df, hue=weight_var)
            
            plt.legend()
            plt.title(loss_type)
            plt.savefig(plot_file)
            plt.close()

        except Exception as e:
            print(e)
            continue


def plot_cost():
    for cost_type in [ 'train_cost_1', 'train_cost_2', 
                       'test_cost_1', 'test_cost_2' ]:

        try:
            cost_results_df = pandas.DataFrame(columns = [size_var, cost_var, weight_var])
            
            for i in range(label_length):
                if train_type == "ADV":
                    weight = ADV_weights[i]
                    label = 'weight_{}'.format(weight)
                    label_legend = r'$\lambda$=' + str(weight)
                elif train_type == "LDP":
                    epsilon = LDP_epsilons[i]
                    delta = LDP_deltas[i]
                    label = 'e_{}_d_{}'.format(epsilon, delta)
                    label_legend = r'$\epsilon$=' + str(epsilon) + r',$\delta$=' + str(delta)

                SUB_TEST_DATA_DIR = SUB_DIR + label

                for z_dim in z_dims:
                    pkl_file = "z_{}.pkl".format(z_dim)

                    result_dict = load_pkl(SUB_TEST_DATA_DIR + '/' + pkl_file)

                    if cost_type not in result_dict or result_dict[cost_type]["opt_cost"] == None:
                        raise Exception('Cannot plot {}'.format(cost_type))

                    opt_cost = result_dict[cost_type]["opt_cost"]
                    extra_cost = result_dict[cost_type]["extra_cost"]

                    num_samples = opt_cost.shape[0]

                    opt_cost_mean = torch.mean(opt_cost).item()
                    total_cost = opt_cost + extra_cost

                    total_cost = total_cost.reshape(num_samples).tolist()

                    
                    basic_results_df = pandas.DataFrame()
                    basic_results_df[cost_var] = total_cost
                    basic_results_df[size_var] = [result_dict['z_dim']] * num_samples
                    basic_results_df[weight_var] = [ label_legend ] * num_samples
                    
                    cost_results_df = cost_results_df.append(basic_results_df)

            plot_file = PLOT_DIR + '/{}_{}_bottleneck.{}'.format(model_name, cost_type, fig_ext)

            # print(cost_results_df)
            sns.pointplot(x=size_var, y=cost_var, data=cost_results_df, hue=weight_var)
            plt.axhline(
                y = opt_cost_mean, linewidth = 2.0, ls = '--', color = 'black', label = 'Optimal')
            plt.legend()
            plt.title(cost_type)
            plt.savefig(plot_file)
            plt.close()

        except Exception as e:
            print(e)
            continue


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str)
    parser.add_argument('--train_type', type=str)
    parser.add_argument('--adv_weights', type=str, default="")
    parser.add_argument('--ldp_epsilons', type=str, default="")
    parser.add_argument('--ldp_deltas', type=str, default="")
    parser.add_argument('--z_dims', type=str)
    parser.add_argument('--fig_ext', type=str)
    args = parser.parse_args()

    model_name = args.model_name
    train_type = args.train_type
    z_dims = args.z_dims.split(',')
    z_dims = [int(z_dim) for z_dim in z_dims]
    fig_ext = args.fig_ext

    if train_type == "ADV":
        ADV_weights = args.adv_weights.split(',')
        ADV_weights = [float(weight) for weight in ADV_weights]
    elif train_type == "LDP":
        LDP_epsilons = args.ldp_epsilons.split(',')
        LDP_epsilons = [float(ldp_epsilon) for ldp_epsilon in LDP_epsilons]
        LDP_deltas = args.ldp_deltas.split(',')
        LDP_deltas = [float(ldp_delta) for ldp_delta in LDP_deltas]

    if train_type == "ADV":
        label_length = len(ADV_weights)
    elif train_type == "LDP":
        label_length = len(LDP_epsilons)

    BASE_DIR = SCRATCH_DIR + model_name
    SUB_DIR = BASE_DIR + '/compression_with_{}/'.format(train_type)
    PLOT_DIR = SUB_DIR + 'loss_plot'
    remove_and_create_dir(PLOT_DIR)

    plot_loss()
    plot_cost()
