from typing import List, Tuple
import numpy as np
from functions.env import MDP_no_R
from functions.utils import convert_to_objs, estimate_occupancy


def rob(
        # application g
        g: Tuple[np.ndarray,np.ndarray,np.ndarray],
        # feedbacks F
        fD: List[Tuple[np.ndarray,MDP_no_R,float]],
        fPC: List[Tuple[np.ndarray,np.ndarray,float,np.ndarray]],
        fTC: List[Tuple[np.ndarray,np.ndarray,float,np.ndarray]], 
        # reward dimensions
        S: int,
        A: int,
        H: int,
        n_objs: int,
        # hyperparameters
        alpha: float,
        K: int,
        r0: np.ndarray,
        s: float, 
) -> Tuple[float, float]:
    """
    g: 
    alpha: step size
    K: n iterations
    r0: initial reward
    s: bound on Lagrange multipliers
    """
    # estimate occupancy measures and convert to objs
    gg = (
        convert_to_objs(estimate_occupancy(S,A,H,*g[0]), g[2], n_objs),
        convert_to_objs(estimate_occupancy(S,A,H,*g[1]), g[2], n_objs),
    )
    ffD = []
    for (DD, MD, k) in fD:
        ffD.append((
            convert_to_objs(estimate_occupancy(S,A,H,*DD), MD.mapp, n_objs), MD, k
        ))
    ffPC = []
    for (D1, D2, k, mapp) in fPC:
        ffPC.append((
            convert_to_objs(estimate_occupancy(S,A,H,*D1), mapp, n_objs),
            convert_to_objs(estimate_occupancy(S,A,H,*D2), mapp, n_objs),
            k
        ))
    ffTC = []
    for (om1, om2, k, mapp) in fTC:
        ffTC.append((
            convert_to_objs(om1, mapp, n_objs),
            convert_to_objs(om2, mapp, n_objs),
            k
        ))
    
    # call PDSM
    _, m_hat, _, _ = PDSM(gg,ffD,ffPC,ffTC,K,alpha,s,r0)
    _, not_M_hat, _, _ = PDSM((gg[1],gg[0]),ffD,ffPC,ffTC,K,alpha,s,r0)  # pass reversed for maximization
    M_hat = -not_M_hat

    # compute targets
    x_hat = (M_hat+m_hat) / 2
    I_hat = (M_hat-m_hat) / 2

    return x_hat, I_hat


def PDSM(
        g,fD,fPC,fTC,K,alpha,s,r0,enforce_r_0=True
):
    """
    Take in input only vectors with size (n_objs,)
    """
    # create lists to store objects
    rewards = [np.copy(r0)]  # rewards computed at each iteration
    rewards_avg = [np.copy(r0)]  # average of the past rewards 
    objs = [np.dot(r0, g[0]-g[1])]  # value of the objective function

    # initial reward
    r_k = np.copy(r0)
    
    # set the number of constraints and initial Lagrange multipliers
    lambda_TC_k = np.zeros(len(fTC))
    lambda_PC_k = np.zeros(len(fPC))
    lambda_D_k = np.zeros(len(fD))

    # loop for K iterations
    for it in range(K+1):
        # if it == 0 or it % 50 ==0:
        #     print('It: ',it)

        ############ Compute subgradients
        # compute the subgradient for r
        grad_r = g[0]-g[1]
        for i, (om1, om2, _) in enumerate(fTC):
            grad_r += lambda_TC_k[i]*(om1-om2)
        for i,(d1, d2, _) in enumerate(fPC):
            grad_r += lambda_PC_k[i]*(d1-d2)
        for i,(dE,M,_) in enumerate(fD):
            pi_star, _ = M.compute_optimal_policy_and_performance(r_k)
            d_star = M.compute_visit_distribution_objects(pi_star)
            grad_r += lambda_D_k[i]*(d_star-dE)

        # compute the subgradient for lambda
        grad_lambda_TC = np.zeros_like(lambda_TC_k)
        for i, (om1, om2, k) in enumerate(fTC):
            grad_lambda_TC[i] = np.dot(r_k, om1-om2) - k
        grad_lambda_PC = np.zeros_like(lambda_PC_k)
        for i, (d1, d2, k) in enumerate(fPC):
            grad_lambda_PC[i] = np.dot(r_k, d1-d2) - k
        grad_lambda_D = np.zeros_like(lambda_D_k)
        for i, (dE,M,k) in enumerate(fD):
            _, J_star = M.compute_optimal_policy_and_performance(r_k)
            grad_lambda_D[i] = J_star - np.dot(r_k, dE) - k

        ############# Gradient step
        # subgradient step for r
        r_k1 = r_k - alpha*grad_r

        # subgradient step for lambda
        lambda_TC_k1 = lambda_TC_k + alpha*grad_lambda_TC
        lambda_PC_k1 = lambda_PC_k + alpha*grad_lambda_PC
        lambda_D_k1 = lambda_D_k + alpha*grad_lambda_D

        ############ Projection
        # projection for r
        proj_r_k1 = np.clip(r_k1,0,1)
        if enforce_r_0:
            proj_r_k1[0] = 0

        # projection step for lambda
        proj_lambda_TC_k1 = np.clip(lambda_TC_k1,0,s)
        proj_lambda_PC_k1 = np.clip(lambda_PC_k1,0,s)
        proj_lambda_D_k1 = np.clip(lambda_D_k1,0,s)

        ############ Update and append
        # reward
        r_k = proj_r_k1
        # lambda
        lambda_TC_k = proj_lambda_TC_k1
        lambda_PC_k = proj_lambda_PC_k1
        lambda_D_k = proj_lambda_D_k1
        # append
        rewards.append(np.copy(r_k))
        rewards_avg.append(np.mean(rewards, axis=0))
        objs.append(np.dot(rewards_avg[-1], g[0]-g[1]))

    return (rewards_avg[-1], objs[-1], rewards_avg, np.array(objs))