import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib.gridspec import GridSpec
from PIL import Image, ImageDraw
import seaborn as sns
sns.set_style("whitegrid") #darkgrid


def plot_score_distr(concept_dict, scores_in, preds_in, scores_out, preds_out, confidence, thresh, save_plot):
    # concept_dict has 'scores', 'mean', 'std'
    # concept_dict['scores'], scores_IN, scores_OOD: dim=(n_features, n_concepts)
    n_classes = len(concept_dict)
    n_concepts = np.shape(scores_in)[1]

    for c in range(n_classes):
        idx_in = np.where(preds_in == c)[0]
        idx_out = np.where(preds_out == c)[0]
        scores_in_ = scores_in[idx_in,:]
        scores_out_ = scores_out[idx_out,:]

        profile_scores = concept_dict[c]['scores']
        conf_interval = scipy.stats.t.interval(confidence, np.shape(profile_scores)[0]-1, loc=np.mean(profile_scores, axis=0), scale=scipy.stats.sem(profile_scores))


        n = int(np.ceil(np.sqrt(n_concepts)))
        fig, axes = plt.subplots(n-1 if n*(n-1) > n_concepts else n, n, \
                                figsize=(30, 25), sharex=True, sharey=True)
        fig.suptitle('Class {}: distribution of {} concept scores'.format(c, n_concepts))
        for i in range(n_concepts):
            #plt.figure()
            axes[i//n, i%n].set_title('Concept {}'.format(i))

            sns.kdeplot(profile_scores[:,i], color='grey', ax=axes[i//n, i%n])
            sns.kdeplot(scores_in_[:,i], color='blue', ax=axes[i//n, i%n])
            sns.kdeplot(scores_out_[:,i],color='red', ax=axes[i//n, i%n])
            """
            sns.histplot(profile_scores[:,i], color='grey', ax=axes[i//n, i%n])#, fit=norm, kde=False)
            sns.histplot(scores_in_[:,i], color='blue', ax=axes[i//n, i%n])#, fit=norm, kde=False)
            sns.histplot(scores_out_[:,i],color='red', ax=axes[i//n, i%n])#, fit=norm, kde=False)
            """

            """
            # draw threshold lines of confidence interval
            low = conf_interval[0][i] # left-most point of interval
            high = conf_interval[1][i] # right-most point of interval
            low = (thresh/2+0.5)*low - (thresh/2-0.5)*high
            high = (thresh/2+0.5)*high - (thresh/2-0.5)*low
            plt.axvline(low, color='k', linestyle='--')
            plt.axvline(high, color='k', linestyle='--')
            """
        # save plot
        #fig = plt.gcf()
        fig.legend(['in-distribution (train)', 'in-distribution (test)', 'out-of-distribution'])
        #fig.savefig("{}/scores_class{}_concept{}_{}%confidence*{}.png".format(save_plot,c,i,confidence*100,thresh))
        fig.savefig("{}/scores_class{}_{}%confidence*{}.png".format(save_plot,c,confidence*100,thresh))
        plt.close()


def debug_with_plots(profiles, in_scores, out_scores, concept_idx, savepath, in_filename, out_image): #out_filename):
    # profiles: dictionary containing concept scores profiles from training data, dim=(n_training, n_concepts)
    # in_scores: dim=(1, n_concepts)
    # out_scores: dim=(1, n_concepts)
    # concept_idx: index of concepts to be visualized
    
    n_data = profiles.shape[0]
    n_concepts = len(concept_idx)
    
    # converting score profiles into dataframe
    data_profile = np.c_[['concept'+str(concept_idx[0])]*n_data, profiles[:,concept_idx[0]]]
    for c in concept_idx[1:]:
        data_profile = np.r_[data_profile, np.c_[['concept'+str(c)]*n_data, profiles[:,c]]]
    df_profile = pd.DataFrame(data_profile, columns = ['concept','score'])
    df_profile['score'] = df_profile['score'].astype(float)
    
    # converting IN and OUT scores to be plotted into dataframe
    data = np.c_[['in']*len(concept_idx), ['concept'+str(c) for c in concept_idx], in_scores[concept_idx].reshape(-1,1)]
    data = np.r_[data, np.c_[['out']*len(concept_idx), ['concept'+str(c) for c in concept_idx], out_scores[concept_idx].reshape(-1,1)]]
    df = pd.DataFrame(data, columns = ['in/out', 'concept','score'])
    df['score'] = df['score'].astype(float)
    # plotting
    fig = plt.figure(constrained_layout=True)
    gs = GridSpec(2, 2, figure=fig)
    
    ax1 = fig.add_subplot(gs[0, :])
    ax1.set_title("In/Out concept scores compared to score profiles", size=10)
    sns.barplot(x="concept", y="score", data=df_profile)
    #g = sns.catplot(x='concept',y='score', data=df_profile, ci="sd", capsize=0.1, kind="bar")
    #g = sns.catplot(x="concept", y="score", hue="in/out",
    #        palette={"in": "m", "out": "r"},
    #        markers=["^", "o"], linestyles=["-", "--"],
    #        kind="point", data=df)
    sns.pointplot(x="concept", y="score", hue="in/out", 
            palette={"in": "b", "out": "r"}, markers=["^", "o"], linestyles=["--", "-"], 
            data=df) 
    ax1.legend() #['ID', 'OOD'])
    plt.xticks(fontsize=6, rotation=45)
    sns.light_palette("seagreen", as_cmap=True)

    ax2 = fig.add_subplot(gs[1, 0])
    in_img = Image.open(in_filename).resize((224,224), Image.ANTIALIAS)
    ax2.imshow(in_img)
    ax2.set_title("ID image", size=10, color='b')
    ax2.axis('off')

    ax3 = fig.add_subplot(gs[1,1])
    if isinstance(out_image, str):
        out_img = Image.open(out_image).resize((224,224), Image.ANTIALIAS)
    else:
        out_img = Image.fromarray(out_image, 'RGB')
    ax3.imshow(out_img)
    ax3.set_title("OOD image", size=10, color='r')
    ax3.axis('off')

    plt.savefig(savepath) #, dpi=500)
    plt.close()


def debug_with_patch(img_filename, score_all, score_diff, profile_mean, filenames_IN,
                    scores_IN, scores_IN_all, savepath, k=5, n_examples=3):
    # score_all: dim=(1,5,5,n_concepts)
    # score_diff: dim=(1,70)

    Image1 = Image.open(img_filename)
    Image1 = Image1.resize((224,224), Image.ANTIALIAS)

    # when input size = (224,224) with Inception-V3
    receptive_size = 37*2
    jump = 32
    
    fig, ax = plt.subplots(2+n_examples, k, sharey=True)
    ax[0,k//2].imshow(Image1)
    ax[0,k//2].set_title("test image")

    #idxs = np.argsort(np.abs(score_diff))[::-1][:k] # index for concepts with top-k score_diff values
    # Get top-k most significant concepts for prediction of each class
    s_mean = np.mean(scores_IN, axis=0)
    idxs = np.argsort(np.abs(s_mean))[::-1][:k] # top K: from largest to smallest in absolute
    for i, idx in enumerate(idxs):
        a, b = np.unravel_index(np.argmax(score_all[:,:,idx], axis=None), score_all[:,:,idx].shape)
        print('scores for concept {}...'.format(idx))
        print(score_all[:,:,idx])
        left = jump*b
        right = left+receptive_size
        top = jump*a
        bottom = top+receptive_size

        Image1copy = Image1.copy()
        region = Image1copy.crop((left,top,right,bottom))
        new_im = Image.new("RGB", (receptive_size,receptive_size))
        new_im.paste(region, (1,1))
        new_im = new_im.resize((224,224), Image.ANTIALIAS)
        ax[1,i].imshow(new_im)
        ax[1,i].set_title("concept {}:\n {:.3f}".format(idx, score_diff[idx]), fontsize=10)
    
        # sorted images from test in-distribution set most similar to the concept
        #idxs_examples = np.argsort(np.abs(scores_IN[:,idx]-profile_mean[idx]))[:n_examples]
        idxs_examples = np.argsort(scores_IN[:,idx])[::-1][:n_examples]
        for j, idx_example in enumerate(idxs_examples):
            Example = Image.open('./data/Animals_with_Attributes2/test/'+filenames_IN[idx_example])
            Example = Example.resize((224,224), Image.ANTIALIAS)
            draw = ImageDraw.Draw(Example)
            aa, bb = np.unravel_index(np.argmax(scores_IN_all[idx_example,:,:,idx], axis=None), scores_IN_all[idx_example,:,:,idx].shape)
            left, top = jump*bb, jump*aa
            right, bottom = left+receptive_size, top+receptive_size
            draw.rectangle((left,top,right,bottom), outline ="red", width=3)
            ax[2+j,i].imshow(Example)
            ax[2+j,i].set_title("example {}".format(j), fontsize=7)

    #plt.grid(False)
    [axi.set_axis_off() for axi in ax.ravel()]
    #fig.savefig(savepath+'.jpg')
    fig.savefig(savepath+'_significant.jpg')
    plt.close(fig)
