import numpy as np

from expground.types import Dict, AgentID, PolicyID, Any, List
from expground.learner import PSROLearner
from expground.gt.payoff_matrix import PayoffMatrix


def calculate_centroid(
    candidates: List[Dict[AgentID, Dict[PolicyID, float]]]
) -> Dict[AgentID, Dict[PolicyID, float]]:
    centroid = {
        agent: dict.fromkeys(pd.keys(), 0.0) for agent, pd in candidates[0].items()
    }
    for e in candidates:
        for agent, meta_policy in e.items():
            for k, v in meta_policy.items():
                centroid[agent][k] += v / len(candidates)
    return centroid


def reflect_meta_strategies(left_meta: Dict, right_meta: Dict, alpha: float = 0.001):
    res = {}
    for agent, meta_strategy in left_meta.items():
        res[agent] = dict.fromkeys(meta_strategy.keys(), 0.0)
        for k, v in meta_strategy.items():
            res[agent][k] = v + alpha * (v - right_meta[agent][k])
    return res


def simplex_project(
    ori_strategy: Dict, candidates: List[Dict]
) -> Dict[AgentID, Dict[PolicyID, float]]:
    # then return the closest meta strategies in candidates
    min_dis = float("inf")
    feature_matrix = np.asarray([list(e.values()) for e in ori_strategy.values()])
    selected = None
    for i, e in enumerate(candidates):
        tmp = np.asarray([list(v.values()) for v in e.values()])
        tmp_dis = np.sum(np.square(feature_matrix - tmp))
        min_dis = min(tmp_dis, min_dis)
        if min_dis == tmp_dis:
            selected = (i, e)
    return selected


def projected_amoeba(
    meta_policies: Dict[AgentID, Dict[PolicyID, float]],
    policy_pool: Dict[AgentID, Any],
    payoff_matrix: PayoffMatrix,
    alpha: float = 0.001,
    t: int = 100,
    n: int = 10,
):
    n_step = 0
    # generate n random meta policies
    candidates = [meta_policies] * n
    sigma_c = None

    while n_step < t:
        # compute regret with given payoff table
        agent_regrets = []
        for i, e in enumerate(candidates):
            agent_regrets[i] = (i, payoff_matrix.compute(meta_policies=e))
        # sort by value
        agent_regrets = sorted(agent_regrets, key=lambda x: x[-1])
        # compute centralized meta policies, except the right most
        centroid = calculate_centroid(candidates[:-1])
        # compute reflected point
        idx = agent_regrets[-1][0]
        right_meta = candidates[idx]
        sigma_r = reflect_meta_strategies(centroid, right_meta, alpha)
        # project sigma_r to probability simplex
        sigma_r_idx, sigma_r = simplex_project(sigma_r, candidates)

        # convert agent rest to dict
        agent_regrets = dict(agent_regrets)
        if agent_regrets[0] <= agent_regrets[sigma_r_idx] < agent_regrets[n - 2]:
            # update the right most
            candidates[-1] = sigma_r
        # expansion
        if agent_regrets[sigma_r_idx] < agent_regrets[0]:
            sigma_e = reflect_meta_strategies(centroid, sigma_r)
            sigma_e_idx, sigma_e = simplex_project(sigma_e, candidates)
            if agent_regrets[sigma_e_idx] < agent_regrets[0]:
                candidates[-1] = sigma_e
            else:
                candidates[-1] = sigma_r

        # contraction
        # sigma_c = centroid + alpha * (candidates[-1] - centroid)
        sigma_c = None
        sigma_c_idx, sigma_c = simplex_project(sigma_c)
        if agent_regrets[sigma_c_idx] < agent_regrets[sigma_r_idx]:
            candidates[-1] = sigma_c

        sigma_c = centroid + alpha * (candidates[-1] - centroid)
        sigma_c_idx, sigma_c = simplex_project(sigma_c)

    return sigma_c


def stackelberg_search(learner: PSROLearner):
    # compute meta strategies at first
    meta_strategies = learner.compute
