import numpy as np
import pickle
import pdb
import scipy
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
from RNNvanilla import *
import glob
from run_Dale import get_inputs, dt, start_index

from sklearn.decomposition import PCA
from FCCA.fcca import LQGComponentsAnalysis as FCCA

models = glob.glob('Dale_models_Areps/*.pkl')
dat = scipy.io.loadmat('data/condsForSimJ2moMuscles.mat')
u, targets, n_conditions, n_delays = get_inputs(dat)

# Pad inputs with trailing zeros to allow network to relax
upad = np.concatenate([u, np.zeros((u.shape[0], 200, u.shape[-1]))], axis=1)

# Fit a range of T
T = np.array([1, 2, 3])
dims = np.array([2, 8])
for model in tqdm(models):
    rnn = torch.load(model)
    # Plot the results
    # Ensure rnn is in evaluation mode
    rnn.eval()
    # Transfer to device
    device = 'cuda'
    rnn.to(device)
    rnn.to_device(device)
    targets = torch.tensor(targets, dtype=torch.float32, requires_grad=False).to(device)
    inputs = torch.tensor(upad, dtype=torch.float32, requires_grad=False).to(device)

    # Forward pass
    with torch.no_grad():
        y_seq, x = rnn.forward(inputs, dt=dt, disable_progress_bar=True, return_state=True)

    # Bring to CPU
    x = x.to('cpu').detach().numpy()
    # Fit PCA
    pcamodel1 = PCA().fit(x.reshape((-1, x.shape[-1])))
    pcamodel2 = PCA().fit(x[:, start_index:, :].reshape((-1, x.shape[-1])))

    # Concatenate trials together - make sure to use the start index
    x = x[:, start_index:, :]
    x = np.concatenate(x, axis=0)

    # Fit LQGCA
    for t in T:
        for d in tqdm(dims):
            for k in range(20):
                fccamodel1 = FCCA(T=3, d=d, n_init=1, rng_or_seed=k)
                fccamodel1.fit(x)

                rl = []
                r = {}
                r['model'] = model
                r['dimreduc_method'] = 'PCA'
                r['dim'] = d
                r['T'] = t
                r['coef'] = pcamodel1.components_.T[:, 0:d]
                r['rep'] = k
                rl.append(r)

                r = {}
                r['model'] = model
                r['dimreduc_method'] = 'FCCA'
                r['dim'] = d
                r['T'] = t
                r['coef'] = fccamodel1.coef_
                r['rep'] = k
                rl.append(r)

                with open('dimreduc_results_dale_concat_reps.pkl', 'ab') as f:
                    f.write(pickle.dumps(rl))
