import numpy as np


def risk_seeking_policy(trees, epsilon):
    rewards = trees.rewards
    indices = np.argsort(rewards)[::-1][:int(epsilon * len(trees.rewards))]
    rewards = np.array(rewards)
    trees.reduce(indices)
    rewards = rewards[indices]
    baseline_reward = np.min(rewards)
    best_reward = np.max(rewards)
    median = np.median(rewards)
    return trees, baseline_reward, (best_reward, median, baseline_reward)


def dpo_policy(trees):
    rewards = trees.rewards
    indices = []
    for i in range(int(len(rewards)/2)):
        if rewards[2*i] >= rewards[2*i + 1]:
            indices += [2*i, 2*i+1]
        else:
            indices += [2*i+1, 2*i]
    trees.reduce(indices)
    return trees, np.mean(rewards), (np.max(rewards), np.median(rewards), np.min(rewards))
