"""Joint Policy-Space Response Oracles."""
import itertools
import numpy as np
from absl import logging
import mb_jpsro.mb_torch_rl_policy as rl_policy
from mb_jpsro.evaluation import find_best_response as evaluation
from mb_jpsro.meta_solver import INIT_POLICIES, _FLAG_TO_FUNC, UPDATE_PLAYERS_STRATEGY, BRS

RETURN_TOL = 1e-12
DIST_TOL = 1e-8
GAP_TOL = 1e-8


LOG_STRING = """
Iteration {iteration: 6d}
=== ({game})
Player            {player}
BRs               {brs}
Num Policies      {num_policies}
Unique Policies   {unique}
--- ({train_meta_solver})
Train Value       {train_value}
Train Gap         {train_gap}
Train Gap Sum     {train_gap_sum}
Train Gap Sum List{train_gap_sum_list}
Eval Gap Sum List{eval_gap_sum_list}
"""


# PSRO Functions.
def initialize_policy(env_model, player, policy_init, agent_kwargs):
    """Returns initial policy."""
    if policy_init == "dqn":
        new_policy = rl_policy.DQNPolicy(env_model, player, **agent_kwargs)
    else:
        raise ValueError(
            "policy_init must be a valid initialization strategy: %s. Received: %s" % (INIT_POLICIES, policy_init))
    return new_policy


def add_new_policies(per_player_new_policies, per_player_gaps, per_player_repeats, per_player_policies, joint_policies,
                     joint_returns, env_model, br_selection, evaluate_number, eval_joint_returns):
    """Adds novel policies from new policies."""
    num_players = len(per_player_new_policies)
    per_player_num_novel_policies = [0 for _ in range(num_players)]

    # Update policies and policy counts.
    for player in range(num_players):
        new_policies = per_player_new_policies[player]
        new_gaps = per_player_gaps[player]

        repeat_policies = []
        repeat_gaps = []
        repeat_ids = []
        novel_policies = []
        novel_gaps = []
        logging.info(new_gaps)
        for new_policy, new_gap in zip(new_policies, new_gaps):
            logging.info("Player %d's new policy is novel.", player)
            novel_policies.append(new_policy)
            novel_gaps.append(new_gap)

        add_novel_policies = []
        add_repeat_ids = []
        if (novel_policies or repeat_policies):
            if br_selection == "all":
                add_novel_policies.extend(novel_policies)
                add_repeat_ids.extend(repeat_ids)
            elif br_selection == "all_novel":
                add_novel_policies.extend(novel_policies)
            elif br_selection == "random":
                index = np.random.randint(0, len(repeat_policies) + len(novel_policies))
                if index < len(novel_policies):
                    add_novel_policies.append(novel_policies[index])
                else:
                    add_repeat_ids.append(repeat_ids[index - len(novel_policies)])
            elif br_selection == "random_novel":
                if novel_policies:
                    index = np.random.randint(0, len(novel_policies))
                    add_novel_policies.append(novel_policies[index])
                else:  # Fall back on random.
                    index = np.random.randint(0, len(repeat_policies))
                    add_repeat_ids.append(repeat_ids[index])
            elif br_selection == "largest_gap":
                if novel_policies:
                    index = np.argmax(novel_gaps)
                    if novel_gaps[index] == 0.0:  # Fall back to random when zero.
                        index = np.random.randint(0, len(novel_policies))
                    add_novel_policies.append(novel_policies[index])
                else:  # Fall back on random.
                    index = np.random.randint(0, len(repeat_policies))
                    add_repeat_ids.append(repeat_ids[index])
            else:
                raise ValueError("Unrecognized br_selection method: %s"
                                 % br_selection)

        for add_repeat_id in add_repeat_ids:
            per_player_repeats[player][add_repeat_id] += 1

        for add_novel_policy in add_novel_policies:
            per_player_policies[player].append(add_novel_policy)  # Add new policy.
            per_player_repeats[player].append(1)  # Add new count.
            per_player_num_novel_policies[player] += 1

    # Add new joint policies.
    for pids in itertools.product(*[range(len(policies)) for policies in per_player_policies]):
        if pids in joint_policies:
            continue
        # logging.info("Evaluating novel joint policy: %s.", pids)
        policies = [policies[pid] for pid, policies in zip(pids, per_player_policies)]
        joint_policies[pids] = policies
        joint_returns[pids] = [0.0 if abs(er) < RETURN_TOL else er
                               for er in expected_returns(env_model, policies, num_episodes=evaluate_number)]
        eval_joint_returns[pids] = [0.0 if abs(er) < RETURN_TOL else er
                                    for er in eval_expected_returns(env_model.env, policies, num_episodes=evaluate_number)]
    # return per_player_num_novel_policies


def expected_returns(env_model, policies, num_episodes):
    totals = np.zeros(len(policies))
    for _ in range(num_episodes):
        totals += sample_episode(env_model.reset(), env_model, policies).reshape(-1)
    return totals / num_episodes


def sample_episode(time_step, env_model, policies):
    if time_step.last():
        return np.array(time_step.rewards, dtype=np.float32)

    player = time_step.observations["current_player"]
    agent_output = policies[player].step(time_step, is_evaluation=True)
    action_list = [agent_output.action]
    time_step = env_model.step(time_step, action_list)

    return sample_episode(time_step, env_model, policies)


def eval_expected_returns(env, policies, num_episodes):
    totals = np.zeros(len(policies))
    for _ in range(num_episodes):
        totals += eval_sample_episode(env.reset(), env, policies).reshape(-1)
    return totals / num_episodes


def eval_sample_episode(time_step, env, policies):
    if time_step.last():
        return np.array(time_step.rewards, dtype=np.float32)

    player = time_step.observations["current_player"]
    agent_output = policies[player].step(time_step, is_evaluation=True)
    action_list = [agent_output.action]
    time_step = env.step(action_list)

    return eval_sample_episode(time_step, env, policies)


def add_meta_game(meta_games, per_player_policies, joint_returns):
    """Returns a meta-game tensor."""
    per_player_num_policies = [len(policies) for policies in per_player_policies]
    shape = [len(per_player_num_policies)] + per_player_num_policies
    meta_game = np.zeros(shape)

    for pids in itertools.product(*[range(np_) for np_ in per_player_num_policies]):
        meta_game[(slice(None),) + pids] = joint_returns[pids]

    meta_games.append(meta_game)
    return meta_games


def add_meta_dist(meta_dists, meta_values, meta_solver, meta_game, per_player_repeats, ignore_repeats):
    """Returns meta_dist."""
    num_players = meta_game.shape[0]
    meta_solver_func = _FLAG_TO_FUNC[meta_solver]
    meta_dist, _ = meta_solver_func(meta_game, per_player_repeats, ignore_repeats=ignore_repeats)
    # Clean dist.
    meta_dist = meta_dist.astype(np.float64)
    meta_dist[meta_dist < DIST_TOL] = 0.0
    meta_dist[meta_dist > 1.0] = 1.0
    meta_dist /= np.sum(meta_dist)
    meta_dist[meta_dist > 1.0] = 1.0
    meta_dists.append(meta_dist)
    meta_value = np.sum(meta_dist * meta_game, axis=tuple(range(1, num_players + 1)))
    meta_values.append(meta_value)
    return meta_dist


def compute_best_response(env_model, joint_policies_slice, meta_dist_slice, player, agent_kwargs, number_train,
                          simulate_number):
    """compute best response agent against mu"""
    new_policy = rl_policy.DQNPolicy(env_model, player, **agent_kwargs)

    for i in range(number_train):
        time_step = env_model.reset()
        policy_id = np.random.choice([i for i in range(len(joint_policies_slice))], p=meta_dist_slice.ravel())
        sample_agent = joint_policies_slice[policy_id]
        cumulative_rewards = 0.0
        step = 0

        while not time_step.last():
            player_id = time_step.observations["current_player"]
            if player_id == player:
                agent_output = new_policy.step(time_step, is_evaluation=False)
            else:
                agent_output = sample_agent[player_id].step(time_step, is_evaluation=True)

            action_list = [agent_output.action]
            time_step = env_model.step(time_step, action_list)
            cumulative_rewards += np.array(time_step.rewards)
            step += 1

        new_policy.step(time_step, is_evaluation=False)

    reward = best_response_returns(env_model, new_policy, player, joint_policies_slice, meta_dist_slice, simulate_number)
    return new_policy, reward[player]


def best_response_returns(env_model, new_policy, player, joint_policies_slice, meta_dist_slice, num_episodes):
    totals = np.zeros(int(env_model.player_number))
    for _ in range(num_episodes):
        policy_id = np.random.choice([i for i in range(len(joint_policies_slice))], p=meta_dist_slice.ravel())
        sample_agent = joint_policies_slice[policy_id]
        policies = []
        for player_id in range(len(totals)):
            if player_id == player:
                policies.append(new_policy)
            else:
                policies.append(sample_agent[player_id])

        totals += sample_episode(env_model.reset(), env_model, policies).reshape(-1)
    return totals / num_episodes


def find_best_response(env_model, meta_dist, meta_game, iteration, joint_policies, target_equilibrium,
                       update_players_strategy, agent_kwargs, number_train, simulate_number):
    """Returns new best response policies."""
    num_players = meta_game.shape[0]
    per_player_num_policies = meta_dist.shape[:]

    # Player update strategy.
    if update_players_strategy == "all":
        players = list(range(num_players))
    elif update_players_strategy == "cycle":
        players = [iteration % num_players]
    elif update_players_strategy == "random":
        players = [np.random.randint(0, num_players)]
    else:
        raise ValueError("update_players_strategy must be a valid player update strategy: "
                         "%s. Received: %s" % (UPDATE_PLAYERS_STRATEGY, update_players_strategy))

    # Find best response.
    per_player_new_policies = []
    per_player_deviation_incentives = []

    if target_equilibrium == "cce":
        for player in range(num_players):
            if player in players:
                joint_policy_ids = itertools.product(*[(np_ - 1,) if p_ == player else range(np_)
                                                       for p_, np_ in enumerate(per_player_num_policies)])
                joint_policies_slice = [joint_policies[jpid] for jpid in joint_policy_ids]
                meta_dist_slice = np.sum(meta_dist, axis=player)
                meta_dist_slice[meta_dist_slice < DIST_TOL] = 0.0
                meta_dist_slice[meta_dist_slice > 1.0] = 1.0
                meta_dist_slice /= np.sum(meta_dist_slice)
                meta_dist_slice = meta_dist_slice.flat

                pr = []
                for i in meta_dist_slice:
                    pr.append(i)

                new_policy, best_response_values = compute_best_response(env_model, joint_policies_slice,
                                                                         np.array(pr), player, agent_kwargs,
                                                                         number_train, simulate_number)

                on_policy_value = np.sum(meta_game[player] * meta_dist)
                deviation_incentive = max(best_response_values - on_policy_value, 0)
                if deviation_incentive < GAP_TOL:
                    deviation_incentive = 0.0

                per_player_new_policies.append([new_policy])
                per_player_deviation_incentives.append([deviation_incentive])
            else:
                per_player_new_policies.append([])
                per_player_deviation_incentives.append([])

    elif target_equilibrium == "ce":
        for player in range(num_players):
            if player in players:
                per_player_new_policies.append([])
                per_player_deviation_incentives.append([])

                for pid in range(per_player_num_policies[player]):
                    joint_policy_ids = itertools.product(*[(pid,) if p_ == player else range(np_)
                                                           for p_, np_ in enumerate(per_player_num_policies)])
                    joint_policies_slice = [joint_policies[jpid] for jpid in joint_policy_ids]
                    inds = tuple((pid,) if player == p_ else slice(None) for p_ in range(num_players))
                    meta_dist_slice = np.ravel(meta_dist[inds]).copy()
                    meta_dist_slice[meta_dist_slice < DIST_TOL] = 0.0
                    meta_dist_slice[meta_dist_slice > 1.0] = 1.0
                    meta_dist_slice_sum = np.sum(meta_dist_slice)

                    if meta_dist_slice_sum > 0.0:
                        meta_dist_slice /= meta_dist_slice_sum
                        new_policy, best_response_values = compute_best_response(env_model, joint_policies_slice,
                                                                                 meta_dist_slice, player, agent_kwargs,
                                                                                 number_train, simulate_number)

                        on_policy_value = np.sum(np.ravel(meta_game[player][inds]) * meta_dist_slice)
                        deviation_incentive = max(best_response_values - on_policy_value, 0)
                        if deviation_incentive < GAP_TOL:
                            deviation_incentive = 0.0

                        per_player_new_policies[-1].append(new_policy)
                        per_player_deviation_incentives[-1].append(meta_dist_slice_sum * deviation_incentive)
            else:
                per_player_new_policies.append([])
                per_player_deviation_incentives.append([])

    else:
        raise ValueError("target_equilibrium must be a valid best response strategy: %s. "
                         "Received: %s" % (BRS, target_equilibrium))

    return per_player_new_policies, per_player_deviation_incentives


# Main Loop.

def initialize(env_model, agent_kwargs, train_meta_solver, policy_init, ignore_repeats, br_selection, evaluate_number):
    """Return initialized data structures."""

    num_players = env_model.player_number

    # Initialize.
    iteration = 0
    per_player_repeats = [[] for _ in range(num_players)]
    per_player_policies = [[] for _ in range(num_players)]
    joint_policies = {}  # Eg. (1, 0): Joint policy.
    joint_returns = {}
    eval_joint_returns = {}
    meta_games = []
    eval_meta_games = []
    train_meta_dists = []
    train_meta_values = []
    train_meta_gaps = []

    # Initialize policies.
    per_player_new_policies = [[initialize_policy(env_model, player, policy_init, agent_kwargs)]
                               for player in range(num_players)]

    per_player_gaps_train = [[1.0] for _ in range(num_players)]
    add_new_policies(per_player_new_policies, per_player_gaps_train, per_player_repeats, per_player_policies,
                     joint_policies, joint_returns, env_model, br_selection, evaluate_number, eval_joint_returns)

    add_meta_game(meta_games, per_player_policies, joint_returns)
    add_meta_game(eval_meta_games, per_player_policies, eval_joint_returns)

    add_meta_dist(train_meta_dists, train_meta_values, train_meta_solver, meta_games[-1], per_player_repeats,
                  ignore_repeats)

    return (iteration, per_player_repeats, per_player_policies, joint_policies, eval_joint_returns, joint_returns,
            meta_games, eval_meta_games, train_meta_dists, train_meta_values, train_meta_gaps)


def initialize_callback_(
        iteration,
        per_player_repeats,
        per_player_policies,
        joint_policies,
        eval_joint_returns,
        joint_returns,
        meta_games,
        eval_meta_games,
        train_meta_dists,
        train_meta_values,
        train_meta_gaps):
    """Callback which allows initializing from checkpoint."""
    checkpoint = None
    return (
        iteration,
        per_player_repeats,
        per_player_policies,
        joint_policies,
        eval_joint_returns,
        joint_returns,
        meta_games,
        eval_meta_games,
        train_meta_dists,
        train_meta_values,
        train_meta_gaps,
        checkpoint)


def run_loop(env_model, game_name,
             seed=0,
             iterations=40,
             policy_init="uniform",
             update_players_strategy="all",
             target_equilibrium="cce",
             br_selection="largest_gap",
             train_meta_solver="mgcce",
             ignore_repeats=False,
             initialize_callback=None,
             agent_kwargs=None,
             number_train=int(1e3),
             evaluate_number=int(1e3),
             simulate_number=int(1e3)):

    """Runs JPSRO."""
    if initialize_callback is None:
        initialize_callback = initialize_callback_

    # Set seed.
    np.random.seed(seed)

    # Some statistics.
    num_players = env_model.player_number  # Look in the game.

    # Initialize.
    values = initialize(env_model, agent_kwargs, train_meta_solver, policy_init, ignore_repeats, br_selection,
                        evaluate_number)

    # Initialize Callback.
    (iteration,
     per_player_repeats,
     per_player_policies,
     joint_policies,
     eval_joint_returns,
     joint_returns,
     meta_games,
     eval_meta_games,
     train_meta_dists,
     train_meta_values,
     train_meta_gaps,
     checkpoint) = initialize_callback(*values)

    eval_meta_gaps = []
    train_gap_sum_list = []
    eval_gap_sum_list = []
    # Run JPSRO.
    while iteration <= iterations:
        logging.info("Beginning JPSRO iteration %03d", iteration)
        # compute best response
        per_player_new_policies, per_player_gaps_train = find_best_response(env_model, train_meta_dists[-1],
                                                                            meta_games[-1], iteration, joint_policies,
                                                                            target_equilibrium, update_players_strategy,
                                                                            agent_kwargs, number_train, simulate_number)
        # gap used for evaluation
        train_meta_gaps.append([sum(gaps) for gaps in per_player_gaps_train])

        _, per_player_gaps_eval = evaluation(env_model.env, train_meta_dists[-1], eval_meta_games[-1], iteration,
                                             joint_policies, target_equilibrium, update_players_strategy,
                                             agent_kwargs, number_train, int(1e4))
        eval_meta_gaps.append([sum(gaps) for gaps in per_player_gaps_eval])

        # add new policies to policy list
        add_new_policies(per_player_new_policies, per_player_gaps_train, per_player_repeats, per_player_policies,
                         joint_policies, joint_returns, env_model, br_selection, evaluate_number, eval_joint_returns)

        # update meta game
        add_meta_game(meta_games, per_player_policies, joint_returns)
        add_meta_game(eval_meta_games, per_player_policies, eval_joint_returns)

        # update meta distribution
        add_meta_dist(train_meta_dists, train_meta_values, train_meta_solver, meta_games[-1], per_player_repeats,
                      ignore_repeats)

        # Stats.
        per_player_num_policies = train_meta_dists[-1].shape[:]
        train_gap_sum_list.append(sum(train_meta_gaps[-1]))
        eval_gap_sum_list.append(sum(eval_meta_gaps[-1]))
        log_string = LOG_STRING.format(iteration=iteration,
                                       game=game_name,
                                       player=("{: 12d}" * num_players).format(*list(range(num_players))),
                                       brs=target_equilibrium,
                                       num_policies=("{: 12d}" * num_players).format(*[sum(ppr)
                                                                                       for ppr in per_player_repeats]),
                                       unique=("{: 12d}" * num_players).format(*per_player_num_policies),
                                       train_meta_solver=train_meta_solver,
                                       train_value=("{: 12g}" * num_players).format(*train_meta_values[-1]),
                                       train_gap=("{: 12g}" * num_players).format(*train_meta_gaps[-1]),
                                       train_gap_sum=(sum(train_meta_gaps[-1]), sum(eval_meta_gaps[-1])),
                                       train_gap_sum_list=train_gap_sum_list,
                                       eval_gap_sum_list=eval_gap_sum_list)

        logging.info(log_string)
        iteration += 1

    best_policy = [[p[i]._policy for i in range(len(p))] for p in per_player_policies]
    best_policy.append(train_meta_dists[-1])

    return eval_gap_sum_list, best_policy
