from tqdm import tqdm
from scipy.optimize import minimize

from utils.constants import *
from utils.publishers_game import GAME_TYPES, ExtendedPublishersGame
from utils.general_utils import bootstrap_ci


# Regret Simulation

class RegretMinimizingPlayer:
    def __init__(self, d, eta):
        self.d = d
        self.eta = eta
        self.U_tilde = np.zeros(d + 1)
        self.u_tilde = np.zeros(d + 1)
        self.lam, self.y, self.x = None, None, None

    def act(self):
        initial_guess = np.array([self.lam] + list(self.y)) if self.lam is not None else None
        self.lam, self.y = solve_opt_prob(self.u_tilde, self.U_tilde, self.eta,
                                          initial_guess=initial_guess)
        self.x = self.y / self.lam  # d-dimensional vector
        return self.x

    def update(self, u):
        self.u_tilde = np.array([-np.dot(self.x, u)] + list(u))
        self.U_tilde += self.u_tilde


def solve_opt_prob(u_tilde, U_tilde, eta, initial_guess=None):
    d = len(u_tilde) - 1
    tmp = U_tilde + u_tilde

    def f(lam_y):
        return -(eta * np.dot(tmp, lam_y) + np.sum(np.log(lam_y)))

    def f_grad(lam_y):
        return - tmp * eta - 1 / lam_y

    if initial_guess is None:
        initial_guess = np.random.uniform(low=DELTA_, high=1, size=d + 1)

    bounds = [(DELTA_, 1)] * (d + 1)
    constraints = [{'type': 'ineq', 'fun': lambda lam_y: lam_y[0] - lam_y[i]} for i in range(1, d + 1)]
    result = minimize(fun=f, x0=initial_guess, bounds=bounds, method='SLSQP', constraints=constraints,
                      options={'ftol': DELTA_}, jac=f_grad)

    if result.success:
        lam = result.x[0]
        y = result.x[1:]
        if not ((DELTA_ <= lam <= 1) and np.all(DELTA_ <= y) and np.all(y <= lam + DELTA_)):
            initial_guess = np.random.uniform(low=DELTA_, high=1, size=d + 1)
            result = minimize(fun=f, x0=initial_guess, bounds=bounds, method='SLSQP', constraints=constraints,
                              options={'ftol': DELTA_}, jac=f_grad)
            lam = result.x[0]
            y = result.x[1:]
            if not (result.success and (DELTA_ <= lam <= 1) and np.all(DELTA_ <= y) and np.all(y <= lam + DELTA_)):
                raise ValueError("Optimization failed")
        return lam, y
    else:
        raise ValueError("Optimization failed")


def is_eps_pne(G: ExtendedPublishersGame, eps, print_eps=False):
    bounds = [(0, 1)] * G.k
    for i in range(G.n):
        temp_x = G.x.copy()

        def f(xi):
            return -G.calc_u_deviation(xi, i)

        def f_grad(xi):
            temp_x[i] = xi
            return -G.calc_grad(temp_x, i)

        initial_guess = G.x[i]
        result = minimize(fun=f, x0=initial_guess, bounds=bounds, method='SLSQP', options={'ftol': DELTA_}, jac=f_grad)

        if not result.success:
            raise ValueError("Finding best response for eps-PNE failed")

        best_strategy = result.x
        diff = G.calc_u_deviation(best_strategy, i) - G.get_u(i)
        if diff > eps and not print_eps:
            return False
        if print_eps:
            # print(G.calc_u_deviation(best_strategy, i))
            # print(G.get_u(i))
            print(f'Player {i} eps: {diff}')
            print(f'Player {i} best strategy: {best_strategy}')
            print()
    return True


def no_regret_dynamics(G: ExtendedPublishersGame, T, eta, eps):
    players = [RegretMinimizingPlayer(d=G.k, eta=eta) for _ in range(G.n)]

    conv = 0
    # We use T to avoid infinite loops.
    # In practice, in all our experiments, the algorithm converges in less than T rounds
    for round_num in range(T):
        x_new = np.array([player.act() for player in players])

        G.update_x(x_new)
        G.save_u_all()
        if is_eps_pne(G, eps):
            conv = 1
            break

        grads = G.get_grad_all()
        for i in range(G.n):
            players[i].update(grads[i])

    publishers_welfare = G.get_publishers_welfare()
    users_welfare = G.get_users_welfare()

    if not conv:
        print("No convergence")
        is_eps_pne(G, eps, print_eps=True)
        print()
    return publishers_welfare, users_welfare, round_num + 1, conv


def calc_regret(history, G: ExtendedPublishersGame):
    regret = np.zeros(G.n)
    bounds = [(0, 1)] * G.k
    for i in range(G.n):
        def f(xi):
            return -sum([G.update_x(x).calc_u_deviation(xi, i) for x in history])  # check time complexity

        initial_guess = history[-1][i]
        result = minimize(fun=f, x0=initial_guess, bounds=bounds, method='SLSQP', options={'ftol': DELTA_})

        if not result.success:
            raise ValueError("Finding best constant response failed")

        best_constant_strategy = result.x
        assert np.all(0 <= best_constant_strategy) and np.all(best_constant_strategy <= 1), best_constant_strategy
        for profile in history:
            G.update_x(profile)
            real_u = G.get_u(i)
            best_constant_strategy_u = G.calc_u_deviation(best_constant_strategy, i)
            regret[i] += best_constant_strategy_u - real_u
    return regret


def no_regret_dynamics_regret_calc(G: ExtendedPublishersGame, T, eta, return_hist=False):
    players = [RegretMinimizingPlayer(d=G.k, eta=eta) for _ in range(G.n)]

    G.initialize()
    profile_hist = []

    for _ in range(T):
        x_new = np.array([player.act() for player in players])

        G.update_x(x_new)

        grads = G.get_grad_all()
        for i in range(G.n):
            players[i].update(grads[i])

        profile_hist.append(x_new)

    publishers_welfare = G.get_publishers_welfare()
    users_welfare = G.get_users_welfare()
    profile_hist = np.array(profile_hist)
    publishers_regret = calc_regret(profile_hist.copy(), G)
    regret = sum(publishers_regret / T) / G.n
    eps_x = G.d(profile_hist[-1], profile_hist[-2]).sum()

    if return_hist:
        eps_PNE = regret * G.n  # the epsilon for the epsilon-PNE if the game is Socially Concave
        return profile_hist, eps_PNE
    return publishers_welfare, users_welfare, regret, eps_x


def full_simulation(ranking_function, additional_param, instances, k, n, s, lam, sim_params,
                    simulation_func=None, tqdm_key=None):
    """Runs a full simulation for a given ranking function and additional parameters, using simulation function."""

    B = len(instances)
    # init results
    publishers_welfare_res = np.zeros(B)
    users_welfare_res = np.zeros(B)
    convergence_rate = np.zeros(B)
    convergence_ratio = np.zeros(B)
    simulation_func = no_regret_dynamics if simulation_func is None else simulation_func

    loop = range(B) if tqdm_key is None else tqdm(range(B),
                                                  desc=f'{tqdm_key("k", "n", "s", "lambda")}={tqdm_key(k, n, s, lam)}')
    for i in loop:
        x_0, x_star_lst = instances[i]

        if additional_param is None:
            G = GAME_TYPES[ranking_function](k, n, s, lam, x_0, x_star_lst)
        else:
            G = GAME_TYPES[ranking_function](k, n, s, lam, x_0, x_star_lst, additional_param)

        publishers_welfare_res[i], users_welfare_res[i], convergence_rate[i], convergence_ratio[i] = \
            (simulation_func(G, **sim_params))

    return np.mean(publishers_welfare_res), np.mean(users_welfare_res), \
        np.mean(convergence_rate), np.mean(convergence_ratio), \
        bootstrap_ci(publishers_welfare_res, B), bootstrap_ci(users_welfare_res, B), \
        bootstrap_ci(convergence_rate, B), bootstrap_ci(convergence_ratio, B)


def comparison(ranking_function_lst, additional_param_lst, k_vals, n_vals, s_vals, lam_vals, B, params,
               simulation_func=None, generate_x0_x_star=None, tqdm_key=None):
    """Compares ranking functions on a grid of parameters.

    Args:
        ranking_function_lst (iterable): collection of ranking functions.
        additional_param_lst (iterable): collection of additional parameters for the ranking functions.
        k_vals (iterable or int): collection of embedding space dimensions / single embedding space dimension.
        n_vals (iterable or int): collection of publishers amounts / single publishers amount.
        s_vals (iterable or int): collection of information needs amounts / single information needs amount.
        lam_vals (iterable or int): collection of lambda values / single lambda value.
        B (int): number of samples.
        params (dict): dictionary of parameters for the simulation.
        simulation_func (function or None, optional): a simulation function or None. Defaults to None.
        generate_x0_x_star (function or None, optional): function to generate x0 and x_star or None. Defaults to None.
        tqdm_key (function or None, optional): key function for tqdm or None. Defaults to None.

    Returns:
        list: list of results for each ranking function.
    """
    assert len(ranking_function_lst) == len(additional_param_lst), \
        "ranking_function_lst and additional_param_lst must have the same length"
    assert callable(simulation_func) or simulation_func is None, "simulation_func must be a function or None"
    assert callable(generate_x0_x_star) or generate_x0_x_star is None, "generate_x0_x_star must be a function or None"

    amount = len(ranking_function_lst)
    k_vals = [k_vals] if not hasattr(k_vals, '__iter__') else k_vals
    n_vals = [n_vals] if not hasattr(n_vals, '__iter__') else n_vals
    s_vals = [s_vals] if not hasattr(s_vals, '__iter__') else s_vals
    lam_vals = [lam_vals] if not hasattr(lam_vals, '__iter__') else lam_vals

    if generate_x0_x_star is not None:
        generate_func = lambda k, n, s: generate_x0_x_star(k, n, s)
    else:
        generate_func = lambda k, n, s: (np.random.rand(n, k), np.random.rand(s, k))
    
    res = [{} for _ in range(amount)]
    for k_val in k_vals:
        for n_val in n_vals:
            for s_val in s_vals:
                instances = [generate_func(k_val, n_val, s_val) for _ in range(B)]
                for lam_val in lam_vals:
                    for i, ranking_function in enumerate(ranking_function_lst):
                        res[i][(k_val, n_val, s_val, lam_val)] = \
                            full_simulation(ranking_function, additional_param_lst[i], 
                                            instances, k_val, n_val, s_val, lam_val, params,
                                            simulation_func=simulation_func, tqdm_key=tqdm_key)
    return res
