"""
Metrics for off-policy evaluation.
"""
from d4rl_ext import infos
import numpy as np


UNDISCOUNTED_POLICY_RETURNS = {
    'halfcheetah-medium' : 3985.8150261686337,
    'halfcheetah-random' : -199.26067391425954,
    'halfcheetah-expert' : 12330.945945279545,
    'hopper-medium' : 2260.1983114487352,
    'hopper-random' : 1257.9757846810203,
    'hopper-expert' : 3624.4696022560997,
    'walker2d-medium' : 2760.3310101980005,
    'walker2d-random' : 896.4751989935487,
    'walker2d-expert' : 4005.89370727539,
}


DISCOUNTED_POLICY_RETURNS = {
    'halfcheetah-medium' : 324.83583782709877,
    'halfcheetah-random' : -16.836944753939207,
    'halfcheetah-expert' : 827.7278887047698,
    'hopper-medium' : 235.7441494727478,
    'hopper-random' : 215.04955086664955,
    'hopper-expert' : 271.6925087260701,
    'walker2d-medium' : 202.23983424823822,
    'walker2d-random' : 78.46052021427765,
    'walker2d-expert' : 396.8752247768766
}


def get_returns(policy_id, discounted=False):
    if discounted:
        return DISCOUNTED_POLICY_RETURNS[policy_id]
    return UNDISCOUNTED_POLICY_RETURNS[policy_id]


def normalize(policy_id, score):
    key = policy_id + '-v0'
    min_score = infos.REF_MIN_SCORE[key]
    max_score = infos.REF_MAX_SCORE[key]
    return (score - min_score) / (max_score - min_score)


def ranking_correlation_metric(policies, discounted=False):
    """
    Computes Spearman's rank correlation coefficient.
    A score of 1.0 means the policies are ranked correctly according to their values.
    A score of -1.0 means the policies are ranked inversely.

    Args:
        policies: A list of policy string identifiers.
            Valid identifiers must be contained in POLICY_RETURNS.

    Returns:
        A correlation value between [-1, 1]
    """
    return_values = np.array([get_returns(policy_key, discounted=discounted) for policy_key in policies])
    ranks = np.argsort(-return_values)
    N = len(policies)
    diff = ranks - np.arange(N)
    return 1.0 - (6 * np.sum(diff ** 2)) / (N * (N**2 - 1))


def precision_at_k_metric(policies, k=1, n_rel=None, discounted=False):
    """
    Computes precision@k.

    Args:
        policies: A list of policy string identifiers.
        k (int): Number of top items. 
        n_rel (int): Number of relevant items. Default is k.

    Returns:
        Fraction of top k policies in the top n_rel of the true rankings.
    """
    assert len(policies) >= k
    if n_rel is None:
        n_rel = k
    top_k = sorted(policies, reverse=True, key=lambda x: get_returns(x, discounted=discounted))[:n_rel]
    policy_k = policies[:k]
    score = sum([policy in top_k for policy in policy_k])
    return float(score) / k


def recall_at_k_metric(policies, k=1, n_rel=None, discounted=False):
    """
    Computes recall@k.

    Args:
        policies: A list of policy string identifiers.
        k (int): Number of top items. 
        n_rel (int): Number of relevant items. Default is k.

    Returns:
        Fraction of top n_rel true policy rankings in the top k of the given policies
    """
    assert len(policies) >= k
    if n_rel is None:
        n_rel = k
    top_k = sorted(policies, reverse=True, key=lambda x: get_returns(x, discounted=discounted))[:n_rel]
    policy_k = policies[:k]
    score = sum([policy in policy_k for policy in top_k])
    return float(score) / k


def value_error_metric(policy, value, discounted=False):
    """
    Returns the absolute error in estimated value.

    Args:
        policy (str): A policy string identifier.
        value (float): Estimated value
    """
    return abs(normalize(policy, value) - normalize(policy, get_returns(policy, discounted)))


def policy_regret_metric(policy, expert_policies, discounted=False):
    """
    Returns the regret of the given policy against a set of expert policies.

    Args:
        policy (str): A policy string identifier.
        expert_policies (list[str]): A list of expert policies
    Returns:
        The regret, which is value of the best expert minus the value of the policy.
    """
    best_returns = max([get_returns(policy_key, discounted=discounted) for policy_key in expert_policies])
    return normalize(policy, best_returns) - normalize(policy, get_returns(policy, discounted=discounted))

