import os
import json
import pickle
from tempfile import tempdir
import numpy as np
from IPython.display import clear_output
import matplotlib.pyplot as plt

import pandas as pd
import gc

import datetime

import tensorflow as tf
from utils.experiment import Experiment
from utils.preprocessing import get_texture_ds, get_natural_ds
from sklearn.preprocessing import OneHotEncoder, scale

import tqdm
import random

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

import tensorflow as tf
tfk=tf.keras

physical_devices = tf.config.experimental.list_physical_devices('GPU')
if physical_devices:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

NUM_TEXT_FAMILY = 5    
    
def prepare_experiment(exp_path):

    with open(os.path.join(exp_path,"configs.json")) as configs_json:
        exp_configs=json.load(configs_json)

    exp = Experiment(exp_configs)
    
    return exp

def get_Z_for_dataset(model, dataset, filter_dict=None):
    
    texture_z1=[]
    texture_z2=[]
    texture_label=[]

    for idx,batch in enumerate(dataset):

        if idx%100 ==0:
            print (f"{idx} / {len(dataset)}")
            clear_output(wait=True)


        # Mean ???
        X=batch[0]
        y=batch[1]
        z1 = model.q_z1_x_model(X).mean().numpy()
        z2 = model.q_z2_z1_model(z1).mean().numpy()

        if filter_dict:
            z1 = z1[:, filter_dict['filter_dims']]
        texture_z1.append(z1)
        texture_z2.append(z2)
        texture_label.append(y)


    print("Converting to numpy...")
    gc.collect()

    Z1=np.concatenate(texture_z1) #The code blows up here for 40px training data
    Z2=np.concatenate(texture_z2)
    del texture_z1, texture_z2
    gc.collect()
    y=np.concatenate(texture_label)

    return Z1, Z2, y

def get_Z_data(model, image_size=20, filter_dict=None):

    ds_train, ds_test = get_texture_ds(batch_size=512, image_size=image_size) 

    print ("Processing test set...")
    Z1_test, Z2_test, y_test = get_Z_for_dataset(model, ds_test, filter_dict=filter_dict)
    print ("Processing train set...")
    Z1_train, Z2_train, y_train = get_Z_for_dataset(model, ds_train, filter_dict=filter_dict)
    del ds_train, ds_test 
    gc.collect()

    print("Returning...")
    return Z1_train, Z2_train, y_train, Z1_test, Z2_test, y_test


def get_log_reg_model(X_train, y_train, X_test, y_test):
    y_train_p = OneHotEncoder(sparse=False).fit_transform(y_train.reshape(-1,1))
    y_test_p = OneHotEncoder(sparse=False).fit_transform(y_test.reshape(-1,1))
    
    print("Creating model...")    
    number_of_classes = y_test_p.shape[1]
    number_of_features = X_train.shape[1]

    model = Sequential()
    model.add(Dense(number_of_classes,activation = 'sigmoid',input_dim = number_of_features))
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=["accuracy"])
    
    print("Fitting model...") 
    model.fit(X_train, y_train_p, epochs=1, validation_data=(X_test, y_test_p))
    accuracy=model.evaluate(X_test, y_test_p,return_dict=True)["accuracy"]
    return model, accuracy

def decodability(model, image_size=20, filter_dict=None):
    Z1_train, Z2_train, y_train, Z1_test, Z2_test, y_test = get_Z_data(model, image_size=image_size, filter_dict=filter_dict)
    gc.collect()
    Z1_model, Z1_accuracy = get_log_reg_model(Z1_train, y_train, Z1_test, y_test)
    Z2_model, Z2_accuracy = get_log_reg_model(Z2_train, y_train, Z2_test, y_test)

    return Z1_model, Z1_accuracy, Z2_model, Z2_accuracy 

def scrambling_decodability(model,filters, non_filters):

    class Model_flow (tfk.Model):
        def __init__(self, original_model, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.original_model=original_model

        def call(self,I):
            x=self.original_model.p_x_z1_model(I).mean()
            x=self.original_model.q_z1_x_model(x).mean()
            O=self.original_model.q_z2_z1_model(x).mean()
            return O

    model_flow=Model_flow(model)

    Z1_train, Z2_train, y_train, Z1_test, Z2_test, y_test = get_Z_data(model)

    scrambled_Z1_train=np.zeros(Z1_train.shape,dtype=np.float16)
    scrambled_Z1_test=np.zeros(Z1_test.shape,dtype=np.float16)

    print("Scrambling train set")
    for i in range(Z1_train.shape[0]):
        filter_values=Z1_train[i,filters]
        np.random.shuffle(filter_values)

        scrambled_Z1_train[i,filters]=filter_values.astype(np.float16)
        scrambled_Z1_train[i,non_filters]=Z1_train[i,non_filters].astype(np.float16)


    print(scrambled_Z1_train.shape)
    print("Predicting Z2 train")
    split_Z1 = np.array_split(scrambled_Z1_train,20, axis=0)
    Z2_train_split=[]

    for chunk in tqdm.tqdm(split_Z1):
        Z2_train_split.append(model_flow(chunk))
        tfk.backend.clear_session()
        gc.collect()

    Z2_train=np.concatenate(Z2_train_split)

    print("Scrambling test set")
    
    for i in range(Z1_test.shape[0]):
        filter_values=Z1_test[i,filters]
        np.random.shuffle(filter_values)

        scrambled_Z1_test[i,filters]=filter_values.astype(np.float16)
        scrambled_Z1_test[i,non_filters]=Z1_test[i,non_filters].astype(np.float16)

    print("Predicting Z2 test")

    Z2_test=model_flow(scrambled_Z1_test)

    gc.collect()


    Z2_model, Z2_accuracy = get_log_reg_model(Z2_train, y_train, Z2_test, y_test)


    return Z2_model, Z2_accuracy


def plot_full_latent_space(self, n_cols, shape=None,diff=0 , step_size=1, save_path=None, z1_units=450):

        for samples in self.ds_train.take(1):
            sample=np.array([samples[0][0]])

        encoding=self.model.q_z1_x_model(sample)

        input_0=encoding.sample()
        output_0=tf.squeeze(self.model.p_x_z1_model(input_0).mean())

        if shape:
            output_0 = tf.reshape(output_0, shape)

        # num_dims=self.experiment_configs["model_configs"]["encoder"]["latent_dims"]
        num_dims=z1_units
        n_rows=int(np.ceil(num_dims/n_cols))
        fig, ax = plt.subplots(nrows=n_rows, ncols=n_cols,figsize=(n_cols*2,n_rows*2))

        dims = np.arange(num_dims)

        for i in dims:
            
            input_i = np.zeros([1,num_dims])
            input_i[0,i]=step_size
            input_i+=input_0
    
            output_i=tf.squeeze(self.model.p_x_z1_model(input_i).mean())

            if shape:
                output_i = tf.reshape(output_i, shape)

            ax[i//n_cols][i%n_cols].imshow(output_i-diff*output_0,interpolation='none', cmap='gray')


            ax[i//n_cols][i%n_cols].set_title(f"{i}")

            ax[i//n_cols][i%n_cols].axis('off')

        if save_path:
            path= save_path
        else:
            path=os.path.join(self.directory,f"latent_space.png")

        plt.title(self.experiment_configs["experiment_directory"].split("/")[-1])
        
        fig.savefig(path,facecolor="white")
        # plt.show()
        
        
def plot_reconstruction(experiment,split,shape=(20,20)):
    #Not tested, goal is to give the same dataset to different models / setups in the future
    if split=="val":
        x = next(iter(experiment.ds_val))[0][:10]

    elif split=="train":
        x = next(iter(experiment.ds_train))[0][:10]

    else:
        raise ValueError(f"Plotting reconstruction for split {split} is not possible.")


    xhat = experiment.model(x)

    shape=(10,*shape)
    x = tf.reshape(x, shape)


    xs = [x,tf.reshape(xhat.sample(), shape), tf.reshape(xhat.mean(), shape)]

    row_titles = ["Original", "Sampled", "Mean"]
    n=x.shape[0]
    m=len(row_titles)

    fig, axes = plt.subplots(nrows=m, ncols=n, figsize=(12, 8))

    for ax, row in zip(axes[:,0], row_titles):
        ax.set_title(row,  size='large')

    for i in range (m):
        for j in range (n):
            axes[i,j].imshow(xs[i][j],interpolation='none', cmap='gray')
            axes[i,j].axis('off')

    fig.tight_layout()

    fig.savefig(os.path.join(experiment.directory,"analysis",f"reconstruction_{split}.png"),facecolor="white")
    #fig.savefig(f"reconstruction_{split}.png",facecolor="white")
    
def Z1_traversal(experiment, n_cols, shape=None,diff=0 , step_size=1, save_path=None, num_dims=450):
    #batch
    for samples in experiment.ds_train.take(1):
        sample=samples[0]

    encoding=experiment.model.q_z1_x_model(sample)

    input_0=encoding.sample()
    output_0=tf.reduce_mean(experiment.model.p_x_z1_model(input_0).mean(),axis=0)

    if shape:
        output_0 = tf.reshape(output_0, shape)

    n_rows=int(np.ceil(num_dims/n_cols))
    fig, ax = plt.subplots(nrows=n_rows, ncols=n_cols,figsize=(n_cols*2,n_rows*2))

    dims = np.arange(num_dims)

    receptive_field_list = []

    for i in dims:
        
        input_i = np.zeros([1,num_dims])
        input_i[0,i]=step_size
        input_i=input_0+input_i


        output_i=tf.reduce_mean(experiment.model.p_x_z1_model(input_i).mean(),axis=0)

        if shape:
            output_i = tf.reshape(output_i, shape)

        receptive_field_list.append((output_i-diff*output_0).numpy())

        ax[i//n_cols][i%n_cols].imshow(output_i-diff*output_0,interpolation='none', cmap='gray')


        ax[i//n_cols][i%n_cols].set_title(f"{i}")

        ax[i//n_cols][i%n_cols].axis('off')

    receptive_field_tensor = np.concatenate(receptive_field_list)

    pickle_path = path=os.path.join(experiment.directory,"analysis",f"Z1_trav.pkl") 
    pickle.dump(receptive_field_tensor, open(pickle_path, "wb"))


    if save_path:
        path= save_path
    else:
        path=os.path.join(experiment.directory,"analysis",f"Z1_trav.png")

    plt.title(experiment.experiment_configs["experiment_directory"].split("/")[-1])
    
    fig.savefig(path,facecolor="white")
    plt.show()

def Z2_traversal (experiment, n_cols, shape=None,diff=0 , step_size=1, save_path=None, num_dims=70):
    #batch
    for samples in experiment.ds_train.take(1):
        sample=samples[0]

    z1=experiment.model.q_z1_x_model(sample)
    z2=experiment.model.q_z2_z1_model(z1)

    input_0=z2.sample()

    z1=experiment.model.p_z1_z2_model(input_0).mean()#???
    x=experiment.model.p_x_z1_model(z1).mean()
    output_0=tf.reduce_mean(x,axis=0)

    if shape:
        output_0 = tf.reshape(output_0, shape)

    n_rows=int(np.ceil(num_dims/n_cols))
    fig, ax = plt.subplots(nrows=n_rows, ncols=n_cols,figsize=(n_cols*2,n_rows*2))

    dims = np.arange(num_dims)

    for i in dims:
        
        input_i = np.zeros([1,num_dims])
        input_i[0,i]=step_size
        input_i=input_0+input_i
        
        z1=experiment.model.p_z1_z2_model(input_i).mean()#???
        x=experiment.model.p_x_z1_model(z1).mean()
        output_i=tf.reduce_mean(x,axis=0)

        # output_i=tf.reduce_mean(experiment.model.p_x_z1_model(input_i).mean(),axis=0)

        if shape:
            output_i = tf.reshape(output_i, shape)

        ax[i//n_cols][i%n_cols].imshow(output_i-diff*output_0,interpolation='none', cmap='gray')


        ax[i//n_cols][i%n_cols].set_title(f"{i}")

        ax[i//n_cols][i%n_cols].axis('off')

    if save_path:
        path= save_path
    else:
        path=os.path.join(experiment.directory,"analysis",f"Z2_trav.png")

    plt.title(experiment.experiment_configs["experiment_directory"].split("/")[-1])
    
    fig.savefig(path,facecolor="white")
    plt.show()


def shuffle_along_axis(a, axis):
    idx = np.random.rand(*a.shape).argsort(axis=axis)
    return np.take_along_axis(a,idx,axis=axis)

def get_Z2_post(model, ex, filter_dict=None, class_label=None):
    if class_label is not None:
        image_batch = ex[0][ex[1] == class_label]
    else:
        image_batch = ex[0]
        
    Z1=model.q_z1_x_model(image_batch).mean()

    if filter_dict:

        filters=filter_dict["filter_dims"]
        non_filters=filter_dict["non_filter_dims"]
        
        Z1=Z1.numpy()
        Z1[:,filters]=shuffle_along_axis(Z1[:,filters], axis=1)


    Z2_mean=model.q_z2_z1_model(Z1).mean()
    Z2_std=model.q_z2_z1_model(Z1).stddev()

    return Z2_mean, Z2_std


def get_Z2_mean_std(model, ds, filter_dict=None, class_label=None):

    Z2_means=[]
    Z2_stds=[]

    for ex in ds:
        next_mean, next_std = get_Z2_post(model, ex, filter_dict, class_label=class_label)
        Z2_means.append(next_mean)
        Z2_stds.append(next_std)

    average_Z2_means=np.mean(np.concatenate(Z2_means, axis=0), axis=0)
    average_abs_Z2_means=np.mean(np.concatenate(np.abs(Z2_means), axis=0), axis=0)
    std_Z2_means=np.std(np.concatenate(Z2_means, axis=0), axis=0)
    average_Z2_stds=np.mean(np.concatenate(Z2_stds, axis=0), axis=0)


    fig, axs=plt.subplots(ncols=2,figsize=(12,6))


    axs[0].hist(average_Z2_means.flatten(), bins=30)
    axs[0].set_title("mean")


    axs[1].hist(average_Z2_stds.flatten(), bins=30)
    axs[1].set_title("std")

    print(f"Dims with std < 0.95 : {np.count_nonzero(average_Z2_stds<0.95)}")
    print(f"Dims with std of means < 0.95 : {np.count_nonzero(std_Z2_means<0.95)}")
    print(f"Dims with abs(mean)>0.05  : {np.count_nonzero(np.abs(average_Z2_means)>0.05)}")

    return average_Z2_means, average_Z2_stds, average_abs_Z2_means, std_Z2_means
    

def generate_active_dim_plots(experiment_path, image_size=20):
    import utils.analysis as ua
    import utils.preprocessing as up
    exp = ua.prepare_experiment(experiment_path)   
    _, ds_test = up.get_texture_ds(batch_size=20000, image_size=image_size)
    text_mean, text_std, text_abs_mean, std_mean = {}, {}, {}, {}
    for label in range(NUM_TEXT_FAMILY):
        text_mean[label], text_std[label], text_abs_mean[label], std_mean[label] = get_Z2_mean_std(exp.model, ds_test, class_label=label)
    text_mean_df = pd.DataFrame(data=text_mean)       #mean of posterior means
    std_mean_df = pd.DataFrame(data=std_mean)         #std of posterior means
    text_std_df = pd.DataFrame(data=text_std)         #mean of posterior std
    save_dir = os.path.join(experiment_path,"analysis", "z2_dim_plots")
    os.makedirs(save_dir, exist_ok=True)
    text_mean_path = os.path.join(save_dir, f"text_mean.png")
    std_mean_path = os.path.join(save_dir, f"std_mean.png")
    text_std_path = os.path.join(save_dir, f"text_std.png")
    fig1 = text_mean_df.plot.bar(figsize=(20,10), title='Average posterior means by texture families for each z2 latent variable')
    fig1.set_xlabel('z2 latent variable index')
    fig1.figure.savefig(text_mean_path)
    fig2 = std_mean_df.plot.bar(figsize=(20,10), title='Std of posterior means by texture families for each z2 latent variable')
    fig2.set_xlabel('z2 latent variable index')
    fig2.figure.savefig(std_mean_path)
    fig3 = text_std_df.plot.bar(figsize=(20,10), title='Average posterior standard deviation by texture families for each z2 latent variable')
    fig3.figure.savefig(text_std_path)
    fig3.set_xlabel('z2 latent variable index')
    
    
def infer_sizes_from_config(config):
    input_dim = config["q_z1_x_configs"]["input_shape"]
    if input_dim == 400:
        image_size = 20
    elif input_dim == 1600:
        image_size = 40
    elif input_dim == 2500:
        image_size = 50
    else:
        raise ValueError('unknown image size')
    #very evil, sorry
    num_z1_dims = config["p_z1_z2_configs"]["output_shape"]
    num_z2_dims = config["p_z1_z2_configs"]["input_shape"]
    
    return image_size, num_z1_dims, num_z2_dims
    
def full_analysis(experiment_path, seed=None):
    if seed is not None:
        tf.random.set_seed(seed)
        random.seed(seed)
        np.random.seed(seed)

    exp = prepare_experiment(experiment_path)
    config = exp.model_configs
    
    image_size, num_z1_dims, num_z2_dims = infer_sizes_from_config(config)
    
    ds_train,ds_test = get_natural_ds(image_size=image_size)
    
    os.makedirs(os.path.join(exp.directory,"analysis"), exist_ok=True)


    exp.set_datasets(ds_train,ds_test)

    plot_reconstruction(exp,"val", shape=(image_size, image_size))
    plot_reconstruction(exp,"train", shape=(image_size, image_size))
   
    Z1_traversal(exp, 10, shape=(image_size, image_size),diff=1 , step_size=1, num_dims=num_z1_dims)
    Z2_traversal(exp, 10, shape=(image_size, image_size),diff=1 , step_size=1, num_dims=num_z2_dims)

    filter_dict = Z1_filters(exp,diff=0 , step_size=1, num_dims=num_z1_dims)

    model = exp.model
    directory = exp.directory
    del ds_train, ds_test, exp
    gc.collect()

    Z1_model, Z1_accuracy, Z2_model, Z2_accuracy  = decodability(model, image_size=image_size, filter_dict=filter_dict)


    last_metrics=pd.read_csv(os.path.join(directory, "log.csv")).iloc[-1]

    model_results={
        "name": os.path.normpath(directory).split(os.sep)[-1]
    }
    model_results.update(last_metrics)

    model_results["Z1_decodability"]=Z1_accuracy
    model_results["Z2_decodability"]=Z2_accuracy
    model_results["filter_count"]=int(filter_dict["filter_count"])

    with open(os.path.join(directory,"analysis","results.json"),"w") as f:
        json.dump(model_results,f)
    import pdb; pdb.set_trace()
    gc.collect()
    pdb.set_trace() 

    generate_active_dim_plots(experiment_path, image_size=image_size)

        
def create_results_table(paths):
    results_df=pd.DataFrame()

    for path in paths:
        result_path=os.path.join(path, "analysis","results.json")

        with open(result_path) as f:
            result=json.load(f)

        results_df=results_df.append(result,ignore_index=True)

    save_dir="/home/documentation/analysis"
    os.makedirs(save_dir, exist_ok=True)

    save_path=os.path.join(save_dir,f"results_{datetime.datetime.now().strftime('%y%m%d-%H%M%S')}.xlsx")

    results_df.to_excel(save_path)


def Z1_filters(experiment, shape=None,diff=0 , step_size=1, num_dims=450):
    from sklearn.cluster import k_means

    dims = np.arange(num_dims)

    MSPs=[]
    for i in dims:
        
        input_i = np.zeros([1,num_dims])
        input_i[0,i]=step_size

        output_i=tf.reduce_mean(experiment.model.p_x_z1_model(input_i).mean(),axis=0)

        MSP =tf.reduce_mean(tf.math.square(output_i))
        MSPs.append(MSP)

    hist=np.histogram(MSPs, bins=int(num_dims/10))

    sample_weight = hist[0]
    sample_weight[hist[0] < 3] = 0
    _, label, _ = k_means(hist[1][:-1].reshape(-1, 1), 2, sample_weight=sample_weight)

    bin_threshold=np.argmax(label!=label[0])

    threshold=hist[1][bin_threshold+1]
    print(threshold)
    msp_array=np.array(MSPs)

    filter_count=np.sum(msp_array > threshold)
    filter_dims=np.arange(num_dims)[np.where(msp_array > threshold)]
    non_filter_dims=np.arange(num_dims)[np.where(msp_array < threshold)]


    filter_dict={
    "filter_count": filter_count,
    "filter_dims": filter_dims,
    "non_filter_dims": non_filter_dims
    }
    
    save_path=os.path.join(experiment.directory,"analysis",f"Z1_filters.json")
    with open(save_path,"w") as f:
        json.dump(filter_dict,f,cls=NpEncoder)
        
    return filter_dict

class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)

