import itertools
import pickle
import os
import argparse
from omegaconf import OmegaConf
import itertools



import seaborn as sns
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl 
from scipy.stats import gaussian_kde
from tqdm import tqdm

plt.rcParams['text.usetex'] = True
mpl.rc('font',family='Times New Roman')
#color_hexes = ['#006BA4', '#FF800E', '#ABABAB', '#595959', '#5F9ED1',
#        '#C85200', '#898989', '#A2C8EC', '#FFBC79', '#CFCFCF']
#color_hexes = sns.color_palette('colorblind')
color_hexes = ['#377eb8', '#ff7f00', '#4daf4a',
              '#f781bf', '#a65628', '#984ea3',
              '#999999', '#e41a1c', '#dede00'] 

if __name__ == '__main__':
    print('add path')
    #path = 
    pair_dist_mat = np.load(os.path.join(path, 'pair_dist_mat.npy'))
    weight = np.array([1/2])
    paide_mat = np.zeros(pair_dist_mat.shape)
    for i in tqdm(range(pair_dist_mat.shape[0])):
        dist = pair_dist_mat[i,:,:,:]
        for ec in range(5):
            for j in range(ec+1, 5):
                dist1 = dist[ec, j, :]
                dist2 = dist[j, ec, :]
                dist_sub = np.stack([dist1,dist2])
                pairwise_dist = np.log(weight)+weight*np.log(np.exp(-dist_sub).sum(0))
                paide_mat[i,ec,j,:] = -pairwise_dist
                paide_mat[i,j,ec,:] = -pairwise_dist
    np.save(os.path.join(path, 'paide_mat.npy'), paide_mat)
    import pdb; pdb.set_trace()
