import itertools
import os
import time
import numpy as np
import scipy
import sys
import pickle
from tqdm import tqdm
from FCCA.cov_util import calc_cross_cov_mats_from_data
from mpi4py import MPI

reps = 20
inner_reps = 10
M = 100
p = 0.25
g = 2
 
if __name__ == '__main__':

    # gen_matrices()
    # print('generated!')

    comm = MPI.COMM_WORLD
    Amats = sys.argv[1]
    activitypath = sys.argv[2]
    savepath = sys.argv[3]

    # First load generated A matrics
    if comm.rank == 0:
        with open(Amats, 'rb') as f:
            Alist = pickle.load(f)
            R = pickle.load(f)
    else:
        Alist = None
        R = None

    Alist = comm.bcast(Alist)
    R = comm.bcast(R)
    tasks = list(itertools.product(np.arange(reps), np.arange(inner_reps), R))
    tasks = np.array_split(tasks, comm.size)[comm.rank]


    print(len(tasks))

    for i, task in enumerate(tasks):
        t0 = time.time()
        rep, inner_rep, r = task
        rep = int(rep)
        inner_rep = int(inner_rep)
        ridx = list(R).index(r)
        A = Alist[rep][ridx]
        nn = np.linalg.norm(A @ A.T - A.T @ A)

        # Load the activity for the task
        fname = '%s/x_%d_%d_%d.pkl' % (activitypath, ridx, rep, inner_rep)
        if os.path.exists(fname):
            with open(fname, 'rb') as f:
                x = pickle.load(f)


            # Sample spikes and calculate cross-covariance matrices
            ccm1 = calc_cross_cov_mats_from_data(x, 5)        
            boxcox = 0.5
            spike_rates_trials = []
            for _ in tqdm(range(100)):
                spike_counts = np.random.poisson(np.exp(x))
                spike_rates = np.array([(np.power(spike_count, boxcox) - 1)/boxcox 
                                        for spike_count in spike_counts])
                spike_rates = np.array([scipy.stats.boxcox(spike_count, boxcox) for spike_count in spike_counts])
                spike_rates_trials.append(spike_rates)
            spike_rates_trials = np.array(spike_rates_trials)
            print('Calculating ccm')
            ccm2 = calc_cross_cov_mats_from_data(spike_rates_trials, 5)        

            print('Rank %d Completed task %d/%d in %f' % (comm.rank, i + 1, len(tasks), time.time() - t0))

            # save to file (append)
            with open('%s/rank%d.pkl' % (savepath, comm.rank), 'ab') as f:
                f.write(pickle.dumps(task))
                f.write(pickle.dumps(A))
                f.write(pickle.dumps(nn))
                f.write(pickle.dumps(ccm1))
                f.write(pickle.dumps(ccm2))
        else:
            print('activity file not found, skipping')
