from torch import logsumexp, softmax
from src.helpers import collapse_permutation_matrix_pytorch, collapse_reward_pytorch
import torch

def two_agent_value_iteration(P_a, rewards, gamma, error=0.001):
    """
    time-invariant soft value iteration function (to ensure that the policy is differentiable)
    - deterministic transition 
    - two agent collaborative foraging 
    - N_STATES, N_ACTIONS refers to single agent case (i.e. N_STATES = grid_H*grid_W; N_ACTIONS = 5)

    inputs:
        P_a        N_STATES**2 x N_ACTIONS**2, a permutation matrix P_a(s,a) to convert V(s) to V(s') based on action a
                                                i.e. V_new = V[P_a[:,a]]

        rewards     N_STATES**2 X 1 - R(s1,s2) where s1 is the location index of agent1
        gamma       float - RL discount
        error       float - threshold for a stop

    returns:
        values    N_STATES**2 x 1 matrix - V(s1,s2)
        policy    N_STATES**2 x N_ACTIONS**2
    """
    n1, n2 = P_a.shape

    values = torch.zeros((n1, 1), requires_grad=True)
    q_values = torch.zeros((n1, n2))

    # estimate values and q-values iteratively
    while True:
        values_tmp = values.clone()
        q_values = torch.cat([values_tmp[P_a[:, i].long()] for i in range(n2)], dim=1)
        values = rewards + gamma * logsumexp(q_values, dim=1)[:,None]
        if torch.max(torch.abs(values - values_tmp)) < error:
            break

    # generate policy
    policy = softmax(q_values, dim=1)

    return values, policy


def two_agent_value_iteration_independent_control(P_a, rewards, gamma, action_list=None, error=0.001):
    """
    time-invariant soft value iteration function (to ensure that the policy is differentiable)
    policy probability is derived from a independent control policy w.o. prediction of the other agent's policy
    N_STATES, N_ACTIONS refers to single agent case (i.e. N_STATES = grid_H*grid_W; N_ACTIONS = 5)

    inputs:
    P_a        N_STATES**2 x N_ACTIONS**2, a permutation matrix P_a(s,a) to convert V(s) to V(s') based on action a
                                    i.e. V_new = V[P_a[:,a]]         
    rewards     N_STATES**2 X 1 - R(s1,s2) where s1 is the location index of agent1
    gamma       float - RL discount
    error       float - threshold for a stop

    returns:
    values    N_STATES**2 x 1 matrix - V(s1,s2)
    policy_independent    N_STATES**2 x N_ACTIONS**2
    """
    _, n2 = P_a.shape
    N_ACTIONS = int(torch.sqrt(torch.tensor(n2)))
    values, jointpolicy = two_agent_value_iteration(P_a, rewards, gamma, error)

    if action_list is None:
        action_list = [[a1, a2] for a1 in range(N_ACTIONS) for a2 in range(N_ACTIONS)]

    idx1 = []
    for a1 in range(N_ACTIONS):
        idx1.append([i for i,ele in enumerate(action_list) if ele[0] == a1])
    idx2 = []
    for a2 in range(N_ACTIONS):
        idx2.append([i for i,ele in enumerate(action_list) if ele[1] == a2])

    p1 = torch.hstack([torch.sum(jointpolicy[:,idx1[i]],dim=1)[:,None] for i in range(N_ACTIONS)])
    p2 = torch.hstack([torch.sum(jointpolicy[:,idx2[i]],dim=1)[:,None] for i in range(N_ACTIONS)])

    policy_independent = torch.zeros(jointpolicy.shape)
    for i, a in enumerate(action_list):
        policy_independent[:,i] = p1[:,a[0]]*p2[:,a[1]]

    return values, policy_independent, p1, p2


def two_agent_value_iteration_independent_control_uniform_prediction(P_a, rewards, gamma, action_list=None, error=0.001):
    n1, n2 = P_a.shape
    N_STATES, N_ACTIONS = int(torch.sqrt(torch.tensor(n1))), int(torch.sqrt(torch.tensor(n2)))
    values, policy_joint = two_agent_value_iteration(P_a, rewards, gamma, error)

    if action_list is None:
        action_list = [[a1, a2] for a1 in range(N_ACTIONS) for a2 in range(N_ACTIONS)]

    idx1 = []
    for a1 in range(5):
        idx1.append([i for i,ele in enumerate(action_list) if ele[0] == a1])
    idx2 = []
    for a2 in range(5):
        idx2.append([i for i,ele in enumerate(action_list) if ele[1] == a2])
    p1 = torch.hstack([torch.sum(policy_joint[:,idx1[i]],dim=1)[:,None] for i in range(N_ACTIONS)])
    p2 = torch.hstack([torch.sum(policy_joint[:,idx2[i]],dim=1)[:,None] for i in range(N_ACTIONS)])

    c1 = torch.zeros((N_STATES**2, N_ACTIONS, N_ACTIONS)) # c1(s,a1,a2) = P(a1|a2,s)
    c2 = torch.zeros((N_STATES**2, N_ACTIONS, N_ACTIONS)) # c2(s,a2,a1) = P(a2|a1,s)
    for a_idx in range(N_ACTIONS):
        c1[:, :, a_idx] = policy_joint[:,idx2[a_idx]] / torch.hstack([p2[:,a_idx][:,None] for _ in range(N_ACTIONS)])
        c2[:, :, a_idx] = policy_joint[:,idx1[a_idx]] / torch.hstack([p1[:,a_idx][:,None] for _ in range(N_ACTIONS)])

    policy1 = 1/N_ACTIONS * torch.sum(c1, dim=2)
    policy2 = 1/N_ACTIONS * torch.sum(c2, dim=2)
    policy_independent = torch.zeros(policy_joint.shape)
    for i, a in enumerate(action_list):
        policy_independent[:,i] = policy1[:,a[0]]*policy2[:,a[1]]

    return values, policy_independent, policy1, policy2



# change this to pytorch
def two_agent_value_iteration_selfish(P_a, rewards, gamma, error=0.001, action_list=None):
    """

    N_STATES, N_ACTIONS refers to single agent case (i.e. N_STATES = grid_H*grid_W; N_ACTIONS = 5)

    inputs:
    P_a        N_STATES**2 x N_ACTIONS**2, a permutation matrix P_a(s,a) to convert V(s) to V(s') based on action a
                                            i.e. V_new = V[P_a[:,a]]
                                            
    rewards     N_STATES**2 X 1 - R(s1,s2) where s1 is the location index of agent1
    gamma       float - RL discount
    error       float - threshold for a stop

    returns:
    None
    policy    N_STATES**2 x N_ACTIONS**2 (p1 x p2)
    p1        N_STATES**2 x N_ACTIONS (p1)
    p2        N_STATES**2 x N_ACTIONS (p2)
    """
    n1, n2 = P_a.shape
    N_STATES, N_ACTIONS = int(torch.sqrt(torch.tensor(n1))), int(torch.sqrt(torch.tensor(n2)))
    if action_list is None:
        action_list = [[a1, a2] for a1 in range(N_ACTIONS) for a2 in range(N_ACTIONS)]
    P_a_single1 = collapse_permutation_matrix_pytorch(P_a,1)
    P_a_single2 = collapse_permutation_matrix_pytorch(P_a,2)
    reward1 = collapse_reward_pytorch(rewards,1)
    reward2 = collapse_reward_pytorch(rewards,2)
    _, single_policy1 = two_agent_value_iteration(P_a_single1, reward1, gamma, error)
    _, single_policy2 = two_agent_value_iteration(P_a_single2, reward2, gamma, error)

    p1 = single_policy1.repeat_interleave(N_STATES, dim=0)
    p2 = single_policy2.tile((N_STATES, 1))

    policy = torch.zeros((p1.shape[0], p1.shape[1] * p2.shape[1]))
    for i, a in enumerate(action_list):
        policy[:, i] = p1[:, a[0]] * p2[:, a[1]]

    return None, policy, p1, p2