import numpy as np
from functions.env import MDP_no_R
from functions.utils import convert_to_trajectories, convert_to_objs

def gen_r_star(n_objs):
    # generate random in [0,1)
    r_star = np.random.random(n_objs)

    # set N = 0 always
    r_star[0] = 0

    return r_star

def gen_initial_state(S):
    # sample uniformly from the state space
    return np.random.choice(S)
    
def gen_transition_model(S,A,H):
    # sample uniformly from the simplex
    return np.random.dirichlet((1,)*S, (H,S,A))

def gen_MDP(S,A,H,n_objs):
    s0 = gen_initial_state(S)
    p = gen_transition_model(S,A,H)
    mapp = gen_mapp(S,H,n_objs)

    return MDP_no_R(S,A,H,p,s0,mapp,n_objs)

def gen_policy(S,A,H,stochastic=False):
    if stochastic:
        # sample uniformly from the simplex
        pi = np.random.dirichlet((1,)*S, size=(H,S,A))
    else:
        # if deterministic, sample each action randomly
        actions = np.random.choice(A, size=(H,S))
        pi = np.eye(A)[actions]

    return pi

def get_random_policy(S,A,H):
    # return uniform policy
    return np.ones((H,S,A)) / A

def gen_mapp(S,H,n_objs):
    # a mapp assigns an object to each s,h
    return np.random.randint(0, n_objs, size=(H, S))


def gen_problem_instance(
          S,A,H,n_objs,n1, n2, nD, nPC, nTC,
          N1, N2, ND, NPC, NTC,
          mD, mPC, mTC, theta
):
    """
    Randomly generate a problem instance for rob with the specified parameter
    values.

    theta: intuitively, how informative are the feedback

    exact quantities are returned in terms of n_objs
    """
    # generate target reward
    r_star = gen_r_star(n_objs)

    # generate application g
    M_target = gen_MDP(S,A,H,n_objs)
    pi1 = gen_policy(S,A,H)
    pi2 = gen_policy(S,A,H)
    d1 = M_target.compute_visit_distribution_objects(pi1)
    d2 = M_target.compute_visit_distribution_objects(pi2)
    D1 = M_target.collect_trajectories(pi1, n1)
    D2 = M_target.collect_trajectories(pi2, n2)
    g = (D1,D2,M_target.mapp)
    g_exact = (d1,d2)

    # generate feedback fD
    fD = []
    fD_exact = []
    for _ in range(mD):
            MD = gen_MDP(S,A,H,n_objs)

            # expert's demonstrations
            piD = gen_policy(S,A,H)
            d = MD.compute_visit_distribution_objects(piD)
            DD = MD.collect_trajectories(piD, nD)

            # RF-Express to estimate transition model
            p_hat = MD.RF_Express(ND)
            MD_hat = MDP_no_R(S,A,H,p_hat,MD.s0,MD.mapp,n_objs)

            # sample feedback
            _, J_star = MD.compute_optimal_policy_and_performance(r_star)
            JD = MD.compute_policy_performance(r_star, piD)
            assert JD <= J_star
            k = np.random.random()*theta +(J_star-JD)  # sample randomly in the feasible values

            # append
            fD.append((DD,MD_hat,k))
            fD_exact.append((d,MD,k))

    # generate feedback fPC
    fPC = []
    fPC_exact = []
    for _ in range(mPC):
            MPC = gen_MDP(S,A,H,n_objs)

            # policies
            piPC1 = gen_policy(S,A,H)
            piPC2 = gen_policy(S,A,H)
            d1 = MPC.compute_visit_distribution_objects(piPC1)
            d2 = MPC.compute_visit_distribution_objects(piPC2)
            DPC1 = MPC.collect_trajectories(piPC1, nPC)
            DPC2 = MPC.collect_trajectories(piPC2, nPC)

            # sample feedback
            JPC1 = MPC.compute_policy_performance(r_star, piPC1)
            JPC2 = MPC.compute_policy_performance(r_star, piPC2)
            k = np.random.random()*theta +(JPC1-JPC2)  # sample randomly in the feasible values

            # append
            fPC.append((DPC1,DPC2,k,MPC.mapp))
            fPC_exact.append((d1,d2,k))

    # generate feedback fTC
    fTC = []
    fTC_exact = []
    for _ in range(mTC):
            MTC = gen_MDP(S,A,H,n_objs)

            # trajectories
            pi_unif = get_random_policy(S,A,H)
            states, _ = MTC.collect_trajectories(pi_unif, 2)
            omega1, omega2 = convert_to_trajectories(states,S,H)

            # sample feedback
            dobj1 = convert_to_objs(omega1,MTC.mapp,n_objs)
            dobj2 = convert_to_objs(omega2,MTC.mapp,n_objs)
            JTC1 = np.dot(dobj1, r_star)
            JTC2 = np.dot(dobj2, r_star)
            k = np.random.random()*theta +(JTC1-JTC2)  # sample randomly in the feasible values

            # append
            fTC.append((omega1,omega2,k,MTC.mapp))
            fTC_exact.append((dobj1,dobj2,k))
    
    return (g,g_exact), (fD,fD_exact), (fPC,fPC_exact), (fTC,fTC_exact), r_star

def get_policy_always_same_action(
        S: int,
        A: int,
        H: int,
        action: int,
        noise: float = 0
):
    """
    Construct a policy that plays always the same action, with potentially some
    noise: a float in [0,0.5]. 0 means no noise. 0.5 means that the action in
    input is never played, but always the other two actions are played.
    """
    if noise < 0 or noise > 0.5:
        raise Exception('Invalid noise value.')
    
    pi = np.ones((H,S,A)) * noise

    pi[:,:,action] = 1-2*noise

    return pi

def get_my_MDP(s0, mapp):
    """
    Construct example in paper. Three lanes, four objects, transition model
    described in the appendix.
    """
    S,A,H = 3,3,5

    # transition model (stationary)
    paL = [[1, 0, 0],
            [0.6, 0.4, 0],
            [0, 0.6, 0.4]]
    paC = [[0.55, 0.45, 0],
            [0.3, 0.4, 0.3],
            [0, 0.45, 0.55]]
    paR = [[0.3, 0.7, 0],
            [0, 0.3, 0.7],
            [0, 0, 1]]
    p = np.zeros((H,S,A,S))
    for h in range(H):
        p[h,:,0,:] = paL
        p[h,:,1,:] = paC
        p[h,:,2,:] = paR
    
    # n_objs
    n_objs = 4

    return MDP_no_R(S,A,H,p,s0,mapp,n_objs)

def get_my_problem_instance():
    """
    Generate the problem instance used in the simulation in the paper.
    """
    # 3 lanes, 3 actions, 5 steps
    S,A,H = 3,3,5

    # r_star
    r_star = np.array([0,0.7,0.2,0.1])

    ### g
    # target MDP
    M = get_my_MDP(1, np.array([[0, 2, 0, 0, 1],
                    [0, 0, 1, 2, 1],
                    [0, 0, 3, 2, 2]]).transpose())

    # get policies
    pi1 = get_policy_always_same_action(S,A,H,2,noise=0)
    pi2 = get_policy_always_same_action(S,A,H,0,noise=0)

    # compute occupancies
    d1, _ = M.compute_visit_distribution(pi1)
    d2, _ = M.compute_visit_distribution(pi2)
    
    # g
    g = (d1,d2,M.mapp)

    ### fD
    # MD
    M = get_my_MDP(2, np.array([[0, 0, 3, 1, 0],
                    [0, 3, 2, 0, 1],
                    [0, 0, 1, 0, 3]]).transpose())
    
    # get policy
    pi = get_policy_always_same_action(S,A,H,1,noise=0.2)
    d, _ = M.compute_visit_distribution(pi)

    # fD
    fD = [(d,M,1)]

    ### fPC
    # MPC
    M = get_my_MDP(0, np.array([[0, 0, 2, 3, 0],
                    [0, 3, 0, 1, 0],
                    [0, 0, 1, 0, 3]]).transpose())
    # get policies
    pi1 = get_policy_always_same_action(S,A,H,0,noise=0.15)
    pi2 = get_policy_always_same_action(S,A,H,1,noise=0.3)
    pi3 = get_policy_always_same_action(S,A,H,1,noise=0.5)
    pi4 = get_policy_always_same_action(S,A,H,2,noise=0.15)

    # get occupancies
    d1, _ = M.compute_visit_distribution(pi1)
    d2, _ = M.compute_visit_distribution(pi2)
    d3, _ = M.compute_visit_distribution(pi3)
    d4, _ = M.compute_visit_distribution(pi4)

    # fPC
    fPC = [(d1,d2,0,M.mapp), (d4,d3,0.5,M.mapp)]

    ### fTC
    # same env as target
    M = get_my_MDP(1, np.array([[0, 2, 0, 0, 1],
                    [0, 0, 1, 2, 1],
                    [0, 0, 3, 2, 2]]).transpose())
    
    # get trajectories
    trajs = [np.array([[0,1,0],[0,0,1],[0,1,0],[1,0,0],[1,0,0]]),
             np.array([[0,1,0],[0,1,0],[0,1,0],[0,1,0],[0,1,0]]),
             np.array([[0,1,0],[1,0,0],[0,1,0],[0,1,0],[0,0,1]]),
             np.array([[0,1,0],[0,0,1],[0,0,1],[0,0,1],[0,0,1]]),
             np.array([[0,1,0],[1,0,0],[1,0,0],[0,1,0],[0,0,1]]),
             np.array([[0,1,0],[0,0,1],[0,1,0],[1,0,0],[1,0,0]]),]
    
    # fTC
    fTC = [(trajs[0],trajs[1],0.3,M.mapp),(trajs[2],trajs[3],1,M.mapp),(trajs[4],trajs[5],-0.5,M.mapp)]

    return r_star, g, fD, fPC, fTC