"""Cross-entropy method baseline."""
import torch
import time
import dill
from typing import Callable
import concurrent.futures


def get_best(
        samples: list[torch.Tensor],
        costs: list[float],
        k: int,
        ) -> list[torch.Tensor]:
    """Return the best `k` samples according to the given costs."""
    indices = list(range(len(samples)))
    indices = sorted(indices, key=lambda i: costs[i])
    best_indices = indices[:k]
    best = [samples[i] for i in best_indices]
    return best


def fit_distribution(
        samples: list[torch.Tensor],
        ) -> tuple[torch.Tensor, torch.Tensor]:
    """Fit Gaussian distribution to the given set of samples. Return mean and
    standard deviation."""
    sigma, mu = torch.std_mean(torch.stack(samples), dim=0, keepdim=False)
    return mu, sigma


def cost_helper(
        x,
        dilled_cost,
        ) -> float:
    """Helper that loads the cost function and returns its result on `x`.
    """
    get_cost = dill.loads(dilled_cost)
    y = get_cost(torch.tensor(x))
    return y


def cem(
        max_timesteps: int,
        action_size: int,
        sample_n: int,
        elite_n: int,
        horizon_len: int,
        init_stdev: float,
        cem_inner_iter_n: int,
        timeout_s: float,
        get_quality: Callable[[torch.Tensor], float],
        get_cost: Callable[[torch.Tensor], float],
        worker_n: int,
        verbose: bool,
        ) -> tuple[list[tuple[float, torch.Tensor]], bool]:
    """
    This implements CEM_MPC as described in:
    Sample-efficient Cross-Entropy Method for Real-time Planning, 2020

    We allow separate cost and quality measures for flexibility. Only cost
    is used for optimization. The returned list attaches quality to each
    candidate.

    The return `bool` is `True` if and only if a candidate with quality greater
    than 1.0 was found, at which point the optimization terminates. Otherwise,
    the optimization timed out and the `bool` is `False`.

    The list is list of pairs containing optimization time (seconds) and
    current candidate trajectory.
    """
    T = max_timesteps
    d = action_size  # action size
    dilled_cost = dill.dumps(get_cost)

    # Log all executed trajectories
    start_t = time.time()
    log = list()

    # Empty trajectory that will be optimized MPC style
    actions = list[list[float]]()

    # Horizon action distribution mean
    mu = torch.zeros((horizon_len, d))

    if worker_n is not None and worker_n >= 2:
        p = concurrent.futures.ProcessPoolExecutor(max_workers=worker_n)
    else:
        p = None

    # CEM MPC algorithm
    for t in range(T):
        # Shift-initialize mu from previous step
        mu = mu.tolist()[1:] + [[0.0 for _ in range(d)]]
        mu = torch.tensor(mu)

        # Standard deviation is reset on every time-step
        sigma = torch.full_like(mu, init_stdev)

        for i in range(cem_inner_iter_n):
            try:
                distribution = torch.distributions.normal.Normal(mu, sigma)
            except:
                # Distribution might degenerate, and fitting might fail.
                # in that case, we break the inner loop
                break
            samples = [
                distribution.sample()
                for _ in range(sample_n)
            ]
            if worker_n is None or worker_n <= 1:
                costs = [
                    get_cost(torch.tensor(actions + x.tolist()))
                    for x in samples
                ]
            else:
                cost_futures = [
                    p.submit(
                        cost_helper,
                        x=actions + x.tolist(),
                        dilled_cost=dilled_cost,
                    )
                    for x in samples
                ]
                costs = [
                    future.result()
                    for future in cost_futures
                ]
            elite_set = get_best(samples, costs, k=elite_n)
            print(f"CEM inner best cost ({i}/{cem_inner_iter_n}): {min(costs)}")
            mu, sigma = fit_distribution(elite_set)

        # Execute first action of mean sequence mu
        actions.append(mu[0].tolist())

        # Log trajectory
        current_t = time.time()-start_t
        current_actions = torch.tensor(actions)
        log.append((current_t, current_actions))

        quality = get_quality(current_actions)
        if verbose:
            print(f"CEM: quality t={t} T={T}: {quality}")
        if quality >= 1.0:
            return log, True

        # Check timeout
        if current_t > timeout_s:
            if verbose:
                print("CEM: Timeout!")
            break
    return log, False
