#%%
import os
import sys
import pickle as pic
import numpy as np
import pandas as pd
import torch as to
import matplotlib.pyplot as plt
from numpy.random import Generator, PCG64

pwd = os.getcwd()
sys.path.append(pwd + "/src")
plt.style.use(pwd + '/plot_params.dms')


import DatasetTools as dts
from MLP import NN

def load_pickles(folder):
    """
    Loads pickles in folder. Assumes folder is in the current working directory
    """

    file_names = sorted(os.listdir(folder))
    pickles = []

    my_new_list = []
    for value_str in file_names:
        if value_str != '.DS_Store':
            my_new_list.append(value_str)
    file_names = my_new_list

    for file in file_names:
        if file.endswith('.pkl'):
            with open(folder + '/' + file, 'rb') as f:
                pickles.append(pic.load(f))
    
    return file_names, pickles


def single_model_data(summary, file_name):
    """
    sub-function in init_dataframe
    """
    row_new = pd.DataFrame({'file_name': file_name, 
                            'history': None, 
                            'unmasked_acc': summary['unmasked_acc'], 
                            'masked_acc': summary['masked_acc'],
                            'masks': None,
                            'hvar': None, 
                            'lr': summary['lr'],
                            'seed_w': summary['seedw'],
                            'learned_params': summary['learned_params']}, index=[0])
    row_new.at[0, 'masks'] = summary['mask']
    row_new.at[0, 'hvar'] = summary['hvar']
    row_new.at[0, 'history'] = summary['history']

    return row_new

def init_dataframe(folder):
    """
    compiles data generated by generate-data.py into dataframe for plotting in current script
    """

    file_names, summarys = load_pickles(folder)

    df = pd.DataFrame(columns=['hvar', 'masks', 'unmasked_acc', 'masked_acc', 'history', 'seed_w', 'lr'])
    for ix, summary in enumerate(summarys):
        new_row = single_model_data(summary, file_names[ix])
        df = pd.concat([df, new_row], axis=0)
        df.reset_index(drop=True, inplace=True)

    return df

def get_samples(df, weight_seed):
        """
        sub-function for build_arrays
        """

        return df[df.seed_w == weight_seed]

def build_arrays(df, num_samples, datatype=to.float32):
    """
    builds arrays for plotting from dataframe generated by init_dataframe
    """

    len_vec = 10000
    len_curve = 300

    learning_curves = to.zeros((3, num_samples, len_curve), dtype=datatype)
    unmasked_accs = to.zeros((3, num_samples), dtype=datatype)
    masked_accs = to.zeros((3, num_samples), dtype=datatype)
    hvars = to.zeros((3, num_samples, len_vec), dtype=datatype)
    masks = to.zeros((3, num_samples, len_vec), dtype=datatype)

    weight_seeds = np.unique(df.seed_w.values)
    
    for ix_sample, val in enumerate(weight_seeds):

        df_sub = get_samples(df, val)
        df_sub.reset_index(drop=True, inplace=True)
        df.sort_values(by=['learned_params'])
        for ix, row in df_sub.iterrows():

            learning_curves[ix, ix_sample, :] = to.Tensor(row.history)
            unmasked_accs[ix, ix_sample] = to.Tensor(row.unmasked_acc)
            masked_accs[ix, ix_sample] = to.Tensor(row.masked_acc)
            hvars[ix, ix_sample, :] = to.Tensor(row.hvar)
            masks[ix, ix_sample, :] = to.Tensor(row.masks)

    return masked_accs, unmasked_accs, hvars, masks, learning_curves

def scatter_hist(x, y, ax, ax_histx, ax_histy):
    """
    sub-function for plot_scatter
    """

    # no labels
    ax_histx.tick_params(axis="x", labelbottom=False)
    ax_histy.tick_params(axis="y", labelleft=False)

    # the scatter plot:
    ax.scatter(x, y, color='grey', marker='o', alpha=0.4, s=5)
    ax.set_xlabel('Bias-Trained Unit Variance')
    ax.set_ylabel('Mask-Trained Unit Variance')

    # now determine nice limits by hand:
    binwidth = 0.02
    xymax = max(to.max(to.abs(x)), to.max(to.abs(y)))
    lim = (int(xymax/binwidth)) * binwidth

    bins = to.arange(0, lim + binwidth, binwidth)
    ax_histx.hist(x, bins=bins, color='orange')
    ax_histy.hist(y, bins=bins, orientation='horizontal', color='k')
    

def plot_scatter(x, y):
    """
    plots scatter plot from supplementary section
    """

    fig = plt.figure(figsize=(1.5, 1.5), layout='constrained')
    ax = fig.add_gridspec(top=0.75, right=0.75).subplots()
    ax.set(aspect=1)
    ax_histx = ax.inset_axes([0, 1.05, 1, 0.25], sharex=ax)
    ax_histy = ax.inset_axes([1.05, 0, 0.25, 1], sharey=ax)

    scatter_hist(x, y, ax, ax_histx, ax_histy)

    plt.plot()

def plot_learning_curves(m1, sd1, m2, sd2, ma1, sda1, ma2, sda2):
    """
    plots learning curves from figure 1
    """

    colors = ['orange', 'k']
    names = ['Bias', 'Mask']
    vals = [ma1, ma2]
    x_value = range(len(names))
    
    fig = plt.figure(figsize=(1.5, 1), layout='constrained')
    ax = fig.add_axes((0, 0, 1, 1))

    n_epochs = 30
    total_steps = m1.shape[-1]
    x = to.linspace(0, n_epochs, total_steps)

    ax.plot(x, m1, color=colors[0])
    ax.fill_between(x, y1=m1 - sd1, y2=m1 + sd1, color=colors[0], alpha=0.5)
    ax.plot(x, m2, color=colors[1])
    ax.fill_between(x, y1=m2 - sd2, y2=m2 + sd2, color=colors[1], alpha=0.5)

    ax2 = ax.inset_axes([0.6, 0.6, 0.35, 0.35])
    rects = ax2.bar(x_value, vals, yerr=[sda1, sda2], tick_label=names, color=['orange', 'k'])
    ax2.bar_label(rects, labels=[np.round(vals[0], decimals=3), np.round(vals[1], decimals=3)], padding=3, size=5)
    ax2.set_ylim((0.65, 1))
    ax.set_ylim(0, 1.5)
    ax.set_ylabel('Cross Entropy')
    ax.set_xlabel('Epochs')

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax2.spines['top'].set_visible(False)
    ax2.spines['right'].set_visible(False)

    plt.plot()

def plot_grouped_bar(bvm, mvm, null):
    """
    plots bar plot from supplementary section
    """

    groups = ("On Units", "Off Units", "All Units")
    group_means = {
        'Null Model': null,
        'Bias vs. Mask': bvm,
        'Mask vs. Mask': mvm,
    }

    cols = ['grey', 'orange', 'k']

    x = np.arange(len(groups))
    width = 0.25
    multiplier = 0

    fig, ax = plt.subplots(figsize=(1.5, 1), layout='constrained')
    ax = fig.add_axes((0, 0, 1, 1))

    for attribute, measurement in group_means.items():
        offset = width * multiplier
        rects = ax.bar(x + offset, measurement[:, 0], width, label=attribute, yerr=measurement[:, 1], color=cols[multiplier])
        multiplier += 1

    ax.set_ylabel('Normalized Matches')
    ax.set_xticks(x + width, groups)
    ax.legend(loc='lower right')
    ax.set_ylim(0, 1.05)

    plt.plot()

def plot_histogram(h, bins=50):
    """
    plots histogram from figure 2
    """

    histogram_weights_b = to.zeros((bins, 5))
    histogram_weights_m = to.zeros((bins, 5))
    min_val=0.1
    max_val = 1.2
    for ix in range(num_samples):
        hist_weights_b, _ = to.histogram(h[0, ix], range=(min_val, max_val), bins=bins, density=False)
        hist_weights_m, edges = to.histogram(h[1, ix], range=(min_val, max_val), bins=bins, density=False)
        histogram_weights_b[:, ix] = hist_weights_b
        histogram_weights_m[:, ix] = hist_weights_m   

    m_hb, sd_hb = stats(histogram_weights_b)
    m_hm, sd_hm = stats(histogram_weights_m)
    x_axis = edges[1:] - to.diff(edges)[0]/2 


    fig = plt.figure(figsize=(1.5, 1))
    ax = fig.add_axes((0, 0, 1, 1))

    ax.stairs(m_hb, edges, fill=True, color='orange', alpha=0.7, label='Bias Learning')
    ax.fill_between(x_axis, m_hb + sd_hb, m_hb - sd_hb, color='orange', alpha=0.35)
    ax.stairs(m_hm, edges, fill=True, color='k', alpha=0.6, label='Mask Learning')
    ax.fill_between(x_axis, m_hm + sd_hm, m_hm - sd_hm, color='k', alpha=0.3)

    ax.set_ylabel('Counts')
    ax.set_xlabel('Unit Variance')
    ax.legend()

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    plt.plot()

def stats(array1, axis=1):
    """
    basic stats
    """

    mean_array1 = to.mean(array1, axis=axis)
    sd_array1 = to.std(array1, axis=axis)

    return mean_array1, sd_array1

def stats1d(array1):
    """
    basic stats for 1d array
    """

    mean_array1 = to.mean(array1)
    sd_array1 = to.std(array1)

    return mean_array1, sd_array1


def my_invert_bool(array):
    """
    inverts binary array
    """

    return ~ array

def bool_index(array, bool):

    list = []
    for row in range(array.shape[0]):
        list.append(array[row, bool[row, :]])
    
    return list

def cc(array1, array2):
    """
    calculates correlation coefficient
    """

    return to.cov(to.stack((array1, array2))) / (to.std(array1) * to.std(array2))

def mask_stats(mask0, mask1):
    """
    calculates statistics for mask arrays
    """

    if mask0.shape != mask1.shape:
        return print('masks should be the same shape')
    
    total = mask0.shape[-1]
    totals_on = mask0.sum(axis=-1)
    totals_off = my_invert_bool(mask0).sum(axis=-1)

    on_matchs = (mask0 * mask1).sum(axis=-1)
    off_matchs = (my_invert_bool(mask0) * my_invert_bool(mask1)).sum(axis=-1)
    on_matchs_normalized = on_matchs / totals_on
    off_matchs_normalized = off_matchs / totals_off
    total_matchs = on_matchs + off_matchs
    total_matchs_normalized = total_matchs / total

    return on_matchs_normalized, off_matchs_normalized, total_matchs_normalized


if __name__ == "__main__":
 
    rng = Generator(PCG64(None))  # Seeds for null data
    df = init_dataframe('data')  # loads dataset

    #%% BUILD ARRAYS TO PLOT
    num_samples = 5
    m_a, um_a, h, m, lc = build_arrays(df, num_samples)
    m = m > 0

    #%% CURVE + ACCURACY STATS
    m_lc, sd_lc = stats(lc)
    m_ma, sd_ma = stats(m_a)
    m_ua, sd_ua = stats(um_a)

    #%% CALCULATE STATS FOR MASKS
    on_matches_normalized0, off_matches_normalized0, total_matches_normalized0 = mask_stats(m[1], m[2])
    on_matches_normalized, off_matches_normalized, total_matches_normalized = mask_stats(m[1], m[0])

    on_m0, on_sd0 = stats1d(on_matches_normalized0)
    off_m0, off_sd0 = stats1d(off_matches_normalized0)
    total_m0, total_sd0 = stats1d(total_matches_normalized0)
    on_m, on_sd = stats1d(on_matches_normalized)
    off_m, off_sd = stats1d(off_matches_normalized)
    total_m, total_sd = stats1d(total_matches_normalized)

    #%% H STATS
    on_match_indx = (m[0] * m[1]) > 0
    h_b = bool_index(h[0, :, :], on_match_indx)
    h_m = bool_index(h[1, :, :], on_match_indx)
    
    cor = np.zeros(h.shape[1])
    for h_index in range(h[1].shape[0]):
        cor[h_index] = cc(h_b[h_index], h_m[h_index])[0, 1]
    print(cor.mean())
    print(cor.std())

    #%% PLOT SUBFIGURES (FIG2 SUBFIGURES A, C; SUPP SCATTER)

    plot_learning_curves(m_lc[0], sd_lc[0], m_lc[1], sd_lc[1], m_ua[0].numpy(), sd_ua[0].numpy(), m_ua[1].numpy(), sd_ua[1].numpy())
    
    h_index = 0
    plot_scatter(h_b[h_index], h_m[h_index])

    bins = 50
    plot_histogram(h)

    #%% DETERMINE ON/OFF UNITS FOR BIASES VIA THRESHOLDING
    h_bias = h[0, :, :]
    m_thresh = h_bias >= 0.05
    on_thresh, off_thresh, total_thresh = mask_stats(m[1], m_thresh)
    on_m, on_sd = stats1d(on_thresh)
    off_m, off_sd = stats1d(off_thresh)
    total_m, total_sd = stats1d(total_thresh)

    #%% NULL COMPARISON
    N = 10000
    pb1 = (m_thresh.sum(axis=-1)/N).mean(axis=-1) # unnormalized probs
    pb0 = (my_invert_bool(m_thresh).sum(axis=-1)/N).mean(axis=-1)
    pm1 = (m[1].sum(axis=-1)/N).mean(axis=-1)
    pm0 = (my_invert_bool(m[1]).sum(axis=-1)/N).mean(axis=-1)

    null_mask = to.from_numpy(rng.binomial(1, pm1, (num_samples, N))) > 0
    null_bias = to.from_numpy(rng.binomial(1, pb1, (num_samples, N))) > 0
    on_matches_null, off_matches_null, total_matches_null = mask_stats(null_mask, null_bias)
    on_null_m, on_null_sd = stats1d(on_matches_null)
    off_null_m, off_null_sd = stats1d(off_matches_null)
    total_null_m, total_null_sd = stats1d(total_matches_null)

    # PLOT BAR PLOT FROM SUPPLEMENTARY SECTION
    bvm = np.array([[on_m, on_sd],[off_m, off_sd],[total_m, total_sd]])
    mvm = np.array([[on_m0, on_sd0],[off_m0, off_sd0],[total_m0, total_sd0]])
    null = np.array([[on_null_m, on_null_sd],[off_null_m, off_null_sd],[total_null_m, total_null_sd]])
    plot_grouped_bar(bvm, mvm, null)

    #%% PLOT FIGURE 2 SUBFIGURE B
    vthreshs = to.arange(0.001, 0.6, 0.02)
    p_matches = to.zeros((2, len(vthreshs)))
    p_null = to.zeros((2, len(vthreshs)))
    bias_on = to.zeros((2, len(vthreshs)))

    for ix, thresh in enumerate(vthreshs):
        h_bias = h[0, :, :]
        m_thresh = h_bias >= thresh
        on, off, total_thresh = mask_stats(m[1], m_thresh)
        p_matches[0, ix], p_matches[1, ix] = stats1d(total_thresh)
        bias_on[0, ix], bias_on[1, ix] = stats1d(m_thresh.sum(axis=-1) * 1.0)

        N = 10000
        pb1 = (m_thresh.sum(axis=-1)/N).mean(axis=-1) # unnormalized probs
        pb0 = (my_invert_bool(m_thresh).sum(axis=-1)/N).mean(axis=-1)
        pm1 = (m[1].sum(axis=-1)/N).mean(axis=-1)
        pm0 = (my_invert_bool(m[1]).sum(axis=-1)/N).mean(axis=-1)
        null_mask = to.from_numpy(rng.binomial(1, pm1, (num_samples, N))) > 0
        null_bias = to.from_numpy(rng.binomial(1, pb1, (num_samples, N))) > 0
        on2, off2, total_matches_null = mask_stats(null_mask, null_bias)
        p_null[0, ix], p_null[1, ix] = stats1d(total_matches_null)
    
    p0_m = total_m0
    p0_sd = total_sd0
    
    fig = plt.figure(figsize=(1.5, 1))
    ax = fig.add_axes((0, 0, 1, 1))

    xlims = (vthreshs[0] - 0.01, vthreshs[-1] + 0.01)
    ax.plot(vthreshs, p_matches[0], color='orange', label='Bias vs. Mask')
    ax.fill_between(vthreshs, p_matches[0] - p_matches[1], p_matches[0] + p_matches[1], color='orange', alpha=0.5)
    ax.plot(vthreshs, p_null[0], color='grey', label='Null Model')
    ax.fill_between(vthreshs, p_null[0] - p_null[1], p_null[0] + p_null[1], color='grey', alpha=0.5)
    ax.hlines(p0_m, xmin=vthreshs[0], xmax=vthreshs[-1], color='k', ls='--', label='Mask vs. Mask')
    ax.fill_between(vthreshs, np.repeat(p0_m - p0_sd, len(vthreshs)), np.repeat(p0_m - p0_sd, len(vthreshs)), color='k', alpha=0.5)
    ax.set_xlim(xlims[0], xlims[1])
    ax.set_xlabel('Task Variance Threshold')
    ax.set_ylabel('Probability of Match')
    ax.legend(loc='lower right', bbox_to_anchor=(0.98, 0.5))

# %%
