import numpy as np

from sklearn.pipeline import Pipeline

from multiprocessing import Pool
from itertools import repeat

from typing import Tuple, List
from tqdm import tqdm
from math import comb

def decode(
    acts : Tuple[np.ndarray, np.ndarray], # Network activations to stimuli
    lbls : Tuple[np.ndarray, np.ndarray], # Labels for stimuli
    pipe : Pipeline,                      # Decoding pipeline (from sklearn)
    pop_size : int = None,                # Size of population to use for decoding
    max_reps : int = 5                    # Numer of repetition of decoding experiment
    ) -> List[float]:

    rng = np.random.default_rng()

    # Unpack train & test activations and labels
    X_train, X_test = acts
    y_train, y_test = lbls

    pop_full, _ = X_train.shape

    # Construct num_reps views of such array of given pop size
    pop_size = pop_size if pop_size else pop_full
    num_reps = min(comb(pop_full, pop_size), max_reps)

    rep_idxs = [rng.choice(pop_full, size = pop_size, replace = False) for _ in range(num_reps)]

    global __scoreit
    def __scoreit(X_train, y_train, X_test, y_test):
        pipe.fit(X_train, y_train)
        return pipe.score(X_test, y_test)

    with Pool() as P:
        # First we train and test on the training set
        G_train = (X_train[idxs].T for idxs in rep_idxs)
        G_test  = (X_train[idxs].T for idxs in rep_idxs)
        Y_train = repeat(y_train)
        Y_test  = repeat(y_train)

        train_out = list(P.starmap(__scoreit, zip(G_train, Y_train, G_test, Y_test)))        

        # Then we train on training set and test on validation set
        G_train = (X_train[idxs].T for idxs in rep_idxs)
        G_test  = (X_test [idxs].T for idxs in rep_idxs)
        Y_train = repeat(y_train)
        Y_test  = repeat(y_test)

        valid_out = list(P.starmap(__scoreit, zip(G_train, Y_train, G_test,  Y_test)))        

    return (train_out, valid_out),\
           (np.mean(train_out, axis = 0), np.mean(valid_out, axis = 0)),\
           (np.std (train_out, axis = 0), np.std (valid_out, axis = 0))