import itertools
import numpy as np
import mb_jpsro.mb_torch_rl_policy as rl_policy
from mb_jpsro.meta_solver import UPDATE_PLAYERS_STRATEGY, BRS
DIST_TOL = 1e-8
GAP_TOL = 1e-8


def find_best_response(env, meta_dist, meta_game, iteration, joint_policies, target_equilibrium,
                       update_players_strategy, agent_kwargs, number_train, simulate_number):
    """Returns new best response policies."""

    num_players = meta_game.shape[0]
    per_player_num_policies = meta_dist.shape[:]

    # Player update strategy.
    if update_players_strategy == "all":
        players = list(range(num_players))
    elif update_players_strategy == "cycle":
        players = [iteration % num_players]
    elif update_players_strategy == "random":
        players = [np.random.randint(0, num_players)]
    else:
        raise ValueError("update_players_strategy must be a valid player update strategy: "
                         "%s. Received: %s" % (UPDATE_PLAYERS_STRATEGY, update_players_strategy))

    # Find best response.
    per_player_new_policies = []
    per_player_deviation_incentives = []

    if target_equilibrium == "cce":
        for player in range(num_players):
            if player in players:
                joint_policy_ids = itertools.product(*[(np_ - 1,) if p_ == player else range(np_)
                                                       for p_, np_ in enumerate(per_player_num_policies)])
                joint_policies_slice = [joint_policies[jpid] for jpid in joint_policy_ids]
                meta_dist_slice = np.sum(meta_dist, axis=player)
                meta_dist_slice[meta_dist_slice < DIST_TOL] = 0.0
                meta_dist_slice[meta_dist_slice > 1.0] = 1.0
                meta_dist_slice /= np.sum(meta_dist_slice)
                meta_dist_slice = meta_dist_slice.flat

                pr = []
                for i in meta_dist_slice:
                    pr.append(i)

                new_policy, best_response_values = compute_best_response(env, joint_policies_slice,
                                                                         np.array(pr), player, agent_kwargs,
                                                                         number_train, simulate_number)

                on_policy_value = np.sum(meta_game[player] * meta_dist)
                deviation_incentive = max(best_response_values - on_policy_value, 0)
                if deviation_incentive < GAP_TOL:
                    deviation_incentive = 0.0

                per_player_new_policies.append([new_policy])
                per_player_deviation_incentives.append([deviation_incentive])
            else:
                per_player_new_policies.append([])
                per_player_deviation_incentives.append([])

    elif target_equilibrium == "ce":
        for player in range(num_players):
            if player in players:
                per_player_new_policies.append([])
                per_player_deviation_incentives.append([])

                for pid in range(per_player_num_policies[player]):
                    joint_policy_ids = itertools.product(*[(pid,) if p_ == player else range(np_)
                                                           for p_, np_ in enumerate(per_player_num_policies)])
                    joint_policies_slice = [joint_policies[jpid] for jpid in joint_policy_ids]
                    inds = tuple((pid,) if player == p_ else slice(None) for p_ in range(num_players))
                    meta_dist_slice = np.ravel(meta_dist[inds]).copy()
                    meta_dist_slice[meta_dist_slice < DIST_TOL] = 0.0
                    meta_dist_slice[meta_dist_slice > 1.0] = 1.0
                    meta_dist_slice_sum = np.sum(meta_dist_slice)

                    if meta_dist_slice_sum > 0.0:
                        meta_dist_slice /= meta_dist_slice_sum
                        new_policy, best_response_values = compute_best_response(env, joint_policies_slice,
                                                                                 meta_dist_slice, player, agent_kwargs,
                                                                                 number_train, simulate_number)

                        on_policy_value = np.sum(np.ravel(meta_game[player][inds]) * meta_dist_slice)
                        deviation_incentive = max(best_response_values - on_policy_value, 0)
                        if deviation_incentive < GAP_TOL:
                            deviation_incentive = 0.0

                        per_player_new_policies[-1].append(new_policy)
                        per_player_deviation_incentives[-1].append(meta_dist_slice_sum * deviation_incentive)
            else:
                per_player_new_policies.append([])
                per_player_deviation_incentives.append([])

    else:
        raise ValueError("target_equilibrium must be a valid best response strategy: %s. "
                         "Received: %s" % (BRS, target_equilibrium))

    return per_player_new_policies, per_player_deviation_incentives


def compute_best_response(env, joint_policies_slice, meta_dist_slice, player, agent_kwargs, number_train, simulate_number):
    """compute best response agent against mu"""
    new_policy = rl_policy.DQNPolicy(env, player, **agent_kwargs)

    for i in range(number_train):
        time_step = env.reset()
        policy_id = np.random.choice([i for i in range(len(joint_policies_slice))], p=meta_dist_slice.ravel())
        sample_agent = joint_policies_slice[policy_id]
        cumulative_rewards = 0.0
        step = 0

        while not time_step.last():
            player_id = time_step.observations["current_player"]
            if player_id == player:
                agent_output = new_policy.step(time_step, is_evaluation=False)
            else:
                agent_output = sample_agent[player_id].step(time_step, is_evaluation=True)

            action_list = [agent_output.action]
            time_step = env.step(action_list)
            cumulative_rewards += np.array(time_step.rewards)
            step += 1

        new_policy.step(time_step, is_evaluation=False)

    reward = best_response_returns(env, new_policy, player, joint_policies_slice, meta_dist_slice, simulate_number)
    return new_policy, reward[player]


def best_response_returns(env, new_policy, player, joint_policies_slice, meta_dist_slice, num_episodes):
    totals = np.zeros(int(env.num_players))
    for _ in range(num_episodes):
        policy_id = np.random.choice([i for i in range(len(joint_policies_slice))], p=meta_dist_slice.ravel())
        sample_agent = joint_policies_slice[policy_id]
        policies = []
        for player_id in range(len(totals)):
            if player_id == player:
                policies.append(new_policy)
            else:
                policies.append(sample_agent[player_id])

        totals += sample_episode(env.reset(), env, policies).reshape(-1)
    return totals / num_episodes


def sample_episode(time_step, env, policies):
    if time_step.last():
        return np.array(time_step.rewards, dtype=np.float32)

    player = time_step.observations["current_player"]
    agent_output = policies[player].step(time_step, is_evaluation=True)
    action_list = [agent_output.action]
    time_step = env.step(action_list)

    return sample_episode(time_step, env, policies)

