import itertools
import time
import os
import numpy as np
import scipy
import torch
import sys
import sdeint
import pickle
from tqdm import tqdm
from FCCA.cov_util import calc_cross_cov_mats_from_data
from FCCA.fcca import LQGComponentsAnalysis as LQGCA
from soc import stabilize, gen_init_W, stabilize_discrete, comm_mat

from mpi4py import MPI

reps = 20
inner_reps = 10

if __name__ == '__main__':

    comm = MPI.COMM_WORLD
    Alist = sys.argv[1]
    activitypath = sys.argv[2]
    savepath = sys.argv[3]
    assert(os.path.exists(savepath))
    d=2
    T=3

    if comm.rank == 0:
        with open(Alist, 'rb') as f:
            Alist = pickle.load(f)
            R = pickle.load(f)
    else:
        Alist = None
        R = None

    Alist = comm.bcast(Alist)
    R = comm.bcast(R)
    
    tasks = list(itertools.product(np.arange(reps), np.arange(inner_reps), R))
    tasks = np.array_split(tasks, comm.size)[comm.rank] 
    
    print(len(tasks))

    for i, task in enumerate(tasks):
        t0 = time.time()
        rep, inner_rep, r = task
        rep = int(rep)
        inner_rep = int(inner_rep)
        ridx = list(R).index(r)

        # Load the activity for the task
        with open('%s/x_%d_%d_%d.pkl' % (activitypath, ridx, rep, inner_rep), 'rb') as f:
            x = pickle.load(f)
            seq_idxs = pickle.load(f)

        Aseq = [Alist[idx][ridx] for idx in seq_idxs]
        nn = np.mean([np.linalg.norm(A @ A.T - A.T @ A) for A in Aseq])

        cross_covs = calc_cross_cov_mats_from_data(x, 5)        
        # Fit PCA/FCCA
        cross_covs_rev = [np.linalg.inv(cross_covs[0]) @ c.T @ np.linalg.inv(cross_covs[0]) for c in cross_covs]
        cross_covs = torch.tensor(cross_covs)
        cross_covs_rev = torch.tensor(cross_covs_rev)
        e, Upca = np.linalg.eig(cross_covs[0])
        eigorder = np.argsort(e)[::-1]
        pca_coef = Upca[:, eigorder][:, 0:d]
        lqgmodel = LQGCA(d=d, T=T, rng_or_seed=int(inner_rep))
        lqgmodel.cross_covs = cross_covs
        lqgmodel.cross_covs_rev = cross_covs_rev
        coef_, score = lqgmodel._fit_projection()

        rd = {}
        rd['rep'] = rep
        rd['inner_rep'] = inner_rep
        rd['R'] = r
        rd['pca_coef'] = pca_coef
        rd['fcca_coef'] = coef_
        rd['pca_eig'] = e[eigorder][0:d]
        rd['dim'] = d
        rd['T'] = T
        rd['nn'] = nn  
    
        print('Rank %d Completed task %d/%d in %f' % (comm.rank, i + 1, len(tasks), time.time() - t0))

        # save to file (append)
        with open('%s/rank%d.pkl' % (savepath, comm.rank), 'ab') as f:
            f.write(pickle.dumps(task))
            f.write(pickle.dumps(seq_idxs))
            f.write(pickle.dumps(nn))
            f.write(pickle.dumps(rd))
