import itertools
import os
import time
import numpy as np
import scipy
import torch
import sys
import sdeint
import pickle
from tqdm import tqdm

from mpi4py import MPI

from FCCA.fcca import LQGComponentsAnalysis as LQGCA

reps = 20
inner_reps = 10
M = 100
p = 0.25
g = 2
  
if __name__ == '__main__':

    comm = MPI.COMM_WORLD
    Amats = sys.argv[1]
    savepath = sys.argv[2]
    #if not os.path.exists(savepath):
    #    os.makedirs(savepath)

    dt = 1
    d = 2

    # First load generated A matrics
    if comm.rank == 0:
        with open(Amats, 'rb') as f:
            Alist = pickle.load(f)
            R = pickle.load(f)
    else:
        Alist = None
        R = None
    
    Alist = comm.bcast(Alist)
    R = comm.bcast(R)

    tasks = list(itertools.product(np.arange(reps), np.arange(inner_reps), R))
    tasks = np.array_split(tasks, comm.size)[comm.rank]

    print(len(tasks))

    for i, task in enumerate(tasks):
        t0 = time.time()
        rep, inner_rep, r = task
        rep = int(rep)
        inner_rep = int(inner_rep)
        ridx = list(R).index(r)
        A = Alist[rep][ridx]
        nn = np.linalg.norm(A @ A.T - A.T @ A)

        # Solve for the exact covarinace function and evaluate it at intervals of dt
        Pi = scipy.linalg.solve_continuous_lyapunov(A, -np.eye(A.shape[0]))
        t_ = [j * dt for j in range(10)]
        cross_covs = [scipy.linalg.expm(tau * A) @ Pi for tau in t_]

        cross_covs_rev = [np.linalg.inv(cross_covs[0]) @ c.T @ np.linalg.inv(cross_covs[0]) for c in cross_covs]

        cross_covs = torch.tensor(cross_covs)
        cross_covs_rev = torch.tensor(cross_covs_rev)

        e, Upca = np.linalg.eig(cross_covs[0])
        eigorder = np.argsort(e)[::-1]
        Upca = Upca[:, eigorder][:, 0:d]

        lqgmodel = LQGCA(d=d, T=3, rng_or_seed=int(inner_rep))
        lqgmodel.cross_covs = cross_covs
        lqgmodel.cross_covs_rev = cross_covs_rev

        coef_, score = lqgmodel._fit_projection()
        phi = np.mean(scipy.linalg.subspace_angles(Upca, coef_))
        scores = score            

        # save to file (append)
        print('Rank %d completed task in %f' % (comm.rank, t0 - time.time()))
        with open('%s/rank%d.pkl' % (savepath, comm.rank), 'ab') as f:
            f.write(pickle.dumps(task))
            # f.write(pickle.dumps(A))
            f.write(pickle.dumps(nn))
            f.write(pickle.dumps(Upca))
            f.write(pickle.dumps(coef_))
            f.write(pickle.dumps(phi))
            f.write(pickle.dumps(scores))
