import time
import random
import os
import numpy as np
import re
from multiprocessing import Pool


def time_str():
    return time.strftime("%Y%m%d-%H%M%S", time.localtime())


def seed_everything(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)


def sorted_alphanumeric(data):
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [convert(c) for c in re.split("([0-9]+)", key)]
    return sorted(data, key=alphanum_key)


def str_to_list(model_str):
    neurons = [int(number) for number in model_str.split("-")]
    return neurons


def get_run_arrays(metric, subpath):
    result = []
    for m in [metric, "selected_index"]:
        path = os.path.join(subpath, m)
        unequal_arrays = [
            np.load(os.path.join(path, obj))
            for obj in sorted_alphanumeric(os.listdir(path))
        ]
        min_len = min([array.shape[0] for array in unequal_arrays])
        array = np.array([array[:min_len] for array in unequal_arrays])
        result.append(array)
    return result


def get_n_percent_bounds(len_array, percent):
    bounds = []
    percent /= 100
    for quantile in [percent, 1 - percent]:
        idx = quantile * len_array
        idx = int(idx + 0.5)
        bounds.append(idx)
    return bounds


def get_iqm_mean_conf(array, bootstraps=2000, percentile=95.0):
    cut_off = (100.0 - percentile) / 2
    num_seeds, epochs = array.shape

    iqm_percent = 25.0
    iqm_bounds = get_n_percent_bounds(num_seeds, iqm_percent)

    mean = np.zeros((epochs))
    conf = np.zeros((2, epochs))
    for e in range(epochs):
        # calculate mean
        sorted_epoch = np.sort(array[:, e])
        mean[e] = np.mean(sorted_epoch[iqm_bounds[0] : iqm_bounds[1]])

        # calculate confidence interval
        bootstrap_means = []
        for _ in range(bootstraps):
            vals = np.random.choice(array[:, e], size=num_seeds)
            vals = np.sort(vals)
            vals = vals[iqm_bounds[0] : iqm_bounds[1]]
            bootstrap_means.append(np.mean(vals))
        bootstrap_means = np.array(bootstrap_means)
        conf[:, e] = np.percentile(bootstrap_means, [cut_off, 100 - cut_off])

    return mean, conf


def compute_epoch_mean_conf(args):
    array, e, num_seeds, iqm_bounds, bootstraps, cut_off = args
    # Calculate mean
    sorted_epoch = np.sort(array[:, e])
    mean = np.mean(sorted_epoch[iqm_bounds[0] : iqm_bounds[1]])

    # Calculate confidence interval
    bootstrap_means = []
    for _ in range(bootstraps):
        vals = np.random.choice(array[:, e], size=num_seeds)
        vals = np.sort(vals)
        vals = vals[iqm_bounds[0] : iqm_bounds[1]]
        bootstrap_means.append(np.mean(vals))
    bootstrap_means = np.array(bootstrap_means)
    conf = np.percentile(bootstrap_means, [cut_off, 100 - cut_off])

    return mean, conf


def get_iqm_mean_conf_parallel(array, bootstraps=2000, percentile=95.0):
    cut_off = (100.0 - percentile) / 2
    num_seeds, epochs = array.shape

    iqm_percent = 25.0
    iqm_bounds = get_n_percent_bounds(num_seeds, iqm_percent)

    args = [
        (array, e, num_seeds, iqm_bounds, bootstraps, cut_off) for e in range(epochs)
    ]

    with Pool() as pool:
        results = pool.map(compute_epoch_mean_conf, args)

    means, confs = zip(*results)
    mean = np.array(means)
    conf = np.stack(confs, axis=1)

    return mean, conf
