import numpy as np
import jacinle.io as io
import matplotlib.pyplot as plt
import scipy.stats as sps
import sys
sys.path.append('..')
from configs.cfg import CfgGMM
cfg = CfgGMM()

cfg.T_linear = 1
dim_list = [2, 6, 12, 16, 32]
trial_list = list(range(1, 16))
# [1.0, 2.0, 3.0, 4.0, 5.0]


def get_mean_med_int(arr, alpha=0.95):
    arr = arr[~np.isnan(arr)]
    _mean = np.mean(arr)
    _med = np.median(arr)
    S = np.sqrt(np.mean(arr**2) - _mean**2)
    z = sps.norm.ppf((1. + alpha) / 2.)
    _int = z * S / np.sqrt(len(arr))
    return _mean, _med, _int


def draw_stats(stats, ax, color, label, legend_position='best'):
    ax.plot(dim_list, stats['meds'], color=color, label=label)
    _means = np.asarray(stats['means'])
    _ints = np.asarray(stats['ints'])
    ax.fill_between(
        dim_list,
        (_means - _ints),
        (_means + _ints), color=color, alpha=.1)
    ax.legend(loc=legend_position)


log10symkl_array = np.zeros([len(dim_list), len(trial_list)])
for num_k in [10, 18]:
    conv_stats = {
        'means': [],
        'meds': [],
        'ints': []
    }
    for idx_dim in range(len(dim_list)):
        for idx_trial in range(len(trial_list)):
            cfg.INPUT_DIM = dim_list[idx_dim]
            cfg.TRIAL = float(trial_list[idx_trial])
            try:
                log10symkl_array[idx_dim, idx_trial] = io.load(
                    cfg.get_save_path() + f'/storing_P/log10symkl_{num_k}.pt')
            except:
                print('\n', "non_existing!!!!!!!!!", cfg.get_save_path() +
                      f'/storing_P/log10symkl_{num_k}.pt')
        _mean, _med, _int = get_mean_med_int(log10symkl_array[idx_dim])
        conv_stats['means'].append(_mean)
        conv_stats['meds'].append(_med)
        conv_stats['ints'].append(_int)
    io.dump(f'./test/ou_sym_kl_0_{int(num_k/2)}.pt', conv_stats)
    fig, ax = plt.subplots(figsize=(8, 5.7), dpi=80)
    ax.set_xlabel('D, dimension', fontsize=18)
    ax.set_ylabel(r'$\log_{10}$SymKL', fontsize=18)

    draw_stats(conv_stats, ax,
               'blue', 'Ours', legend_position='best')

    plt.grid(which='major')
    plt.minorticks_on()
    plt.tight_layout()
    plt.ylim(-4.5, 0.9)
    fig.savefig(f"./test/ou_sym_kl_0_{int(num_k/2)}.png", bbox_inches='tight', dpi=200)
