import os
import torch
import pickle

import numpy as np
import torch.optim as optim

from tqdm import tqdm
from scipy import stats
from copy import deepcopy

import matplotlib
import matplotlib.patches as mpatches
from matplotlib import pyplot as plt

from torch.utils.data import DataLoader
from torch.utils.data.sampler import SequentialSampler

import utils
import model.net as net

from dataloader import TestDataset
from dataloader import TrainEvalDataset

matplotlib.rcParams['figure.dpi'] = 200

#@title Plotting: Seaborn style and matplotlib params
import seaborn as sns
sns.set_style("white")

# Matplotlib params
from matplotlib import rcParams
from matplotlib import rc

rcParams['legend.loc'] = 'best'
rcParams['pdf.fonttype'] = 42
rcParams['ps.fonttype'] = 42
# rcParams['font.size'] = 30

rc('text', usetex=False)

np.set_printoptions(precision=4)

def set_axes_label(ax, xlabel, ylabel):
  ax.set_xlabel(xlabel, labelpad=0)
  ax.set_ylabel(ylabel, labelpad=0)

def set_axes(ax, xlim, ylim, xlabel, ylabel):
  ax.set_xlim(xlim)
  ax.set_ylim(ylim)
  ax.set_xlabel(xlabel, labelpad=14)
  ax.set_ylabel(ylabel, labelpad=14)
 
def set_ticks(ax, xticks, xticklabels, yticks, yticklabels):
  ax.set_xticks(xticks)
  ax.set_xticklabels(xticklabels)
  ax.set_yticks(yticks)
  ax.set_yticklabels(yticklabels)

def decorate_axis(ax, wrect=10, hrect=10, labelsize='large'):
  # Hide the right and top spines
  ax.spines['right'].set_visible(False)
  ax.spines['top'].set_visible(False)
  ax.spines['left'].set_linewidth(2)
  ax.spines['bottom'].set_linewidth(2)
  # Deal with ticks and the blank space at the origin
  ax.tick_params(length=0.1, width=0.1, labelsize=labelsize)
  # Pablos' comment
  ax.spines['left'].set_position(('outward', hrect))
  ax.spines['bottom'].set_position(('outward', wrect))

def dump_hist(res_path, loss_dict):
    total_gt = loss_dict.get('labels')[:, 0]
    sns.set_style("white")
    rcParams['legend.loc'] = 'best'
    rcParams['pdf.fonttype'] = 42
    rcParams['figure.dpi'] = 200
    rcParams.update({'font.size': 30})
    colors = sns.color_palette('colorblind')
    n_bins = 1000
    counts, bins, _ = plt.hist(np.abs(total_gt).flatten(), bins=n_bins)
    plt.close()
    fig, ax = plt.subplots(ncols=1, figsize=(12, 6))
    y_val = counts
    x_val = (bins[:n_bins] + bins[1:]) / 2
    # y_val2 = counts2
    # x_val2 = (bins2[:n_bins] + bins2[1:]) / 2
    ax.plot(x_val, y_val, linewidth=2, color=colors[1], label='Ground Truth Value')

    ax.grid(True, alpha=0.3)
    decorate_axis(ax, wrect=10, hrect=10, labelsize=30)  
    plt.xscale("log")
    plt.yscale("log")
    plt.xlabel("Ground Truth")
    plt.ylabel("Count")

    plt.tight_layout()
    plt.savefig(res_path + 'traffic_gt_lol.pdf', format='pdf')

    plt.close()

def pretty_plot(x_dict, y_dict, x_label, y_label, title_txt, legend=False, conf_dict={}, legendsize=20, ticklabelsize=25, fontsize=30, legend_o='h',
    color_start=0):

    rcParams['font.size'] = fontsize
    linewidth = 2
    color_palette = 'colorblind'

    fig, ax = plt.subplots(figsize=(12, 9))
    # plt.subplots_adjust(top=0.99, right=0.95)
    
    color_arr = sns.color_palette(color_palette)[color_start:]
    if len(y_dict) == 3:
        color_arr = list(np.r_[color_arr[:2], color_arr[4:]])

    colors = dict(zip(list(y_dict.keys()), color_arr))

    for method, y_arr in y_dict.items():
        ax.plot(x_dict[method], y_arr, color=colors[method], linewidth=linewidth, label=method)
        low_arr = conf_dict.get(method + '_low')
        if low_arr is not None:
            high_arr = conf_dict.get(method + '_high')
            ax.fill_between(x_dict[method], low_arr, high_arr, color=colors[method], alpha=0.1)
        

    ax.grid(True, alpha=0.3)
    decorate_axis(ax, wrect=10, hrect=10, labelsize=ticklabelsize)

    if legend:    
        labels = list(y_dict.keys())
        fake_patches = [mpatches.Patch(color=colors[l], alpha=0.75) for l in labels]
        if legend_o == 'h':
            legend = fig.legend(fake_patches, labels, loc='upper center', fancybox=True,
                                fontsize=legendsize, ncol=len(labels), bbox_to_anchor=(0.52, 1),
                                columnspacing=1.0)
        elif legend_o == 'v':
            legend = fig.legend(fake_patches, labels, loc='upper right', fancybox=True,
                                fontsize=legendsize, ncol=1, bbox_to_anchor=(0.9, 0.95),
                                columnspacing=1.0)

    set_axes_label(ax, x_label, y_label)
    plt.title(title_txt)
    plt.tight_layout()

    return ax


def calculate_metrics(model, data_loader, params):

    model.eval()

    with torch.no_grad():

        for i, (test_batch, id_batch, v, labels) in enumerate(tqdm(data_loader)):

            test_batch = test_batch.permute(1, 0, 2).to(torch.float32).to(params.device)
            id_batch = id_batch.unsqueeze(0).to(params.device)
            v_batch = v.to(torch.float32).to(params.device)
            labels = labels.to(torch.float32).to(params.device)

            batch_size = test_batch.shape[1]

            input_mu = torch.zeros(batch_size, params.test_predict_start, device=params.device) # scaled
            input_sigma = torch.zeros(batch_size, params.test_predict_start, device=params.device) # scaled

            hidden = model.init_hidden(batch_size)
            cell = model.init_cell(batch_size)

            for t in range(params.test_predict_start):

              zero_index = (test_batch[t, :, 0] == 0)

              if t > 0 and torch.sum(zero_index) > 0:
                  test_batch[t, zero_index, 0] = mu[zero_index]

              mu, sigma, hidden, cell, _ = model(test_batch[t].unsqueeze(0), id_batch, hidden, cell)
              input_mu[:, t] = v_batch[:, 0] * mu + v_batch[:, 1]
              input_sigma[:, t] = v_batch[:, 0] * sigma

            sample_mu, sample_sigma = model.test(test_batch, v_batch, id_batch, hidden, cell)

            # Calculate the losses

            pred_arr = sample_mu.cpu().detach().numpy()
            gt_arr = labels[:, params.test_predict_start:].cpu().detach().numpy()
            samples_arr = labels[:, :params.test_predict_start].cpu().detach().numpy()

            samples_mae = np.mean(np.abs(gt_arr - pred_arr), axis=1)
            samples_mse = np.mean(np.power(gt_arr - pred_arr, 2), axis=1)

            # samples_deviation = np.sum(np.abs(gt_arr - pred_arr), axis=1)
            # samples_sse = np.sum(np.power(gt_arr - pred_arr, 2), axis=1)
            samples_summation = np.sum(np.abs(gt_arr), axis=1)
            samples_nzc = np.sum(gt_arr != 0, axis=1)

            # samples_mape = np.divide(samples_mae, np.mean(gt_arr, axis=1))

            if i == 0:

                all_gt = deepcopy(gt_arr)
                all_pred = deepcopy(pred_arr)
                all_std = deepcopy(sample_sigma.cpu().detach().numpy())
                all_samples = deepcopy(samples_arr)

                all_mae = deepcopy(samples_mae)
                all_mse = deepcopy(samples_mse)

                all_summation = deepcopy(samples_summation)
                all_nzc = deepcopy(samples_nzc)

                # all_mape = deepcopy(samples_mape)

            else:

                all_gt = np.r_[all_gt, gt_arr]
                all_pred = np.r_[all_pred, pred_arr]
                all_std = np.r_[all_std, sample_sigma.cpu().detach().numpy()]
                all_samples = np.r_[all_samples, samples_arr]

                all_mae = np.r_[all_mae, samples_mae]
                all_mse = np.r_[all_mse, samples_mse]

                all_summation = np.r_[all_summation, samples_summation]
                all_nzc = np.r_[all_nzc, samples_nzc]

                # all_mape = np.r_[all_mape, samples_mape]

        # nrmse_arr = np.full_like(all_mse, fill_value=np.nan)
        # nrmse_arr[all_mean != 0] = np.divide(np.sqrt(all_mse[all_mean != 0]), all_mean[all_mean != 0])

        print("--------------------------------Final metrics--------------------------------")
        # print("MAE:", np.mean(all_mae), np.std(all_mae), stats.skew(all_mae), stats.kurtosis(all_mae))
        # print("MSE:", np.mean(all_mse), np.std(all_mse), stats.skew(all_mse), stats.kurtosis(all_mse))

        den_val = np.sum(all_summation) / np.sum(all_nzc)

        # Calculate overall and per sample ND

        mean_nd = np.average(all_mae, weights=all_nzc) / den_val
        mean_rmse = np.sqrt(np.average(all_mse, weights=all_nzc)) / den_val

        non_zero_bool = all_nzc > 0

        sample_nd = np.full_like(all_mae, fill_value=0.)
        sample_nd[non_zero_bool] = params.predict_steps * np.divide(all_mae[non_zero_bool], all_nzc[non_zero_bool]) / den_val

        sample_rmse = np.full_like(all_mae, fill_value=0.)
        sample_rmse[non_zero_bool] = np.sqrt(params.predict_steps * np.divide(all_mse[non_zero_bool], all_nzc[non_zero_bool])) / den_val

        print("Metric & Mean & VaR_95 & VaR_98 & VaR_99 & Max & Kurtosis & Skew")

        # print("ND:", mean_nd, np.nanstd(sample_nd), stats.skew(sample_nd, nan_policy='omit'), stats.kurtosis(sample_nd, nan_policy='omit'),
        #     np.nanpercentile(sample_nd, [95, 98, 99, 100]))

        nd_str_arr = np.round(
            np.r_[[mean_nd],
            np.nanpercentile(sample_nd, [95, 98, 99, 100]),
            [stats.kurtosis(sample_nd, nan_policy='omit'), stats.skew(sample_nd, nan_policy='omit')]], 4).astype(str)
        
        print("ND &", ' & '.join(nd_str_arr))

        # print("RMSE:", mean_rmse, np.nanstd(sample_rmse), stats.skew(sample_rmse, nan_policy='omit'), stats.kurtosis(sample_rmse, nan_policy='omit'),
        #     np.nanpercentile(sample_rmse, [95, 98, 99, 100]))

        rmse_str_arr = np.round(
            np.r_[[mean_rmse],
            np.nanpercentile(sample_rmse, [95, 98, 99, 100]),
            [stats.kurtosis(sample_rmse, nan_policy='omit'), stats.skew(sample_rmse, nan_policy='omit')]], 4).astype(str)
        
        print("NRMSE &", ' & '.join(rmse_str_arr))



        losses_dict = {
            'samples': all_samples,
            'labels': all_gt,
            'pred_std': all_std,
            'preds': all_pred,
            'mae': all_mae,
            'mse': all_mse,
            'summation': all_summation,
            'nzc': all_nzc,
            'nd': sample_nd,
            'rmse': sample_rmse,
            'mean_mae': np.mean(all_mae),
            'mean_nd': mean_nd,
            'mean_rmse': mean_rmse,
            # 'mape': all_mape,
        }

        return losses_dict


def dump_results(res_path, loss_dict, metric, data_type='test'):

    if metric == 'mae':
        pickle.dump(loss_dict, open(res_path + data_type + '.pb', 'wb'))

    fig = plt.figure(figsize=(15, 10))
    plot_data = loss_dict.get(metric)
    mean_val = loss_dict.get('mean_' + metric)

    ptile_vals = np.round(np.nanpercentile(plot_data, [0, 50, 75, 95, 98, 99, 100]), 3)
    ptile_string = ','.join(ptile_vals.astype(str))

    n_bins = 300
    counts, bins, _ = plt.hist(plot_data[~np.isnan(plot_data)], bins=n_bins)

    y_val = counts
    x_val = (bins[:n_bins] + bins[1:]) / 2

    plt.title('Mean: ' + str(np.round(mean_val, 3)) + ' | Std: ' + str(np.round(np.nanstd(plot_data), 3)) + 
        ' | Skew: ' + str(np.round(stats.skew(plot_data, nan_policy='omit'), 3)) + 
        ' | Kurt: ' + str(np.round(stats.kurtosis(plot_data, nan_policy='omit'), 3)) + 
        '\nPercentile Values (0, 50, 75, 95, 98, 99, 100): ' + ptile_string)

    plt.grid()
    plt.savefig(res_path + data_type + '_' + metric + '_hist')

    x_dict = {
        'ND': x_val,
    }

    y_dict = {
        'ND': y_val,
    }


    ax = pretty_plot(x_dict, y_dict, 'Normalized Deviation (ND)', 'Count', "", legend=False, fontsize=20)

    # plt.plot(x_val, y_val, marker='o', label='Prediction')
    # plt.plot(x_val_2, y_val_2, marker='o', color='r', label='Ground Truth')
    ax.set_xscale('log')
    ax.set_yscale('log')
    plt.savefig(res_path + metric + '_lol')
    plt.close()



def dump_gt(res_path, loss_dict):

    plt_ts = 0
    n_bins = 300

    fig = plt.figure(figsize=(15, 10))
    plot_data_2 = loss_dict.get('labels')[:, plt_ts]
    plot_data = loss_dict.get('preds')[:, plt_ts]
    # plot_data = np.log(plot_data[plot_data > 0])

    # mean_val = np.mean(plot_data)
    # ptile_vals = np.round(np.percentile(plot_data, [0, 50, 75, 95, 98, 99, 100]), 3)
    # ptile_string = ','.join(ptile_vals.astype(str))

    counts, bins, _ = plt.hist(plot_data[~np.isnan(plot_data)], bins=n_bins)
    # print(len(counts), len(bins))

    # plt.title('Mean: ' + str(np.round(mean_val, 3)) + ' | Std: ' + str(np.round(np.nanstd(plot_data), 3)) + 
    #     ' | Skew: ' + str(np.round(stats.skew(plot_data, nan_policy='omit'), 3)) + 
    #     ' | Kurt: ' + str(np.round(stats.kurtosis(plot_data, nan_policy='omit'), 3)) + 
    #     '\nPercentile Values (0, 50, 75, 95, 98, 99, 100): ' + ptile_string)

    plt.grid()
    plt.savefig(res_path + 'gt_' + str(plt_ts) + '_hist')    
    plt.close()

    counts_2, bins_2, _ = plt.hist(plot_data_2[~np.isnan(plot_data_2)], bins=n_bins)

    # fig = plt.figure(figsize=(15, 10))
    y_val = counts
    x_val = (bins[:n_bins] + bins[1:]) / 2
    
    y_val_2 = counts_2
    x_val_2 = (bins_2[:n_bins] + bins_2[1:]) / 2

    x_dict = {
        # 'Prediction': x_val,
        'Ground Truth': x_val_2,
    }

    y_dict = {
        # 'Prediction': y_val,
        'Ground Truth': y_val_2,
    }


    ax = pretty_plot(x_dict, y_dict, 'Value', 'Count', "", legend=False, color_start=1, fontsize=35, ticklabelsize=35)

    # plt.plot(x_val, y_val, marker='o', label='Prediction')
    # plt.plot(x_val_2, y_val_2, marker='o', color='r', label='Ground Truth')
    ax.set_xscale('log')
    ax.set_yscale('log')
    plt.savefig(res_path + 'traffic_gt_lol')
    plt.close()


if __name__ == '__main__':

    experiment_name = 'base_model'
    model_name = 'best'

    model_dir = 'experiments/' + experiment_name + '/'
    data_dir = 'data/traffic/'
    dataset = 'traffic'

    res_dir = 'losses/' + experiment_name + '/'

    # Load params

    json_path = os.path.join(model_dir, 'params.json')
    params = utils.Params(json_path)

    # use GPU if available

    cuda_exist = torch.cuda.is_available()

    # Set random seeds for reproducible experiments if necessary

    if cuda_exist:

        params.device = torch.device('cuda')

        # torch.cuda.manual_seed(240)
        model = net.Net(params).cuda()

    else:

        params.device = torch.device('cpu')

        # torch.manual_seed(230)
        model = net.Net(params)

    optimizer = optim.Adam(model.parameters(), lr=params.learning_rate)

    restore_path = os.path.join(model_dir,  model_name + '.pth.tar')
    utils.load_checkpoint(restore_path, model, optimizer)

    res_path = res_dir + model_name + '/'
    # os.makedirs(res_path, exist_ok=True)

    test_set = TestDataset(data_dir, dataset, params.num_class)
    test_loader = DataLoader(test_set, batch_size=params.predict_batch, sampler=SequentialSampler(test_set), num_workers=2)

    print(experiment_name)

    test_loss_dict = calculate_metrics(model, test_loader, params)
    # dump_results(res_path, test_loss_dict, metric='rmse', data_type='test')
    # dump_results(res_path, test_loss_dict, metric='nd', data_type='test')
    # dump_results(res_path, test_loss_dict, metric='mae', data_type='test')
    # dump_gt(res_path, test_loss_dict)
    # dump_hist(res_path, test_loss_dict)


