import numpy as np
import torch
import pdb

from .. import utils
from .sampling import sample_n, get_logp, sort_2d, filter_cdf_prob

REWARD_DIM = VALUE_DIM = 1


@torch.no_grad()
def model_rollout_discrete(trajectory_vector, state, model, observation_dim, action_dim, transition_dim, discount,
                           denormalize_rew, denormalize_val, denormalize_action, discretizer):
    state = state.clone()
    sequence_logits = model.decode(trajectory_vector, state)
    probs = torch.nn.functional.softmax(sequence_logits, dim=-1)


@torch.no_grad()
def propose_plan_discrete(model, value_fn, x,
    n_steps, beam_width, n_expand,
    observation_dim, action_dim, discretizer,
    denormalize_rew, denormalize_val,
    denormalize_act, normalize_act,
    discount=0.99, max_context_transitions=None,
    k_obs=None, k_act=None, k_rew=1,
    cdf_obs=None, cdf_act=None, cdf_rew=None,
    verbose=True, previous_actions=None,
    cmaes_sigma_mult=0.05,
    iteration=100
):
    inp = x.clone()
    # convert max number of transitions to max number of tokens
    transition_dim = observation_dim + action_dim + REWARD_DIM + VALUE_DIM
    max_block = max_context_transitions * transition_dim - 1 if max_context_transitions else None

    latent = torch.zeros([1, model.embedding_dim * 2], device="cuda")
    state = model.tok_emb(x)
    logits = model.decode(latent, state)
    logits = logits.reshape([1, -1, model.vocab_size + 1])
    probs = torch.nn.functional.softmax(logits, dim=-1)
    max_idx = torch.argmax(probs, dim=2)
    complete_idx = max_idx[0, :-(max_idx.shape[1] % transition_dim)]
    trajectory = discretizer.reconstruct(complete_idx.reshape([-1, transition_dim]))

    return None

@torch.no_grad()
def rollout(action_seq, model, observation_dim, action_dim, discount, n_steps, k_rew, cdf_rew, k_obs, cdf_obs, k_act, cdf_act, value_fn, beam_width, x_inp, max_block, transition_dim, discretizer):
    x = x_inp.clone()
    ## pass in max numer of tokens to sample function
    sample_kwargs = {
        'max_block': max_block,
        'crop_increment': transition_dim,
    }
    beam_width = action_seq.shape[0]
    ## repeat input for search
    x = x.repeat(beam_width, 1)
    ## construct reward and discount tensors for estimating values
    rewards = torch.zeros(beam_width, n_steps + 1, device=x.device)
    discounts = discount ** torch.arange(n_steps + 1, device=x.device)
    output_x = None
    for t in range(n_steps):
        ## append the proposed action to embedding
        act_discrete = discretizer.discretize(action_seq[:, t], subslice=[observation_dim, observation_dim+action_dim])
        action = torch.tensor(act_discrete).to(x.device)
               
        ## sample actions
        _, p = sample_n(model, x, action_dim, topk=k_act, cdf=cdf_act, **sample_kwargs)
        p = p.reshape(beam_width*action_dim, -1)
        p = filter_cdf_prob(p, cdf_act)
        p = p.reshape(beam_width, action_dim, -1)
        prob = torch.gather(p, dim=-1, index=action.unsqueeze(-1))
        # print(p[0, 0], prob, torch.sum(p[0, 0]), torch.sum(prob))
        # exit()
        #TODO: add a penalty to cdf_act
        act_threshold = 1e-8
        conf = (torch.min(prob.squeeze(), dim=-1)[0] <= act_threshold).float() * -1000
        # conf = torch.min(prob.squeeze(), dim=-1)[0] * 0

        ## TODO: potentially discretilize action
        x = torch.cat((x, action), dim=1)
 
        ## sample reward and value estimate
        x, r_probs = sample_n(model, x, REWARD_DIM + VALUE_DIM, returnx=False, topk=k_rew, cdf=cdf_rew, **sample_kwargs)

        ## optionally, use a percentile or mean of the reward and
        ## value distributions instead of sampled tokens
        r_t, V_t = value_fn(r_probs)

        ## update rewards tensor
        #TODO: changes here
        rewards[:, t] = r_t + conf
        rewards[:, t+1] = V_t

        ## estimate values using rewards up to `t` and terminal value at `t`
        values = (rewards * discounts).sum(dim=-1)

        ## get `beam_width` best actions
        # values, inds = torch.topk(values, beam_width)

        ## index into search candidates to retain `beam_width` highest-reward sequences
        # x = x[inds]
        # rewards = rewards[inds]

        ## sample next observation (unless we have reached the end of the planning horizon)
        if t < n_steps - 1:
            x, _ = sample_n(model, x, observation_dim, returnx=False, topk=k_obs, cdf=cdf_obs, **sample_kwargs)

    ## return best sequence
    best_sequence = x.detach().cpu().numpy()
    value = values.detach().cpu().numpy()
    ## sample next observation (unless we have reached the end of the planning horizon)
    final_x, _ = sample_n(model, x, observation_dim, topk=k_obs, cdf=cdf_obs, **sample_kwargs)
    # final_x, _, output_x = sample_n(model, x, observation_dim, returnx=True, topk=k_obs, cdf=cdf_obs, **sample_kwargs)
    final_x = final_x[:, -observation_dim:]
    final_x = final_x.detach().cpu().numpy()
    if output_x is not None:
        output_x = output_x[:, -observation_dim:]
        output_x = output_x.detach().cpu().numpy()
        return value, best_sequence, output_x
    else:
        return value, best_sequence, final_x

from functools import partial
import contextlib
import os
import random
from .core import beam_plan
from .utils import extract_actions

import numpy as np
import cma

@torch.no_grad()
def cmaes_plan(model, value_fn, x,
    n_steps, beam_width, n_expand,
    observation_dim, action_dim, discretizer,
    discount=0.99, max_context_transitions=None,
    k_obs=None, k_act=None, k_rew=1,
    cdf_obs=None, cdf_act=None, cdf_rew=None,
    verbose=True, previous_actions=None,
    cmaes_sigma_mult=0.05,
    iteration=100,
):
    '''
        x : tensor[ 1 x input_sequence_length ]
    '''

    inp = x.clone()
    best_sequence, sequence, values = beam_plan(
                model, value_fn, inp,
                n_steps, beam_width, n_expand, observation_dim, action_dim,
                discount, max_context_transitions, verbose=verbose,
                k_obs=k_obs, k_act=k_act, cdf_obs=cdf_obs, cdf_act=cdf_act,
                return_all=True,
            )
    values = values.detach().cpu().numpy()
    ## [ horizon x transition_dim ] convert sampled tokens to continuous trajectory
    sequence_recon = discretizer.reconstruct(sequence.reshape(beam_width*n_steps, -1))
    best_sequence_recon = discretizer.reconstruct(best_sequence)

    ## [ action_dim ] index into sampled trajectory to grab first action
    x0 = extract_actions(best_sequence_recon, observation_dim, action_dim)
    x0 = x0.reshape(action_dim * n_steps)
    xs = extract_actions(sequence_recon, observation_dim, action_dim)
    xs = xs.reshape(-1, action_dim * n_steps)

    # compute mean and sigma
    x0 = np.mean(xs, axis=0)
    # x0 = np.zeros(action_dim * n_steps)
    # sigma0 = np.clip(np.std(xs, axis=0), a_min=1e-8, a_max=0.1)
    sigma0 = np.ones(action_dim * n_steps)
    # print(sigma0)

    # convert max number of transitions to max number of tokens
    transition_dim = observation_dim + action_dim + REWARD_DIM + VALUE_DIM
    max_block = max_context_transitions * transition_dim - 1 if max_context_transitions else None

    all_samples = []
    all_fX = []
    all_sequence = []
    time_snapshots = []
    # all_fX += [-max(values.detach().cpu().numpy())]
    # all_sequence += [best_sequence]

    # with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
    with torch.no_grad():
        es = cma.CMAEvolutionStrategy(x0, cmaes_sigma_mult, {'maxfevals': iteration, 
        # es = cma.CMAEvolutionStrategy(x0, cmaes_sigma_mult, { 
            # 'scaling_of_variables': sigma0, 'bounds': [-1, 1], 'verbose': -3, 'CMA_sampler': cma.sampler.GaussDiagonalSampler})
            'tolx': 0.02, 'CMA_stds': sigma0, 'bounds': [-1, 1], 'verbose': -3})
        func = partial(rollout, model=model, observation_dim=observation_dim, action_dim=action_dim, discount=discount, n_steps=n_steps, k_rew=k_rew, 
                cdf_rew=cdf_rew, k_obs=k_obs, cdf_obs=cdf_obs, k_act=k_act, cdf_act=cdf_act, value_fn=value_fn, beam_width=beam_width, x_inp=inp, max_block=max_block, 
                transition_dim=transition_dim, discretizer=discretizer,)
        def wrapped_func(action_seq):
            action_seq = np.array(action_seq)
            value, x, final_obs = func(action_seq=action_seq.reshape((-1, n_steps, action_dim)))
            return -value, x, final_obs
        num_evals = 0
        # while num_evals < iteration:
        init_X = es.ask()
        es.tell([xs[i] for i in range(xs.shape[0])], [-values[i] for i in range(values.shape[0])])
        while not es.stop():
            new_samples = es.ask()
            if len(new_samples) + num_evals > iteration:
                # random.shuffle(new_samples)
                # new_samples = new_samples[:iteration - num_evals]
                results = wrapped_func(new_samples)
                new_fX = [results[0][i] for i in range(results[0].shape[0])]
                new_sequence = [results[1][i] for i in range(results[1].shape[0])]
                es.tell(new_samples, new_fX)
            else:
                results = wrapped_func(new_samples)
                new_fX = [results[0][i] for i in range(results[0].shape[0])]
                new_sequence = [results[1][i] for i in range(results[1].shape[0])]
                es.tell(new_samples, new_fX)
            all_fX += new_fX
            all_samples += new_samples
            all_sequence += new_sequence
            num_evals += len(new_fX)
            time_snapshots.append((num_evals, es.result.fbest))

    # assert num_evals == iteration
    # if min(new_fX) < es.result.fbest:
    #     xbest = new_samples[new_fX.index(min(new_fX))] # argmin x
    #     best_response = xbest
    # else:
    #     best_response = es.result.xbest

    # Extra step
    print(max(values), min(all_fX), es.result.fbest)
    all_fX += new_fX
    all_samples += new_samples
    all_sequence += new_sequence
    argmin = all_fX.index(min(all_fX))
    best_response = all_sequence[argmin]
    ## [ batch_size x (n_context + n_steps) x transition_dim ]
    best_response = best_response.reshape(-1, transition_dim)

    ## crop out context transitions
    ## [ batch_size x n_steps x transition_dim ]
    best_response = best_response[-n_steps:]
    return best_response

import contextlib
import os
import numpy as np