import numpy as np
from scipy import optimize
from typing import List, Tuple
import warnings
import matplotlib.pyplot as plt
from modules.agents.Probability import ProbabilityEmpiricalMeasure
from modules.agents.Policy import PolicyFiniteActionsFiniteStates
from modules.utils.Log import Logger


def compute_distances(dist: callable, action_space: np.ndarray) -> np.ndarray:
    return np.array([[dist(a2, a1)
                      for a1 in action_space] for a2 in action_space])


def compute_distances_p(dist: callable, action_space: np.ndarray, p: float) -> np.ndarray:
    return np.power(compute_distances(dist, action_space), p)


def regularizer_before_min(s: int, a: int, lam: float, advantage: np.ndarray, distances_p: callable) -> np.ndarray:
    return np.asarray(lam * distances_p(s)[a, :] - advantage[s])


def regularizer(s: int, a: int, lam: float, advantage: np.ndarray, distances_p: callable) -> float:
    # if lam > 1e5:
    #     return 0.
    return np.min(regularizer_before_min(s=s, a=a, lam=lam, advantage=advantage, distances_p=distances_p))


def regularizer_mins(s: int, a: int, lam: float, advantage: np.ndarray, distances_p: callable) -> np.ndarray:
    # if lam > 1e5:
    #     return 0
    vals = regularizer_before_min(s=s, a=a, lam=lam, advantage=advantage, distances_p=distances_p)
    tol = 1e-2
    return np.where(np.min(vals) + tol >= vals)[0]


def regularizer_table(lam: float, advantage: np.ndarray, distances_p: callable, action_space: np.ndarray, state_space_idx: np.ndarray) -> np.ndarray:
    return np.array([[regularizer(s=s, a=a, lam=lam, advantage=advantage, distances_p=distances_p)
                      for a in action_space] for s in state_space_idx])


"""
def d_underbar_overbar_p(s: int, a: int, lam: float, advantage: np.ndarray, distances_p: callable) -> Tuple[float, float]:
    minimizers = regularizer_mins(s=s, a=a, lam=lam, advantage=advantage, distances_p=distances_p)
    dist_p = distances_p(s)[a, minimizers]
    return np.min(dist_p), np.max(dist_p)
"""


def solve_dual(policy: PolicyFiniteActionsFiniteStates,
               rho: ProbabilityEmpiricalMeasure,
               advantage,
               distances_p: callable,
               epsilon_p: float,
               action_space: np.ndarray,
               state_space: np.ndarray,
               logger: Logger,
               do_plot: float=False) -> Tuple[float, List[List[np.ndarray]]]:
    # find indices of the states (so that it works for continuous state spaces)
    state_space_idx = np.arange(len(state_space))

    # define objective function
    def obj(lam: float) -> float:
        # regularized advantage as a function of s and a, with its expected value
        r = regularizer_table(lam=lam,
                              advantage=advantage,
                              distances_p=distances_p,
                              action_space=action_space,
                              state_space_idx=state_space_idx)
        e = [policy(s_value).expected_value_function(r[s_idx]) for s_idx, s_value in enumerate(state_space)]
        return lam * epsilon_p - rho.expected_value_function(e)

    """
    # define gradient
    def grad(lam: float) -> float:
        grad_s = []
        for s_idx, s_val in enumerate(state_space):
            grad_s_a = []
            for a in action_space:
                d_underbar_p, d_overbar_p = d_underbar_overbar_p(s_idx, a, lam, advantage, distances_p)
                grad_s_a.append(0.5*d_underbar_p + 0.5*d_overbar_p)
            grad_s.append(policy.policy(s_val).expected_value_function(grad_s_a))
        return epsilon_p - rho.expected_value_function(grad_s)
    """

    # solve
    """
    left, right, tol = 0.0, 1e5, 1e-2
    f_left = obj(left)
    f_right = obj(right)
    while abs(right - left) > tol:
        middle = (right + left) / 2.0
        f_middle = obj(middle)
        if f_left < f_right:
            right = middle
            f_right = f_middle
        else:
            left = middle
            f_left = f_middle
    lambda_star = (right + left)/2.0

    """
    """
    # solve optimization problem (l non-negative)
    res = optimize.minimize(obj,
                            x0=np.array([1]), # this does not work
                            bounds=(optimize.Bounds(0, np.inf)),
                            method='L-BFGS-B',
                            options={'disp': False, 'maxiter': 100})
    lambda_1 = res.x[0]
    lambda_star = lambda_1
    """
    # minimize scalar (l non-negative)
    res = optimize.minimize_scalar(obj, bounds=(0, np.inf), tol=1e-5)
    lambda_2 = max(res.x, 0)  # sometimes it still outputs negative values
    lambda_star = lambda_2
    # plot
    if do_plot:
        points = np.linspace(0, 2*lambda_star, 200)
        plt.plot(points, [obj(x) for x in points])
        plt.scatter([lambda_star], [obj(lambda_star)])
        plt.show()
    # Log information
    # if logger is not None:
    #    logger.log({'_lambda_opt': lambda_1,
    #                '_lambda_scalar': lambda_2})
    # output minimizers
    mins = [[regularizer_mins(s=s,
                              a=a,
                              lam=lambda_star,
                              advantage=advantage,
                              distances_p=distances_p)
             for a in action_space] for s in state_space_idx]
    print(epsilon_p)
    print(lambda_star)
    return lambda_star, mins


def solve_primal(policy: PolicyFiniteActionsFiniteStates,
                 rho: ProbabilityEmpiricalMeasure,
                 distances_p: callable,
                 mins: List[List[np.ndarray]],
                 epsilon_p: float,
                 action_space: np.ndarray,
                 state_space: np.ndarray,
                 logger: Logger) -> PolicyFiniteActionsFiniteStates:
    # find indices of the states (so that it works for continuous state spaces)
    state_space_idx = np.arange(len(state_space))

    # find (p)-distance of current action from the minimizers of that action
    mins_distances_p = [[distances_p(s)[a, mins[s][a]]
                         for a in action_space] for s in state_space_idx]  # we could have use distances (no p) as well

    zero_probability = ProbabilityEmpiricalMeasure(action_space,
                                                   np.zeros(action_space.shape))
    pimin = PolicyFiniteActionsFiniteStates(state_space_idx,
                                            action_space,
                                            [zero_probability.copy() for _ in state_space_idx])
    pimax = PolicyFiniteActionsFiniteStates(state_space_idx,
                                            action_space,
                                            [zero_probability.copy() for _ in state_space_idx])
    d1 = 0
    d2 = 0
    for s_idx, s_value in enumerate(state_space):
        for a in action_space:
            # new actions
            a_min = mins[s_idx][a][np.argmin(mins_distances_p[s_idx][a])]
            a_max = mins[s_idx][a][np.argmax(mins_distances_p[s_idx][a])]
            prob_a = policy(s_value).get_probability(a)
            pimin(s_idx).add_probability(a_min, prob_a)
            pimax(s_idx).add_probability(a_max, prob_a)
            d1 += rho.get_probability(s_idx)*prob_a*distances_p(s_idx)[a, a_min]
            d2 += rho.get_probability(s_idx)*prob_a*distances_p(s_idx)[a, a_max]
    print(d1)
    print(d2)
    if d1 > epsilon_p:
        warnings.warn('Something went wrong: d1 > eps. Probably lambda is not optimal.')
    if d2 < epsilon_p:
        warnings.warn('Something went wrong: d2 < eps. Probably lambda is not optimal.')
    if d1 == d2:
        tstar = 0
    else:
        tstar = (d2 - epsilon_p)/(d2 - d1)  # max(min((self._epsilon_p() - d1) / (d2 - d1), 1), 0)
        tstar = float(np.clip(tstar, 0, 1))
    # # log data (to print as well)
    # if logger is not None:
    #     logger.log({'_d_min': d1,
    #                 '_d_max': d2,
    #                 '_d_true': d1 * (1 - tstar) + d2 * tstar})
    # compute new policy
    new_policy = []
    for s in state_space_idx:
        new_policy_s = tstar * pimin(s) + (1 - tstar) * pimax(s)
        if new_policy_s.is_normalized():
            new_policy_s.normalize()  # is_normalized() has numerical tolerance, so normalize anyway
            new_policy.append(new_policy_s)
        else:
            raise ValueError()
    return PolicyFiniteActionsFiniteStates(action_space=action_space,
                                           state_space=state_space_idx,
                                           policy=new_policy)

