import numpy as np
from tqdm import tqdm
import pickle
import glob
import torch
import matplotlib.pyplot as plt
import scipy
import pandas as pd
import torch
import torch.nn as nn

import sys
sys.path.append('..')
from utils import apply_df_filters

rl = []
with open('dimreduc_results_dale_concat_reps.pkl', 'rb') as f:
    while True:
        try:
            rl.extend(pickle.load(f))
        except EOFError:
            break

df = pd.DataFrame(rl)


# Load A matrices
with open('Alist_rnn.pkl', 'rb') as f:
    Alist = pickle.load(f)

# Narrow down the dataframe by each model. 
# Load the model and get the readout direction. Compare the subspaces to the 
# the observed readout direction
# models = np.unique(df['model'].values)
T = np.array([1, 2, 3])
dims = np.array([2, 8])
reps = np.arange(10)

# 5 reps of initialization matrices, 20 nn, T, dims, FCCA reps
ssafp = np.zeros((5, 20, T.size, dims.size, reps.size), dtype=object)

for i in range(5):
    models = glob.glob('Dale_models_Areps/model_%d_*.pkl' % i)
    for j, model in enumerate(models):
        rnn_model = torch.load(model)
        C0 = rnn_model.output_weight_matrix.detach().cpu().numpy().T
        for k, t in enumerate(T):
            for h, d in enumerate(dims):
                for r in reps:
                    df_ = apply_df_filters(df, model=model, T=t, dim=d, rep=r)
                    df_fca1 = apply_df_filters(df_, dimreduc_method='FCCA')
                    C1 = df_fca1.iloc[0]['coef']
                    assert(C1.shape[1] == d)

                    df_pca1 = apply_df_filters(df_, dimreduc_method='PCA')
                    assert(df_pca1.shape[0] == 1)
                    C1 = df_pca1.iloc[0]['coef']
                    assert(C1.shape[1] == d)
                    ssafp[i, j, k, h, r] = scipy.linalg.subspace_angles(df_fca1.iloc[0]['coef'],
                                                                df_pca1.iloc[0]['coef'])

# Plot of the ssa as a function of NN
ssa_fpall = np.zeros((len(Alist[0]), len(Alist), reps.size))
fig, ax = plt.subplots(1, 1, figsize=(4, 4))
nn = np.zeros((len(Alist[0]), len(Alist)))
for i1 in range(len(Alist)):
    models = glob.glob('Dale_models_Areps/model_%d_*.pkl' % i1)
    for i2 in range(len(Alist[0])):
        A = Alist[i1][i2]
        nn[i2, i1] = np.linalg.norm(A @ A.T - A.T @ A)
        for k in reps:
            ssa_fpall[i2, i1, k] = np.mean(ssafp[i1, i2, -1, 0, k])

nn = np.mean(nn, axis=1)
ax.plot(nn, np.mean(ssa_fpall, axis=(1, 2)), color='k')
ax.fill_between(nn, np.mean(ssa_fpall, axis=(1, 2)) - np.std(ssa_fpall, axis=(1, 2)),
                np.mean(ssa_fpall, axis=(1, 2)) + np.std(ssa_fpall, axis=(1, 2)), 
                alpha=0.25, color='k')
ax.set_ylim([0, np.pi/2])
ax.set_yticks([0, np.pi/8, np.pi/4, 3*np.pi/8, np.pi/2])
ax.set_xscale('log')
fig.savefig('dale_rnn_ssa.pdf')
# Perhaps omit the data point giving rise to the spike