# Add lambda for different data.
import numpy as np
import torch
import torch.nn.functional as F
from scipy.stats import rankdata

from src.util import DEFAULT_DEVICE


def get_lambda(lambda_method, future_rewards, heuristic_discount):
    if lambda_method == 'portion':
        lambdas = portion_lambda(future_rewards) * heuristic_discount
    elif lambda_method=='sigmoid':
        lambdas = sigmoid_lambda(future_rewards) * heuristic_discount
    elif lambda_method == 'cut':
        cut = np.percentile(future_rewards, heuristic_discount*100) * heuristic_discount
        lambdas = (future_rewards>cut).astype(float)
    else:
        raise NotImplementedError
    return lambdas

def portion_lambda(future_rewards):
    if len(future_rewards) == 1:
        return [1.0]
    else:
        rank_future_rewards = rankdata(future_rewards)
        return (rank_future_rewards-1)/(rank_future_rewards-1).max()

def sigmoid_lambda(future_rewards):
    if len(future_rewards)==1:
        return [1.0]
    else:
        max_future_rewards = future_rewards.max()
        min_future_rewards = future_rewards.min()
        normalized_future_rewards = 12*(future_rewards - min_future_rewards)/(max_future_rewards - min_future_rewards)-6
        # Make the value in -6 to 6
        return 1/(1+np.exp(-normalized_future_rewards))

def add_lambda(traj_data, heuristic_discount,discount,lambda_method):
    # Calcualte total reward for each traj
    if lambda_method == 'constant':
        for temp_data in traj_data:
            temp_data['lambda'] = np.ones(len(temp_data['returns']))*heuristic_discount
    elif 'traj' in lambda_method:
        future_rewards = []
        for temp_data in traj_data:
            future_rewards += [((temp_data['returns'] - temp_data['rewards'])/discount).mean()]
            # We care about the next-state heuristics instead of the current-state heuristics.
        future_rewards = np.array(future_rewards)
        lambdas = get_lambda(lambda_method.split('traj_')[1], future_rewards, heuristic_discount)
        for temp_data in traj_data:
            temp_data['lambda'] = np.ones(len(temp_data['returns']))*lambdas[0]
            lambdas = np.delete(lambdas,0)
    else:
        raise NotImplementedError


def get_heuristic_mix_h_v(returns, rewards, discount, heuristic_method, next_v, temperature, lambda_values):
    future_rewards = (returns-rewards)/discount
    if heuristic_method == 'return':
        heuristics = future_rewards
    elif heuristic_method == 'max':
        heuristics = torch.max(future_rewards, next_v)
    elif heuristic_method == 'softmax':
        two_v = torch.cat([future_rewards.reshape(-1,1), next_v.reshape(-1,1)], 1)
        discount_vector = torch.FloatTensor([[1/temperature,1/temperature]]).to(DEFAULT_DEVICE)
        softmax_vbeta_v = F.softmax(two_v*discount_vector, dim=1)
        heuristics = (softmax_vbeta_v*two_v).sum(1)
    elif heuristic_method == 'zero':
        # Not change the return but only reduce the discout factor.
        heuristics = 0*future_rewards
    else:
        raise NotImplementedError
    mix_hu_v = lambda_values*heuristics + (1 - lambda_values)*next_v
    return heuristics, mix_hu_v