import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from utils.toolkit import compute_hit_matrix, information_metrics


plt.rcParams['font.size'] = 14
plt.rcParams['font.family'] = 'Times New Roman'

def getweight(path,ntask=10):
    palette = ['#8de5a1', '#ff9f9b', '#a1c9f4', '#b5b5ac','#409140', '#e06666', '#7abacc', '#8d8d8d',"#ce0c0c","#4272c4"]
    
    fig, axes = plt.subplots(2,ntask,figsize=(20,5))
    axes = axes.flatten()

    for task in range(ntask):
        print(task)
        data_dnm =  torch.load(os.path.join(path,f'icarl_cub_{task}_10_dnm.ckpt'))
        data = torch.load(os.path.join(path,f'icarl_cub_{task}_10.ckpt'))
        
        # original ones
        weights = data['fc.weight'].cpu().numpy()    
        weights = weights.flatten()

        sns.histplot(
            weights,
            bins=20,
            kde=True,
            stat='probability', # 'count', 'density', 'percent', 'probability' or 'frequency'
            ax=axes[task],
            color=palette[task],
            )
        
        axes[task].set_xlim([-.5,0.5])
        axes[task].set_ylim([0,0.5])
        axes[task].set_xticks([-.5,0,.5])
        axes[0].set_ylabel('iCaRL w/ DeL')
        if task >0:
            axes[task].set_ylabel('')
            axes[task].set_yticks([])
        
        
        
        # DeL ones
        sw =data_dnm['fc.sw'].cpu().numpy()
        sw = sw.flatten()
        sns.histplot(
            sw,
            bins=20,
            kde=True,
            stat='probability',
            ax=axes[ntask+task],
            color=palette[task],
            )
        
        axes[ntask+task].set_xlim([-.5,0.5])
        axes[ntask+task].set_ylim([0,0.5])
        axes[ntask+task].set_xticks([-0.5,0,0.5])
        axes[ntask+task].set_xlabel(f'$T_{"{"}{task+1}{"}"}$',fontsize=16)
        
        if task >0:
            axes[ntask+task].set_ylabel('')
            axes[ntask+task].set_yticks([])
        axes[0].set_ylabel('iCaRL',fontsize=16)
        axes[10].set_ylabel('iCaRL w DeL',fontsize=16)


    fig.tight_layout()
    fig.savefig('weight_distibution.png',dpi=300)


def get_entropy(path,dnm,name,task):
    palette = ['#8de5a1', '#ff9f9b', '#a1c9f4', '#b5b5ac','#409140', '#e06666', '#7abacc', '#8d8d8d',"#ce0c0c","#4272c4"]
    fig, axes = plt.subplots(2,10,figsize=(20,5))  ## entropy_dnm  4.-5  18-12
    fig1, axes1 = plt.subplots(2,10,figsize=(20,5)) ## selectivity_dnm  class
    axes  = axes.flatten()
    axes1  = axes1.flatten()
    
    
    dnm_ind = [0,1,2,3,4,10,11,12,13,14]
    original_ind = [5,6,7,8,9,15,16,17,18,19]
    dnm_ind = range(10)
    original_ind = range(10,20)

    dnm_ind_class = range(10)
    original_ind_class = range(10,20)
    for i in range(task):
        print(i)
        labels_arr = np.load(os.path.join(path,f'{dnm}_labels_{i}.npy'))
        features_arr_del = np.load(os.path.join(path,f'{dnm}_sa_x_{i}.npy')) 
        features_arr = np.load(os.path.join(path,f'{name}_features_{i}.npy'))    
        
        # entropy distribution of Original ones   
        entropy, sparsity, inactive_nodes, selectivity = information_metrics(features_arr, labels_arr, theta=0.5,thershold=0.5)
        sns.histplot(
            entropy,
            bins=20,
            kde=True,
            stat='probability',
            ax=axes[dnm_ind[i]],
            color=palette[i],
            )
        # axes[i].set_title(f'Task {i}')
        # axes[dnm_ind[i]].set_xlabel(f'$T_{i}$')
        axes[dnm_ind[i]].set_xlim([0,8])
        if i> 0 and i !=task:
            axes[dnm_ind[i]].set_ylabel('')
            axes[dnm_ind[i]].set_yticks([])

        
        # entropy distribution of  Del 
        entropy_del, sparsity, inactive_nodes, selectivity_del = information_metrics(features_arr_del[:,0,1,:], labels_arr, theta=0.5,thershold=0.5)
        sns.histplot(
            entropy_del,
            bins=20,
            kde=True,
            stat='probability',
            ax=axes[original_ind[i]],
            color=palette[i],
            )
        # axes[i].set_title(f'Task {i}')
        axes[10].set_ylabel('iCaRL w/ DeL')
        # axes[5].set_ylabel('iCaRL w/ DeL')
        axes[0].set_ylabel('iCaRL')
        # axes[15].set_ylabel('iCaRL')
        axes[original_ind[i]].set_xlabel(f'$T_{"{"}{i+1}{"}"}$')
        axes[original_ind[i]].set_xlim([0,8])
        if i> 0 and i !=task:
            axes[original_ind[i]].set_ylabel('')
            axes[original_ind[i]].set_yticks([])

        

        ###### class distribution of Original ones  
        # if i>=5:
        sns.histplot(
            selectivity,
            bins=(i+1)*10,
            kde=True,
            stat='probability',
            ax=axes1[dnm_ind_class[i]],
            color=palette[i],
            )
        # axes[i].set_title(f'Task {i}')
        axes1[0].set_ylabel('iCaRL')
        axes1[dnm_ind_class[i]].set_xlim([i*20,(i+1)*20 +1])
        if i> 0 and i !=task:
            axes1[dnm_ind_class[i]].set_ylabel('')
            axes1[dnm_ind_class[i]].set_yticks([])

        ###### class distribution of DEL
        sns.histplot(
            selectivity_del,
            bins=(i+1)*20,
            kde=True,
            stat='probability',
            ax=axes1[original_ind_class[i]],
            color=palette[i],
            )
        # axes[i].set_title(f'Task {i}')
        axes1[10].set_ylabel('iCaRL w/ DeL')
        axes1[original_ind_class[i]].set_xlabel(f'$T_{"{"}{i+1}{"}"}$')
        # axes[original_ind[i]].set_xlim([0,8])
        axes1[original_ind_class[i]].set_xlim([i*20,(i+1)*20+1])
        if i> 0 and i !=task:
            axes1[original_ind_class[i]].set_ylabel('')
            axes1[original_ind_class[i]].set_yticks([])
    fig.tight_layout()
    fig1.tight_layout()
    fig.savefig(f'entropy_disribution{task}.png',dpi=300)
    fig1.savefig(f'class_disribution{task}.png',dpi=300)
    

def sa_da_feature(path, dnm,original,palette):
    features_arr = np.load(os.path.join(path,f'{dnm}_sa_x_{0}.npy'))
    features_ = np.load(os.path.join(path,f'{original}_features_{0}.npy'))
    print(features_[0,:])
    # plt.rcParams['image.cmap'] = 'PuBuGn' 
    # plt.rcParams['image.cmap'] = 'PuBuGn' 
    # plt.subplot(3,1,1)
    # plt.imshow(features_[:1,:10])
    # plt.xticks([])
    # plt.yticks([])
    # plt.tight_layout()
    # plt.savefig(f'fea_OR.png',dpi=300)
    fes_ind = 9
    plt.figure()
    plt.imshow(features_arr[:1,0:10,0,fes_ind])
    plt.xticks([])
    plt.yticks([])
    plt.tight_layout()
    plt.savefig(f'dnm_sa1_{fes_ind}.png',dpi=300,bbox_inches='tight')

    plt.figure()
    plt.imshow(features_arr[:1,0:10,1,fes_ind])
    plt.xticks([])
    plt.yticks([])
    plt.tight_layout()
    plt.savefig(f'dnm_sa2_{fes_ind}.png',dpi=300,bbox_inches='tight')

    print(features_arr.shape)
    
    for d in range(1):
        fig, axes = plt.subplots(2,1,figsize=(2.5,3))
        axes = axes.flat
        indd = [9,10]
        for i,ind in enumerate(indd):
            sns.kdeplot(features_arr[ind,i,0,:],
                        fill=True,
                        ax=axes[i],
                        alpha=.5,
                        color=palette[2],
                )
            axes[i].set_ylabel('')
            # axes[i].set_xlabel(f'{i}')
            axes[i].set_yticks([])
            axes[i].set_xticks([])
            axes[i].set_xlim([0.45,0.556])
        fig.tight_layout(pad=0.1)
        fig.savefig(f'da_.png',dpi=300)


if __name__ == "__main__":
    path = rf'./\results\icarl/cub_0_20/'
    dnm = 'reproduce_1993_resnet18_2_sigmoid_none_none_True_False'
    original = 'reproduce_1993_resnet18'
    palette = ['#8de5a1', '#ff9f9b', '#a1c9f4',"#ffffff", '#b5b5ac','#409140', '#e06666', '#7abacc', '#8d8d8d',"#ce0c0c","#4272c4"]
    # getweight(path)
    # get_entropy(path,dnm,original,10)
    sa_da_feature(path, dnm,original,palette)