import numpy as np
import torch
import pdb

from .. import utils
from .sampling import sample_n, get_logp, sort_2d, filter_cdf_prob, sample_n_continuous

REWARD_DIM = VALUE_DIM = 1

@torch.no_grad()
def model_rollout_continuous(model, x, latent, denormalize_rew, denormalize_val, discount, prob_penalty_weight=1e4):
    prediction = model.decode(latent, x[:, -1, :model.observation_dim])
    prediction = prediction.reshape([-1, model.transition_dim])

    r_t, V_t = prediction[:, -3], prediction[:, -2]
    if denormalize_rew is not None:
        r_t = denormalize_rew(r_t).reshape([x.shape[0], -1])
    if denormalize_val is not None:
        V_t = denormalize_val(V_t).reshape([x.shape[0], -1])

    # discounts with terminal flag
    terminal = prediction[:, -1].reshape([x.shape[0], -1])
    discounts = torch.cumprod(torch.ones_like(r_t) * discount * (1-terminal), dim=-1)
    values = torch.sum(r_t[:,:-1] * discounts[:, :-1], dim=-1) + V_t[:,-1] * discounts[:,-1]
    prob_penalty = prob_penalty_weight * torch.mean(torch.square(latent), dim=-1)
    objective = values - prob_penalty
    return objective.cpu().numpy(), prediction.cpu().numpy()


from functools import partial
import contextlib
import os
import random
from .core import beam_plan
from .utils import extract_actions_continuous

import numpy as np
import cma


@torch.no_grad()
def sample(model, x, denormalize_rew, denormalize_val, discount, steps, nb_samples=4096, rounds=8):
    indicies = torch.randint(0, model.model.K-1, size=[nb_samples, steps // model.latent_step],
                             device=x.device, dtype=torch.int32)
    prediction_raw = model.decode_from_indices(indicies, x[:, 0, :model.observation_dim])
    prediction = prediction_raw.reshape([-1, model.transition_dim])

    r_t, V_t = prediction[:, -3], prediction[:, -2]
    if denormalize_rew is not None:
        r_t = denormalize_rew(r_t).reshape([indicies.shape[0], -1])
    if denormalize_val is not None:
        V_t = denormalize_val(V_t).reshape([indicies.shape[0], -1])

    discounts = torch.cumprod(torch.ones_like(r_t) * discount, dim=-1)
    values = torch.sum(r_t[:,:-1] * discounts[:, :-1], dim=-1) + V_t[:,-1] * discounts[:,-1]
    optimal = prediction_raw[values.argmax()]
    print(values.max().item())
    return optimal.cpu().numpy()


@torch.no_grad()
def sample_with_prior(prior, value_fn, model, x, denormalize_rew, denormalize_val, discount, steps, nb_samples=4096, rounds=8,
                      likelihood_weight=5e2, prob_threshold=0.05, uniform=False, return_info=False):
    state = x[:, 0, :model.observation_dim]
    optimals = []
    optimal_values = []
    info = {}
    for round in range(rounds):
        contex = None
        acc_probs = torch.ones([1]).to(x)
        for step in range(steps//model.latent_step):
            logits, _ = prior(contex, state) # [B x t x K]
            probs = raw_probs = torch.softmax(logits[:, -1, :], dim=-1) # [B x K]
            if uniform:
                valid = probs > prob_threshold
                probs = valid/valid.sum(dim=-1)[:, None]
            if step == 0:
                samples = torch.multinomial(probs, num_samples=nb_samples//rounds, replacement=True) # [B, M]
            else:
                samples = torch.multinomial(probs, num_samples=1, replacement=True)  # [B, M]
            samples_prob = torch.cat([torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(raw_probs, samples)]) # [B, M]
            acc_probs = acc_probs * samples_prob.reshape([-1])
            if not contex is None:
                contex = torch.cat([contex, samples.reshape([-1, 1])], dim=1)
            else:
                contex = samples.reshape([-1, step+1]) # [(B*M) x t]

        prediction_raw = model.decode_from_indices(contex, state)
        prediction = prediction_raw.reshape([-1, model.transition_dim])

        r_t = prediction[:, -3]
        if value_fn is None:
            V_t = prediction[:, -2]
        else:
            r_t = r_t - 1
            V_t = value_fn(prediction[:, :model.observation_dim-2].cpu().numpy(),
                           prediction[:, model.observation_dim:model.observation_dim+model.action_dim].cpu().numpy())
            V_t = torch.tensor(V_t, device=r_t.device)
        terminals = prediction[:, -1].reshape([contex.shape[0], -1])
        if denormalize_rew is not None:
            r_t = denormalize_rew(r_t).reshape([contex.shape[0], -1])
        if denormalize_val is not None:
            V_t = denormalize_val(V_t).reshape([contex.shape[0], -1])

        discounts = torch.cumprod(torch.ones_like(r_t) * discount, dim=-1)
        values = torch.sum(r_t[:,:-1] * discounts[:, :-1], dim=-1) + V_t[:,-1] * discounts[:,-1]
        likelihood_bonus = likelihood_weight*torch.log(torch.clamp(acc_probs, 0,
                                                                   prob_threshold**(steps//model.latent_step)))
        info["probs"] = acc_probs.cpu().numpy()
        info["returns"] = values.cpu().numpy()
        info["predictions"] = prediction_raw.cpu().numpy()
        max_idx = (values+likelihood_bonus).argmax()
        optimal_value = values[max_idx]
        optimal = prediction_raw[max_idx]
        optimals.append(optimal)
        optimal_values.append(optimal_value.item())

    max_idx = np.array(optimal_values).argmax()
    optimal = optimals[max_idx]
    print(f"predicted max value {optimal_values[max_idx]}")
    if return_info:
        return optimal.cpu().numpy(), info
    else:
        return optimal.cpu().numpy()


@torch.no_grad()
def sample_with_prior_tree(prior, model, x, denormalize_rew, denormalize_val, discount, steps, samples_per_latent=16, likelihood_weight=0.0):
    contex = None
    state = x[:, 0, :model.observation_dim]
    acc_probs = torch.ones([1]).to(x)
    for step in range(steps//model.latent_step):
        logits, _ = prior(contex, state) # [B x t x K]
        probs = torch.softmax(logits[:, -1, :], dim=-1) # [B x K]
        samples = torch.multinomial(probs, num_samples=samples_per_latent, replacement=True) # [B, M]
        samples_prob = torch.cat([torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(probs, samples)]) # [B, M]
        acc_probs = acc_probs.repeat_interleave(samples_per_latent, 0) * samples_prob.reshape([-1])
        if not contex is None:
            contex = torch.cat([torch.repeat_interleave(contex, samples_per_latent, 0), samples.reshape([-1, 1])],
                               dim=1)
        else:
            contex = samples.reshape([-1, step+1]) # [(B*M) x t]

    prediction_raw = model.decode_from_indices(contex, state)
    prediction = prediction_raw.reshape([-1, model.transition_dim])

    r_t, V_t = prediction[:, -3], prediction[:, -2]


    if denormalize_rew is not None:
        r_t = denormalize_rew(r_t).reshape([contex.shape[0], -1])
    if denormalize_val is not None:
        V_t = denormalize_val(V_t).reshape([contex.shape[0], -1])

    discounts = torch.cumprod(torch.ones_like(r_t) * discount, dim=-1)
    values = torch.sum(r_t[:,:-1] * discounts[:, :-1], dim=-1) + V_t[:,-1] * discounts[:,-1]
    likelihood_bouns = likelihood_weight*torch.log(acc_probs)
    max_idx = (values+likelihood_bouns).argmax()
    optimal = prediction_raw[max_idx]
    print(f"predicted max value {values[max_idx]}, likelihood {acc_probs[max_idx]} with bouns {likelihood_bouns[max_idx]}")
    return optimal.cpu().numpy()

@torch.no_grad()
def beam_with_prior(prior, value_fn, model, x, denormalize_rew, denormalize_val, discount, steps,
                    beam_width, n_expand, prob_threshold=0.05, likelihood_weight=5e2, prob_acc="product"):
    contex = None
    state = x[:, 0, :model.observation_dim]
    acc_probs = torch.ones([1]).to(x)
    for step in range(steps//model.latent_step):
        logits, _ = prior(contex, state) # [B x t x K]
        probs = torch.softmax(logits[:, -1, :], dim=-1) # [B x K]
        nb_samples = beam_width * n_expand if step == 0 else n_expand
        samples = torch.multinomial(probs, num_samples=nb_samples, replacement=True) # [B, M]
        samples_prob = torch.cat([torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(probs, samples)]) # [B, M]
        if prob_acc in ["product", "expect"]:
            acc_probs = acc_probs.repeat_interleave(nb_samples, 0) * samples_prob.reshape([-1])
        elif prob_acc == "min":
            acc_probs = torch.minimum(acc_probs.repeat_interleave(nb_samples, 0), samples_prob.reshape([-1]))
        if not contex is None:
            contex = torch.cat([torch.repeat_interleave(contex, nb_samples, 0), samples.reshape([-1, 1])],
                               dim=1)
        else:
            contex = samples.reshape([-1, step+1]) # [(B*M) x t]

        prediction_raw = model.decode_from_indices(contex, state)
        prediction = prediction_raw.reshape([-1, model.transition_dim])
        r_t = prediction[:, -3]
        if value_fn is None:
            V_t = prediction[:, -2]
        else:
            r_t = r_t - 1
            V_t = value_fn(prediction[:, :model.observation_dim-2].cpu().numpy(),
                           prediction[:, model.observation_dim:model.observation_dim+model.action_dim].cpu().numpy())
            V_t = torch.tensor(V_t, device=r_t.device)

        if denormalize_rew is not None:
            r_t = denormalize_rew(r_t).reshape([contex.shape[0], -1])
        if denormalize_val is not None:
            V_t = denormalize_val(V_t).reshape([contex.shape[0], -1])


        discounts = torch.cumprod(torch.ones_like(r_t) * discount, dim=-1)
        values = torch.sum(r_t[:,:-1] * discounts[:, :-1], dim=-1) + V_t[:,-1] * discounts[:,-1]
        if prob_acc == "product":
            likelihood_bonus = likelihood_weight*torch.log(torch.clamp(acc_probs, 0,
                                                                       prob_threshold**(steps//model.latent_step)))
        elif prob_acc == "min":
            likelihood_bonus = likelihood_weight*torch.log(torch.clamp(acc_probs, 0, prob_threshold))
        nb_top = beam_width if step < (steps//model.latent_step-1) else 1
        if prob_acc == "expect":
            values_with_b, index = torch.topk(values*acc_probs, nb_top)
        else:
            values_with_b, index = torch.topk(values+likelihood_bonus, nb_top)
        contex = contex[index]
        acc_probs = acc_probs[index]

    optimal = prediction_raw[index[0]]
    print(f"predicted max value {values[0]}")
    return optimal.cpu().numpy()

@torch.no_grad()
def beam_with_uniform(prior, model, x, denormalize_rew, denormalize_val, discount, steps,
                    beam_width, n_expand,  prob_threshold=0.05):
    contex = None
    state = x[:, 0, :model.observation_dim]
    acc_probs = torch.ones([1]).to(x)
    for step in range(steps//model.latent_step):
        logits, _ = prior(contex, state) # [B x t x K]
        probs = torch.softmax(logits[:, -1, :], dim=-1) # [B x K]
        nb_samples = beam_width * n_expand if step == 0 else n_expand
        valid = probs > prob_threshold
        samples = torch.multinomial(valid/valid.sum(dim=-1), num_samples=nb_samples, replacement=True) # [B, M]
        samples_prob = torch.cat([torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(probs, samples)]) # [B, M]
        acc_probs = acc_probs.repeat_interleave(nb_samples, 0) * samples_prob.reshape([-1])
        if not contex is None:
            contex = torch.cat([torch.repeat_interleave(contex, nb_samples, 0), samples.reshape([-1, 1])],
                               dim=1)
        else:
            contex = samples.reshape([-1, step+1]) # [(B*M) x t]

        prediction_raw = model.decode_from_indices(contex, state)
        prediction = prediction_raw.reshape([-1, model.transition_dim])
        r_t, V_t = prediction[:, -3], prediction[:, -2]

        if denormalize_rew is not None:
            r_t = denormalize_rew(r_t).reshape([contex.shape[0], -1])
        if denormalize_val is not None:
            V_t = denormalize_val(V_t).reshape([contex.shape[0], -1])


        discounts = torch.cumprod(torch.ones_like(r_t) * discount, dim=-1)
        values = torch.sum(r_t[:,:-1] * discounts[:, :-1], dim=-1) + V_t[:,-1] * discounts[:,-1]
        nb_top = beam_width if step < (steps//model.latent_step-1) else 1
        values, index = torch.topk(values, nb_top)
        contex = contex[index]
        acc_probs = acc_probs[index]

    optimal = prediction_raw[index[0]]
    print(f"predicted max value {values[0]}")
    return optimal.cpu().numpy()

@torch.no_grad()
def beam_mimic(prior, model, x, denormalize_rew, denormalize_val, discount, steps,
                    beam_width, n_expand,  prob_threshold=0.05):
    contex = None
    state = x[:, 0, :model.observation_dim]
    acc_probs = torch.ones([1]).to(x)
    for step in range(steps//model.latent_step):
        logits, _ = prior(contex, state) # [B x t x K]
        probs = torch.softmax(logits[:, -1, :], dim=-1) # [B x K]
        nb_samples = beam_width * n_expand if step == 0 else n_expand
        samples = torch.multinomial(probs, num_samples=nb_samples, replacement=True) # [B, M]
        samples_prob = torch.cat([torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(probs, samples)]) # [B, M]
        acc_probs = acc_probs.repeat_interleave(nb_samples, 0) * samples_prob.reshape([-1])
        if not contex is None:
            contex = torch.cat([torch.repeat_interleave(contex, nb_samples, 0), samples.reshape([-1, 1])],
                               dim=1)
        else:
            contex = samples.reshape([-1, step+1]) # [(B*M) x t]

        nb_top = beam_width if step < (steps//model.latent_step-1) else 1
        values, index = torch.topk(acc_probs, nb_top)
        contex = contex[index]
        acc_probs = acc_probs[index]

    prediction_raw = model.decode_from_indices(contex, state)
    optimal = prediction_raw[0]
    print(f"value {values[0]}, prob {acc_probs[0]}")
    return optimal.cpu().numpy()


@torch.no_grad()
def enumerate_all(model, x, denormalize_rew, denormalize_val, discount):
    indicies = torch.range(0, model.model.K-1, device=x.device, dtype=torch.int32)
    prediction_raw = model.decode_from_indices(indicies, x[:, 0, :model.observation_dim])
    prediction = prediction_raw.reshape([-1, model.transition_dim])

    r_t, V_t = prediction[:, -2], prediction[:, -1]
    if denormalize_rew is not None:
        r_t = denormalize_rew(r_t).reshape([indicies.shape[0], -1])
    if denormalize_val is not None:
        V_t = denormalize_val(V_t).reshape([indicies.shape[0], -1])

    discounts = torch.cumprod(torch.ones_like(r_t) * discount, dim=-1)
    values = torch.sum(r_t[:,:-1] * discounts[:, :-1], dim=-1) + V_t[:,-1] * discounts[:,-1]
    optimal = prediction_raw[values.argmax()]
    return optimal.cpu().numpy()


@torch.no_grad()
def propose_plan_continuous(model, x):
    latent = torch.zeros([1, model.trajectory_embd], device="cuda")
    prediction = model.decode(latent, x[:, 0, :model.observation_dim])
    prediction = prediction.reshape([-1, model.transition_dim])

    return prediction.cpu().numpy()

@torch.no_grad()
def cmaes_plan_continuous(model, x, observation_dim, action_dim,
    denormalize_rew, denormalize_val, discount=0.99, iteration=1000,
):
    '''
        x : tensor[ 1 x input_sequence_length ]
    '''

    inp = x.clone()
    '''
    best_sequence, sequence, values = beam_plan(
                model, value_fn, inp,
                n_steps, beam_width, n_expand, observation_dim, action_dim,
                discount, max_context_transitions, verbose=verbose,
                k_obs=k_obs, k_act=k_act, cdf_obs=cdf_obs, cdf_act=cdf_act,
                return_all=True,
            )
    values = values.detach().cpu().numpy()
    ## [ horizon x transition_dim ] convert sampled tokens to continuous trajectory
    sequence_recon = discretizer.reconstruct(sequence.reshape(beam_width*n_steps, -1))
    best_sequence_recon = discretizer.reconstruct(best_sequence)

    ## [ action_dim ] index into sampled trajectory to grab first action
    x0 = extract_actions(best_sequence_recon, observation_dim, action_dim)
    x0 = x0.reshape(action_dim * n_steps)
    xs = extract_actions(sequence_recon, observation_dim, action_dim)
    xs = xs.reshape(-1, action_dim * n_steps)

    # compute mean and sigma
    x0 = np.mean(xs, axis=0)
    # x0 = np.zeros(action_dim * n_steps)
    # sigma0 = np.clip(np.std(xs, axis=0), a_min=1e-8, a_max=0.1)
    sigma0 = np.ones(action_dim * n_steps)
    # print(sigma0)
    '''

    # convert max number of transitions to max number of tokens
    transition_dim = observation_dim + action_dim + REWARD_DIM + VALUE_DIM + 1

    latent_init = np.zeros([1, model.trajectory_embd])
    all_samples = []
    all_fX = []
    all_sequence = []
    time_snapshots = []
    # all_fX += [-max(values.detach().cpu().numpy())]
    # all_sequence += [best_sequence]

    with torch.no_grad():
        es = cma.CMAEvolutionStrategy(latent_init,
                                      0.2, {'maxfevals': iteration, 'bounds': [-1.0, 1.0], 'verbose': -3, "popsize": 4096})
        func = partial(model_rollout_continuous, model=model, discount=discount,
                       denormalize_rew=denormalize_rew, denormalize_val=denormalize_val)
        def wrapped_func(latent):
            latent = torch.tensor(np.stack(latent), device=inp.device, dtype=torch.float32)
            value, x = func(latent=latent, x=inp[0].repeat(latent.shape[0], 1, 1))
            return -value, x.reshape([latent.shape[0], -1, x.shape[-1]])
        num_evals = 0
        # while num_evals < iteration:
        # init_X = es.ask()
        # es.tell([xs[i] for i in range(xs.shape[0])], [-values[i] for i in range(values.shape[0])])
        while not es.stop():
            new_samples = es.ask()
            if len(new_samples) + num_evals > iteration:
                # random.shuffle(new_samples)
                # new_samples = new_samples[:iteration - num_evals]
                results = wrapped_func(new_samples)
                new_fX = [results[0][i] for i in range(results[0].shape[0])]
                new_sequence = [results[1][i] for i in range(results[1].shape[0])]
                es.tell(new_samples, new_fX)
            else:
                results = wrapped_func(new_samples)
                new_fX = [results[0][i] for i in range(results[0].shape[0])]
                new_sequence = [results[1][i] for i in range(results[1].shape[0])]
                es.tell(new_samples, new_fX)
            all_fX += new_fX
            all_samples += new_samples
            all_sequence += new_sequence
            num_evals += len(new_fX)
            time_snapshots.append((num_evals, es.result.fbest))

    # assert num_evals == iteration
    # if min(new_fX) < es.result.fbest:
    #     xbest = new_samples[new_fX.index(min(new_fX))] # argmin x
    #     best_response = xbest
    # else:
    #     best_response = es.result.xbest

    # Extra step
    print(min(all_fX), es.result.fbest)
    all_fX += new_fX
    all_samples += new_samples
    all_sequence += new_sequence
    argmin = all_fX.index(min(all_fX))
    best_response = all_sequence[argmin]
    ## [ batch_size x (n_context + n_steps) x transition_dim ]
    best_response = best_response.reshape(-1, transition_dim)
    #best_response = propose_plan_continuous(model, x)

    return best_response