import numpy as np
import scipy
import os
import sys
import pandas as pd
from tqdm import tqdm
import time
import pickle
import pdb
import matplotlib.pyplot as plt 
from sklearn.decomposition import PCA
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression

from FCCA.fcca import LQGComponentsAnalysis as fcca
from data_loader import load_peanut
from decoders import lr_decoder

FCCA_params = {'T':3, 'n_init':10}
loader_params = {'bin_width':50, 'filter_fn':'none', 'filter_kwargs':{}, 'boxcox':0.5, 'spike_threshold':100}
decoder_params = {'trainlag': 0, 'testlag': 0, 'decoding_window': 6}
epochs = np.arange(2, 18, 2)
dimvals = np.arange(1, 31)


class PCA_wrapper():

    def __init__(self, d):
        self.pcaobj = PCA()
        self.dim = d

    def fit(self, X):
        if np.ndim(X) == 3:
            X = np.reshape(X, (-1, X.shape[-1]))
        self.pcaobj.fit(X)
        self.coef_ = self.pcaobj.components_.T[:, 0:self.dim]

    def score(self):
        return sum(self.pcaobj.explained_variance_ratio_[0:self.dim])

# Calculates subspace angles across a dimreduc dataframe
def get_ssa(dimreduc_df, dim):
    with open(dimreduc_df, 'rb') as f:
        rl = pickle.load(f)

    dimreduc_df = pd.DataFrame(rl)
    dimreduc_df = dimreduc_df.loc[dimreduc_df['dim'] == dim]
    ss_angles = np.zeros((epochs.size, 5, dim))

    for i, epoch in enumerate(epochs):
        for f, fold in enumerate(np.arange(5)):
            df_pca = dimreduc_df.loc[(dimreduc_df['epoch'] == epoch) & 
                                     (dimreduc_df['dimreduc_method'] == 'PCA') &
                                     (dimreduc_df['fold_idx'] == f)]
            df_fcca = dimreduc_df.loc[(dimreduc_df['epoch'] == epoch) & 
                                     (dimreduc_df['dimreduc_method'] == 'FCCA') &
                                     (dimreduc_df['fold_idx'] == f)]

            assert(df_pca.shape[0] == 1)
            assert(df_fcca.shape[0] == 1)
            
            ss_angles[i, f, :] = scipy.linalg.subspace_angles(df_pca.iloc[0]['coef'], 
                                                              df_fcca.iloc[0]['coef'])
            
    return ss_angles

def dimreduc_(data_path):

    # For each epoch, load and do dimreduc
    rl = []
    dimvals = [10]
    for epoch in epochs:
        dat = load_peanut(data_path, epoch, **loader_params)
        X = dat['spike_rates'].squeeze()
        for fold_idx, (train_idxs, test_idxs) in enumerate(KFold(n_splits=5, shuffle=False).split(X)):
            # PCA can be fit all at once
            pcamodel = PCA_wrapper(d=dimvals[-1])
            pcamodel.fit(X[train_idxs])

            # estimate data statistics of FCCA up front
            fccamodel = fcca(**FCCA_params)
            fccamodel.estimate_data_statistics(X[train_idxs])            

            for k, dim in tqdm(enumerate(dimvals)):
                fccamodel.fit_projection(d = dim)

                r1 = {}
                r1['dimreduc_method'] = 'FCCA'
                r1['dim'] = dim
                r1['epoch'] = epoch
                r1['fold_idx'] = fold_idx
                r1['coef'] = fccamodel.coef_
                rl.append(r1)

                r2 = {}
                r2['dimreduc_method'] = 'PCA'
                r2['dim'] = dim
                r2['epoch'] = epoch
                r2['fold_idx'] = fold_idx
                r2['coef'] = pcamodel.coef_[:, 0:dim]
                pcamodel.dim = dim
                rl.append(r2)

                # save as we go along
                with open('dimreduc_tmp.pkl', 'ab') as f:
                    f.write(pickle.dumps(r1))
                    f.write(pickle.dumps(r2))

        # Save results
        with open('dimreduc.pkl', 'wb') as f:
            f.write(pickle.dumps(rl))

def decoding_(data_path, dimreduc_df, save_name='decoding.pkl'):
    # Open the dimreduc df and then pass each result to the decoder
    with open(dimreduc_df, 'rb') as f:
        dimreduc_list = pickle.load(f)

    dimreduc_df = pd.DataFrame(dimreduc_list)
    rl = []
    # Filter by fold idx and epoch
    for i, epoch in tqdm(enumerate(epochs)):
        dat = load_peanut(data_path, epoch, **loader_params)
        X = dat['spike_rates'].squeeze()
        Y = dat['behavior'].squeeze()
        for fold_idx, (train_idxs, test_idxs) in enumerate(KFold(n_splits=5, shuffle=False).split(X)):
            df_ = dimreduc_df.loc[(dimreduc_df['fold_idx'] == fold_idx) & (dimreduc_df['epoch'] == epoch)]
            Xtrain = X[train_idxs]
            Xtest = X[test_idxs]
            Ytrain = Y[train_idxs]
            Ytest = Y[test_idxs]

            for k in range(df_.shape[0]):
                
                # Apply projection
                xtrain = Xtrain @ df_.iloc[k]['coef']
                xtest = Xtest @ df_.iloc[k]['coef']

                r2_pos, _ = lr_decoder(xtest, xtrain, Ytest, Ytrain, **decoder_params)
                r = {}
                r['r2'] = r2_pos
                # Copy over all other dataframe columns
                r.update(df_.iloc[k].to_dict())
                rl.append(r)

    # Save
    with open(save_name, 'wb') as f:
        f.write(pickle.dumps(rl))

def plot_decoding(decoding_df):
    with open(decoding_df, 'rb') as f:
        decoding_results = pickle.load(f)

    decoding_df = pd.DataFrame(decoding_results)
    r2 = np.zeros((epochs.size, 5, dimvals.size, 2))

    for i, epoch in enumerate(epochs):
        for f, fold in enumerate(np.arange(5)):
            for d, dim in enumerate(dimvals):            
                for k, dimreduc_method in enumerate(['PCA', 'FCCA']):
                    df_ = decoding_df.loc[(decoding_df['epoch'] == epoch) &
                                          (decoding_df['fold_idx'] == f) &
                                          (decoding_df['dim'] == dim) &
                                          (decoding_df['dimreduc_method'] == dimreduc_method)]
                    assert(df_.shape[0] == 1)
                    r2[i, f, d, k] = df_.iloc[0]['r2']

    fig, ax = plt.subplots(1, 1, figsize=(4, 4))
    

    # Move the fold indices up and then reshape to calc std
    fca_mean = np.mean(r2[..., 1], axis=(0, 1))
    fca_std = np.std(r2[..., 1].reshape((-1, dimvals.size)), axis=0)/np.sqrt(np.prod(r2.shape[0:2]))

    pca_mean = np.mean(r2[..., 0], axis=(0, 1))
    pca_std = np.std(r2[..., 0].reshape((-1, dimvals.size)), axis=0)/np.sqrt(np.prod(r2.shape[0:2]))

    ax.fill_between(dimvals, fca_mean - fca_std, fca_mean + fca_std, color='r', alpha=0.25)
    ax.fill_between(dimvals, pca_mean - pca_std, pca_mean + pca_std, color='k', alpha=0.25)

    ax.plot(dimvals, fca_mean, color='r')
    ax.plot(dimvals, pca_mean, color='k')

    print('Peak delta: %d' % dimvals[np.argmax(fca_mean - pca_mean)])
    print('Fractional improvement: %f' % np.max(np.divide(fca_mean - pca_mean, pca_mean)))
    axin = ax.inset_axes([0.6, 0.1, 0.35, 0.35])

    pca_auc = np.sum(np.mean(r2[..., 0], axis=1), axis=1)
    fca_auc = np.sum(np.mean(r2[..., 1], axis=1), axis=1)

    # Run a signed rank test
    _, p = scipy.stats.wilcoxon(pca_auc, fca_auc, alternative='less')
    print(p)

    axin.scatter(np.zeros(epochs.size), pca_auc, color='k', alpha=0.75, s=3)
    axin.scatter(np.ones(epochs.size), fca_auc, color='r', alpha=0.75, s=3)
    axin.plot(np.array([(0, 1) for _ in range(epochs.size)]).T, 
              np.array([(y1, y2) for y1, y2 in zip(pca_auc, 
                                                   fca_auc)]).T,
               color='k', alpha=0.5)
    axin.set_yticks([])
    axin.set_ylabel('Decoding AUC', fontsize=10)
    axin.set_xlim([-0.5, 1.5])
    axin.set_xticks([0, 1])
    axin.set_xticklabels(['FFC', 'FBC'], fontsize=10)

    ax.legend(['FBC', 'FFC'], loc='upper left', fontsize=14)
    #ax.set_title('Rat Hippocampus', fontsize=14)
    ax.set_xlabel('Dimension', fontsize=14)
    ax.set_ylabel('Position Decoding ' + r'$r^2$', fontsize=14)    
    ax.tick_params(axis='both', labelsize=12)

    fig.tight_layout()
    fig.savefig('hpc_r2.pdf', bbox_inches='tight', pad_inches=0)

def plot_ssa(dimreduc_df):

    DIM = 2
    ss_angles = get_ssa(dimreduc_df, DIM)

    fig, ax = plt.subplots(1, 1, figsize=(2, 5))
    medianprops = {'linewidth':0}
    bplot = ax.boxplot(np.mean(ss_angles, axis=-1).ravel(), patch_artist=True, 
                       medianprops=medianprops, notch=True, vert=True, showfliers=False)
    ax.set_ylim([0, np.pi/2])
    ax.tick_params(axis='both', labelsize=12)
    ax.set_xticks([])
    ax.set_ylabel('FCCA/PCA avg. subspace angle (rads)', fontsize=14)
    ax.set_yticks([0, np.pi/8, np.pi/4, 3 * np.pi/8, np.pi/2])
    ax.set_yticklabels(['0', r'$\pi/8$', r'$\pi/4$', r'$3\pi/8$', r'$\pi/2$'])
    for patch in bplot['boxes']:
        patch.set_facecolor('k')
        patch.set_alpha(0.75)

    fig.tight_layout()
    fig.savefig('hpc_ssa.pdf', bbox_inches='tight', pad_inches=0)

# Initialization variance
def init_var_dimreduc(data_path):
    dimvals = np.arange(1, 31, 2)
    # For each epoch, load and do dimreduc
    rl = []
    for epoch in epochs:
        dat = load_peanut(data_path, epoch, **loader_params)
        X = dat['spike_rates'].squeeze()
        for fold_idx, (train_idxs, test_idxs) in enumerate(KFold(n_splits=5, shuffle=False).split(X)):
            # estimate data statistics of FCCA up front
            fccamodel = fcca(**FCCA_params)
            fccamodel.estimate_data_statistics(X[train_idxs])            
            cross_covs = fccamodel.cross_covs
            cross_covs_rev = fccamodel.cross_covs_rev
            # Set n_init to 1
            fccamodel.n_init = 1
            for k, dim in tqdm(enumerate(dimvals)):
                for rep in range(20):
                    fccamodel = fcca(T=3, n_init=1, rng_or_seed=rep)
                    fccamodel.cross_covs = cross_covs
                    fccamodel.cross_covs_rev = cross_covs_rev
                    fccamodel.fit_projection(d = dim)
                    fcca_score = fccamodel.score(X=X[train_idxs])

                    r1 = {}
                    r1['dimreduc_method'] = 'FCCA'
                    r1['dim'] = dim
                    r1['epoch'] = epoch
                    r1['fold_idx'] = fold_idx
                    r1['coef'] = fccamodel.coef_
                    r1['score'] = fcca_score
                    r1['rep'] = rep
                    rl.append(r1)

    # Save results
    with open('initvar_dimreduc.pkl', 'wb') as f:
        f.write(pickle.dumps(rl))

def plot_initvar(decoding_df):

    with open(decoding_df, 'rb') as f:
        rl = pickle.load(f)

    df = pd.DataFrame(rl)
    dimvals = np.arange(1, 31, 2)
    reps = np.arange(20)

    r2 = np.zeros((epochs.size, dimvals.size, reps.size, 5))
    for i, epoch in enumerate(epochs):
        for j, d in tqdm(enumerate(dimvals)):
            for fold_idx in range(5):
                for r in reps:
                    df_ = df.loc[(df['epoch'] == epoch) &
                                 (df['dim'] == d) &
                                 (df['fold_idx'] == fold_idx) &
                                 (df['rep'] == r)]
                    assert(df_.shape[0] == 1)
                    r2[i, j, r, fold_idx] = df_.iloc[0]['r2']

    r2avg = np.mean(r2, axis=-1)
    r2avgctr = r2avg - np.median(r2avg, axis=-1, keepdims=True)
    # Plot the min/max across all reps/data files relative to to the median decoding accuracy across dimensions
    fig, ax = plt.subplots(figsize=(4, 4))
    r2avgctr = r2avgctr.transpose((1, 0, 2))
    ax.boxplot(r2avgctr.reshape((r2avgctr.shape[0], -1)).T,  medianprops={'linewidth':0})
    ax.set_xticks(np.arange(r2avgctr.shape[0])[::2] + 1)
    ax.set_xticklabels(dimvals[::2] + 1)
    #ax.set_xticklabels([dimvals[0], dimvals[3], dimvals[6], dimvals[9], dimvals[12]])
    ax.set_xlabel('Dimension')
    ax.set_ylabel('FCCA Decoding spread around median')
    ax.set_ylim([-0.25, 0.25])
    fig.savefig('hpc_initvar.pdf', bbox_inches='tight', pad_inches=0)
    #ax.set_xticks([1, 5, 10, 15, 25])
    #ax.set_xticklabels(dimvals)

def plot_initvar_ssa(dimreduc1_df, dimreduc2_df):

    # Plot the variability in subspace angles between different initializations 
    # of FCCA/PCA
    with open(dimreduc1_df, 'rb') as f:
        rl = pickle.load(f)

    df1 = pd.DataFrame(rl)

    with open(dimreduc2_df, 'rb') as f:
        rl = pickle.load(f)

    df2 = pd.DataFrame(rl)
    dimvals = np.arange(1, 31, 2)
    reps = np.arange(20)

    ssa = np.zeros((epochs.size, dimvals.size, reps.size, 5))
    for i, epoch in enumerate(epochs):
        for j, d in tqdm(enumerate(dimvals)):
            for fold_idx in range(5):
                df2_ = df2.loc[(df2['epoch'] == epoch) &
                               (df2['dim'] == d) &
                               (df2['fold_idx'] == fold_idx) &
                               (df2['dimreduc_method'] == 'PCA')]
                assert(df2_.shape[0] == 1)
                for r in reps:
                    df_ = df1.loc[(df1['epoch'] == epoch) &
                                 (df1['dim'] == d) &
                                 (df1['fold_idx'] == fold_idx) &
                                 (df1['rep'] == r)]
                    assert(df_.shape[0] == 1)

                    ssa_ = scipy.linalg.subspace_angles(df2_.iloc[0]['coef'], 
                                                        df_.iloc[0]['coef'])
                    ssa[i, j, r, fold_idx] = np.mean(ssa_)

    ssaavg = np.mean(ssa, axis=-1)
    # Plot the min/max across all reps/data files relative to to the median decoding accuracy across dimensions
    fig, ax = plt.subplots(figsize=(4, 4))
    ssaavg = ssaavg.transpose((1, 0, 2))
    ax.boxplot(ssaavg.reshape((ssaavg.shape[0], -1)).T,  medianprops={'linewidth':0})
    ax.set_xticks(np.arange(ssaavg.shape[0])[::2] + 1)
    ax.set_xticklabels(dimvals[::2] + 1)
    #ax.set_xticklabels([dimvals[0], dimvals[3], dimvals[6], dimvals[9], dimvals[12]])
    ax.set_xlabel('Dimension')
    #ax.set_ylabel('FCCA Decoding spread around median')
    ax.set_ylim([0, np.pi/2])
    ax.set_ylabel('FCCA/PCA avg. subspace angle (rads)', fontsize=14)
    ax.set_yticks([0, np.pi/8, np.pi/4, 3 * np.pi/8, np.pi/2])
    ax.set_yticklabels(['0', r'$\pi/8$', r'$\pi/4$', r'$3\pi/8$', r'$\pi/2$'])
    fig.savefig('hpc_initvar2.pdf', bbox_inches='tight', pad_inches=0)

def T_var_dimreduc(data_path):
    #dimvals = np.arange(1, 31, 2)
    dimvals = np.array([2])
    # For each epoch, load and do dimreduc
    max_T = 8
    for epoch in epochs:
        rl = []
        dat = load_peanut(data_path, epoch, **loader_params)
        X = dat['spike_rates'].squeeze()
        for fold_idx, (train_idxs, test_idxs) in enumerate(KFold(n_splits=5, shuffle=False).split(X)):
            # estimate data statistics of FCCA up front
            fccamodel = fcca(**FCCA_params)
            # Set T to the maximum T so enough cross-covariances matrices are estimated
            fccamodel.T = max_T
            fccamodel.estimate_data_statistics(X[train_idxs])            
            cross_covs = fccamodel.cross_covs
            cross_covs_rev = fccamodel.cross_covs_rev
            for k, dim in tqdm(enumerate(dimvals)):
                for T in np.array([2, 3, 4, 5, 6, 7, 8]):
                    fccamodel = fcca(T=T, n_init=10)
                    fccamodel.cross_covs = cross_covs
                    fccamodel.cross_covs_rev = cross_covs_rev
                    fccamodel.fit_projection(d = dim)

                    r1 = {}
                    r1['dimreduc_method'] = 'FCCA'
                    r1['dim'] = dim
                    r1['epoch'] = epoch
                    r1['fold_idx'] = fold_idx 
                    r1['coef'] = fccamodel.coef_
                    r1['T'] = T
                    rl.append(r1)

        # Save results
        with open('Tvar_dimreduc_d2.pkl', 'ab') as f:
            f.write(pickle.dumps(rl))

def plot_Tvar(dimreduc_df1, dimreduc_df2):

    rl = []
    with open(dimreduc_df1, 'rb') as f:
        while True:
            try:
                rl.extend(pickle.load(f))
            except EOFError:
                break
    df = pd.DataFrame(rl)

    with open(dimreduc_df2, 'rb') as f:
        rl = pickle.load(f)
    df2 = pd.DataFrame(rl)

    dimvals = np.array([2])
    T = np.array([2, 4, 5, 6, 7, 8]) 
    Tref = 3
    ssa = np.zeros((epochs.size, dimvals.size, T.size, 5)) 
    ssa_ref = np.zeros((epochs.size, dimvals.size, T.size, 5))
    for i, epoch in enumerate(epochs):
        for j, d in tqdm(enumerate(dimvals)):
            for fold_idx in range(5):
                # Measure ssa relative to Tref                
                df_ref = df.loc[(df['epoch'] == epoch) &
                             (df['dim'] == d) &
                             (df['fold_idx'] == fold_idx) &
                             (df['T'] == Tref)]
                try:
                    assert(df_ref.shape[0] == 1)
                except:
                    pdb.set_trace()

                # Measure ssa relative to PCA
                df_ref2 = df2.loc[(df2['epoch'] == epoch) &
                             (df2['dim'] == d) &
                             (df2['fold_idx'] == fold_idx) &
                             (df2['dimreduc_method'] == 'PCA')]
                
                assert(df_ref2.shape[0] == 1)

                for k, t in enumerate(T):
                    df_ = df.loc[(df['epoch'] == epoch) &
                                 (df['dim'] == d) &
                                 (df['fold_idx'] == fold_idx) &
                                 (df['T'] == t)]
                    assert(df_.shape[0] == 1)
                    ssa[i, j, k, fold_idx] = \
                        np.mean(scipy.linalg.subspace_angles(df_ref.iloc[0]['coef'],
                                                             df_.iloc[0]['coef']))
                    ssa_ref[i, j, k, fold_idx] = \
                        np.mean(scipy.linalg.subspace_angles(df_ref2.iloc[0]['coef'],
                                                             df_.iloc[0]['coef']))

    DIM_IDX = 0    
    # Plot boxplots of subspace angles for each T
    fig, ax = plt.subplots(1, 2, figsize=(8, 4))
    x = ssa[:, DIM_IDX, ...].transpose((0, 2, 1)).reshape((-1, T.size))
    medianprops = {'linewidth':0}
    bplot = ax[0].boxplot(x, patch_artist=True, medianprops=medianprops)
    ax[0].set_xticklabels(T)
    ax[0].set_xlabel('T')
    ax[0].set_ylabel('Subspace angle (rads) to T=3', fontsize=14)
    ax[0].set_yticks([0, np.pi/8, np.pi/4, 3 * np.pi/8, np.pi/2])
    ax[0].set_yticklabels(['0', r'$\pi/8$', r'$\pi/4$', r'$3\pi/8$', r'$\pi/2$'])
    ax[0].set_xlabel('T')
    # for patch in bplot['boxes']:
    #     patch.set_facecolor('k')
    #     patch.set_alpha(0.75)
    
    x = ssa_ref[:, DIM_IDX, ...].transpose((0, 2, 1)).reshape((-1, T.size))
    bplot = ax[1].boxplot(x, patch_artist=True, medianprops=medianprops)
    ax[1].set_xticklabels(T)
    ax[1].set_xlabel('T')
    ax[1].set_ylabel('Subspace angle (rads) to PCA', fontsize=14)
    ax[1].set_yticks([0, np.pi/8, np.pi/4, 3 * np.pi/8, np.pi/2])
    ax[1].set_yticklabels(['0', r'$\pi/8$', r'$\pi/4$', r'$3\pi/8$', r'$\pi/2$'])
    ax[1].set_xlabel('T')
    for patch in bplot['boxes']:
        patch.set_facecolor('white')
        # patch.set_alpha(0.75)

    fig.tight_layout()
    fig.savefig('Tvar.pdf', bbox_inches='tight', pad_inches=1)


def ssa_supp_plot(dimreduc_df):

    ssa = []
    for d in dimvals:
        ssa.append(get_ssa(dimreduc_df, d))

    ss_angles = np.array([np.mean(s, axis=-1) for s in ssa])
    # Plot mean/std across dimensions
    fig, ax = plt.subplots(1, 1, figsize=(4, 4))

    ssa_avg = np.mean(ss_angles, axis=(1, 2))
    ssa_std = np.std(ss_angles, axis=(1, 2))/np.sqrt(np.prod(ss_angles.shape[1:]))

    ax.plot(dimvals, ssa_avg, linestyle=':')
    ax.fill_between(dimvals, ssa_avg - ssa_std, ssa_avg + ssa_std, alpha=0.5)
    ax.set_ylabel('FFC/FBC Subspace angle (rads)', fontsize=14)
    ax.set_yticks([0, np.pi/8, np.pi/4, 3 * np.pi/8, np.pi/2])
    ax.set_yticklabels(['0', r'$\pi/8$', r'$\pi/4$', r'$3\pi/8$', r'$\pi/2$'])
    ax.set_xlabel('Dimension')
    fig.tight_layout()
    fig.savefig('ssa_across_dims_hpc.png')

if __name__ == '__main__':

    data_path = 'hpc_data.obj' 

    # Main dimensionality reduction + decoding
    dimreduc_(data_path)
    decoding_(data_path, 'dimreduc.pkl')
    plot_decoding('decoding.pkl')
    plot_ssa('dimreduc.pkl')
    ssa_supp_plot('dimreduc.pkl')

    # Variation over initialization
    init_var_dimreduc(data_path)
    decoding_(data_path, 'initvar_dimreduc.pkl', 'initvar_decoding.pkl')
    plot_initvar('initvar_decoding.pkl')
    plot_initvar_ssa('initvar_dimreduc.pkl', 'dimreduc.pkl')

    # Variation across T
    T_var_dimreduc(data_path)
    plot_Tvar('Tvar_dimreduc_d2.pkl', 'dimreduc.pkl')
