import torch
from robustopt_torch.distributions import DiscreteDist
# import torch.multiprocessing as mp
# from torch.multiprocessing.spawn import _wrap as _mp_error_wrapper

# Return a shifted generator for a subprocess
def init_generator_for_subproc(state):
    rng = torch.Generator()
    rng.set_state(state)
    return rng

# Get nearly even sample sizes between subprocesses
def split_samples(num_samp, procs):
    min_samp = num_samp // procs
    extras = num_samp % procs
    return [min_samp + (i < extras) for i in range(procs)]

# Get index of sample batches
def batch_indexes(data_size, procs):
    batches = split_samples(data_size, procs)
    batch_beg = [sum(batches[0:i]) for i in range(procs)]
    batch_end = batch_beg[1:]
    batch_end.append(data_size)
    return [(beg, end) for beg, end in zip(batch_beg, batch_end)]

# Generate seeds for each process
def generate_seeds(procs, seed_generator = None):
    return [torch.randint(int(1e8), (1,), generator = seed_generator).item()
            for i in range(procs)]

def generate_states(procs, seed_generator = None):
    proc_offset = torch.randint(int(1e8), (1,), generator =
                                seed_generator).item()
    seeds = [proc_offset + i for i in range(procs)]
    return [torch.Generator().manual_seed(seed).get_state() for seed in seeds]

# Serialize a distribution to be sent to another process
def serialize_mu(mu, procs, vals_q, weights_q):
    if isinstance(mu, DiscreteDist):
        for i in range(procs):
            vals_q.put(mu.vals)
            weights_q.put(mu.weights)
        return ((vals_q, weights_q, mu.sampling_kernel), "DiscreteDist")

    raise ValueError(f"Serialization for distribution type {type(mu)} is not implemented")

# Deserialize a distribution from another process
def deserialize_mu(serialized_mu):
    if serialized_mu[1] == "DiscreteDist":
        vals, weights = serialized_mu[0][0].get(), serialized_mu[0][1].get()
        samp_kern = serialized_mu[0][2]
        return DiscreteDist(vals, weights = weights, sampling_kernel = samp_kern)

    raise ValueError(f"Deserialization for distribution type {serialized_mu[1]}" \
                      " is not implemented")

# Serialize a function that has been curried with pytorch tensors
def serialize_func(func, procs, args_q, keywords_q):
    if hasattr(func, "args"):
        if isinstance(func.args, tuple):
            for i in range(procs):
                args_q.put(func.args)
    if hasattr(func, "keywords"):
        if isinstance(func.keywords, dict):
            for i in range(procs):
                keywords_q.put(func.keywords)
    return {"func" : func, "args" : args_q, "keywords" : keywords_q}

# Replace the args and keywords of a curried function
def replace_args(part, new_args, new_keywords):
    _,_, f = part.__reduce__()
    f, _, _, n = f
    part.__setstate__( (f, new_args, new_keywords, n) )

# Deserialize a curried function
def deserialize_func(serialized_func):
    func = serialized_func["func"]
    replace_args(func, serialized_func["args"].get(),
                 serialized_func["keywords"].get())
    return func

# Calculate a mean across processes with different sample sizes
def calculate_mean_across_procs(proc_sample_sizes, results_list):
    return sum([proc_sample_sizes[i] * result for i, result in results_list]) / \
        sum(proc_sample_sizes)

# def my_startprocesses(rank, results_queue, exit_queue, procs, lam, mu_vals,
#                       optimizer, optimizer_args, context = None):
#     if context is None:
#     mp = mp.get_context(start_method)
#     error_queues = []
#     processes = []
#     for i in range(nprocs):
#         error_queue = mp.SimpleQueue()
#         process = mp.Process(
#             target=_wrap,
#             args=(fn, i, args, error_queue),
#             daemon=daemon,
#         )
#         process.start()
#         error_queues.append(error_queue)
#         processes.append(process)

#     context = ProcessContext(processes, error_queues)
#     if not join:
#         return context

#     # Loop on join until it returns True or raises an exception.
#     while not context.join():
#         pass
