from tqdm import tqdm 
from PIL import Image
import cv2
import pickle
import matplotlib.pyplot as plt
import numpy as np



def plot_CE_difference_1d(CE_result,text_list, experiment_types,img_dir, img_g_dir, n_point, n_attr):
    plot_path = f"figs/CE_result/{experiment_types}/{img_dir.split('/')[-1]}_{img_g_dir.split('/')[-1]}"
    if not os.path.exists(plot_path):
            os.makedirs(plot_path)
    font_size = 8

    pair_text_list_np = np.expand_dims(np.array(text_list),1)
    result = np.concatenate((CE_result,pair_text_list_np),axis=1)
    result_sorted = np.array(sorted(result, key = lambda item:np.float64(item[0]))[::-1])
    stats = result_sorted[:,:2].astype('float64')
    KL, JSD = stats[:,0], stats[:,1]

    def save_img_1d(outputs, types):
        plt.close()
        fig = plt.figure()
        ax2 = fig.add_subplot(111)
        ax2.set_xticks(np.arange(len(text_list)))
        ax2.set_xticklabels(result_sorted[:,2], rotation=90, fontsize =font_size)
        ax2.bar(result_sorted[:, 2],outputs)

        ymin, ymax = ax2.get_ylim()
        interval = ymax / 10
        exponent = int(math.log10(abs(ymax)))

        interval = 10**(exponent-1)
        ax2.yaxis.set_major_locator(MaxNLocator(nbins=10, integer=True, prune='both', min_n_ticks=10, symmetric=True, steps=[1, 2, 5, 10]))
        ax2.yaxis.set_minor_locator(MultipleLocator(interval))

        plt.title(f"{types}, task: set1 : {img_dir.split('/')[-1]}, set2 : {img_g_dir.split('/')[-1]},\n \
        total_difference ={np.round(np.sum(outputs)/(stats.shape[0]), 10)*100000} ", fontsize = 5)
        plt.ylabel('each attribute-pair difference between 2 dataset', fontsize=5)
        plt.tight_layout()
        plt.show()  
        plt.savefig(f"{plot_path}/n_point_{n_point}_n_attr_{n_attr}_{types}.png", dpi = 500)
    save_img_1d(KL, "KL_1d")
    save_img_1d(JSD, "JSD_1d")

    return result_sorted




def plot_CE_difference_2d(CE_result,text_list, experiment_types,img_dir, img_g_dir, n_point, n_attr):
    plot_path = f"figs/CE_result/{experiment_types}/{img_dir.split('/')[-1]}_{img_g_dir.split('/')[-1]}"
    if not os.path.exists(plot_path):
            os.makedirs(plot_path)
    if len(text_list)>7: font_size=2
    else : font_size=5
    pair_text_list =[f"P({a[0]}, {a[1]})" for a in combinations(text_list,2)]
    pair_text_list_np = np.expand_dims(np.array(pair_text_list),1)
    result = np.concatenate((CE_result,pair_text_list_np),axis=1) 
    result_sorted = np.array(sorted(result, key = lambda item:np.float64(item[0]))[::-1])

    stats = result_sorted[:,:2].astype('float64')
    KL, JSD= stats[:,0], stats[:,1]

    def save_img(outputs, types):
        plt.close()
        fig = plt.figure()
        ax2 = fig.add_subplot(111)

        ax2.set_xticks(np.arange(len(pair_text_list)))
        ax2.set_xticklabels(result_sorted[:,2], rotation=90, fontsize =font_size)
        ax2.bar(result_sorted[:, 2],outputs)
        ymin, ymax = ax2.get_ylim()
        interval = ymax / 10
        exponent = int(math.log10(abs(ymax)))

        interval = 10**(exponent-1)
        ax2.yaxis.set_major_locator(MaxNLocator(nbins=10, integer=True, prune='both', min_n_ticks=10, symmetric=True, steps=[1, 2, 5, 10]))
        ax2.yaxis.set_minor_locator(MultipleLocator(interval))

        plt.title(f"{types}, task: set1 : {img_dir.split('/')[-1]}, set2 : {img_g_dir.split('/')[-1]},\n \
        total_difference ={np.round(np.sum(outputs)/(stats.shape[0]), 10)*10000000} ", fontsize = 5)
        plt.ylabel('each attribute-pair difference between 2 dataset', fontsize=5)
        
        plt.tight_layout()
        plt.show()    
        plt.savefig(f"{plot_path}/n_point_{n_point}_n_attr_{n_attr}_{types}.png", dpi = 2000)
    save_img(KL, "KL_2d")
    save_img(JSD, "JSD_2d")

    return result_sorted
