import torch
from models.r2d2_config import device
from phone_booth_colab_maze_final import PBCMaze
from models.r2d2_final import convert_msg_to_actions

ACTIONS = list(range(6))
CTDU_ACTIONS = list(range(7))
CTDU_LEFT, CTDU_RIGHT, CTDU_UP, CTDU_DOWN, CTDU_NOOP, CTDU_SEND = ACTIONS
LEFT, RIGHT, UP, DOWN, NOOP, HINT_UP, HINT_DOWN = CTDU_ACTIONS

def evaluate(eval_env, a0_agent, a1_agent, print_actions = False):
    done = False

    score = 0
    steps = 0
    a0_reward = None
    a1_reward = None
    obs, state = eval_env.reset()

    a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
    a1_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))

    with torch.no_grad():
        a0_actions = []
        a1_actions = []
        while not done:
            # Agent 0's turn
            a0_obs = torch.Tensor(eval_env.get_obs(0)).to(device)
            a0_policy, a0_action, a0_next_hidden, a0_message = a0_agent.get_action(a0_obs, a0_hidden, argmax = True)
            a0_reward, done, info = eval_env.step(0, a0_action, policy = a0_policy.squeeze().detach().numpy())
            a0_hidden = a0_next_hidden
            # Agent 1's turn
            a1_obs = torch.Tensor(eval_env.get_obs(1)).to(device)
            a1_policy, a1_action, a1_next_hidden, _ = a1_agent.get_action(a1_obs, a1_hidden, argmax = True)
            a1_reward, done, info = eval_env.step(1, a1_action)
            a1_hidden = a1_next_hidden
            score += a0_reward + a1_reward
            a0_actions.append(a0_action.item())
            a1_actions.append(a1_action.item())
    if(print_actions):
        print("a0 actions: {}".format(a0_actions))
        print("a1 actions: {}".format(a1_actions))
    return score, info

def ctdu_evaluate(eval_env, a0_agent, a1_agent, print_actions = False):
    done = False

    score = 0
    steps = 0
    a0_reward = None
    a1_reward = None
    obs, state = eval_env.reset()

    a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
    a1_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))

    with torch.no_grad():
        a0_actions = []
        a1_actions = []
        while not done:
            # Agent 0's turn
            a0_obs = torch.Tensor(eval_env.get_obs(0)).to(device)
            a0_policy, a0_action, a0_next_hidden, a0_message = a0_agent.get_action(a0_obs, a0_hidden, argmax = True)
            if(a0_action == CTDU_SEND):
                a0_reward, done, info = eval_env.step(0, convert_msg_to_actions(a0_message.size(-1), a0_message), policy = a0_policy.squeeze().detach().numpy())
            else:
                a0_reward, done, info = eval_env.step(0, a0_action, policy = a0_policy.squeeze().detach().numpy())
            a0_hidden = a0_next_hidden
            # Agent 1's turn
            a1_obs = torch.Tensor(eval_env.get_obs(1)).to(device)
            a1_policy, a1_action, a1_next_hidden, _ = a1_agent.get_action(a1_obs, a1_hidden, argmax = True)
            a1_reward, done, info = eval_env.step(1, a1_action)
            a1_hidden = a1_next_hidden
            score += a0_reward + a1_reward
            a0_actions.append(a0_action.item())
            a1_actions.append(a1_action.item())
    if(print_actions):
        print("a0 actions: {}".format(a0_actions))
        print("a1 actions: {}".format(a1_actions))
    return score, info

def evaluate_policy(a0_agent, a1_agent, iql_env_config):
    with torch.no_grad():
        env = PBCMaze(env_args=iql_env_config)
        env.reset()
        a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
        a1_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
        while(env.agent0_loc[0] != env.booth_loc or env.agent1_loc[0] != env.receiver_booth_loc):
            a0_obs = torch.Tensor(env.get_obs(0)).to(device)
            _, _, a0_hidden, a0_message = a0_agent.get_action(a0_obs, a0_hidden, argmax = True)
            env.step(0, 1)
            a1_obs = torch.Tensor(env.get_obs(1)).to(device)
            _, _, a1_hidden, _ = a1_agent.get_action(a1_obs, a1_hidden, argmax = True)
            env.step(1, 0)
        env.goal = 2
        a0_obs = torch.Tensor(env.get_obs(0)).to(device)
        a1_obs = torch.Tensor(env.get_obs(1)).to(device)
        a0_policy, a0_up_action, _, _ = a0_agent.get_action(a0_obs, a0_hidden, argmax = True, discretize_message = False)
        a1_policy, _, _, _ = a1_agent.get_action(a1_obs, a1_hidden, argmax = True)

        # print(a0_policy)
        # print(a1_policy)
        env.goal = 3
        a0_obs = torch.Tensor(env.get_obs(0)).to(device)
        a1_obs = torch.Tensor(env.get_obs(1)).to(device)
        a0_policy, a0_down_action, _, _ = a0_agent.get_action(a0_obs, a0_hidden, argmax = True, discretize_message = False)
        a1_policy, _, _, _ = a1_agent.get_action(a1_obs, a1_hidden, argmax = True)
        # a0_policy, _, _, a0_down_message = a0_agent.get_action(a0_obs, a0_hidden, argmax = True, discretize_message = True)

        print("Goal: UP, action: {}".format(a0_up_action))
        print("Goal: DOWN, action: {}".format(a0_down_action))

        # print(a0_policy)
        # print(a1_policy)

def ctdu_evaluate_policy(a0_agent, a1_agent, iql_env_config):
    with torch.no_grad():
        env = PBCMaze(env_args=iql_env_config)
        env.reset()
        a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
        a1_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
        while(env.agent0_loc[0] != env.booth_loc or env.agent1_loc[0] != env.receiver_booth_loc):
            a0_obs = torch.Tensor(env.get_obs(0)).to(device)
            _, _, a0_hidden, a0_message = a0_agent.get_action(a0_obs, a0_hidden, argmax = True)
            env.step(0, 1)
            a1_obs = torch.Tensor(env.get_obs(1)).to(device)
            _, _, a1_hidden, _ = a1_agent.get_action(a1_obs, a1_hidden, argmax = True)
            env.step(1, 0)
        env.goal = 2
        a0_obs = torch.Tensor(env.get_obs(0)).to(device)
        a1_obs = torch.Tensor(env.get_obs(1)).to(device)
        a0_policy, _, _, a0_up_message = a0_agent.get_action(a0_obs, a0_hidden, argmax = True, discretize_message = False)
        a1_policy, _, _, _ = a1_agent.get_action(a1_obs, a1_hidden, argmax = True)

        # print(a0_policy)
        # print(a1_policy)
        env.goal = 3
        a0_obs = torch.Tensor(env.get_obs(0)).to(device)
        a1_obs = torch.Tensor(env.get_obs(1)).to(device)
        a0_policy, _, _, a0_down_message = a0_agent.get_action(a0_obs, a0_hidden, argmax = True, discretize_message = False)
        a1_policy, _, _, _ = a1_agent.get_action(a1_obs, a1_hidden, argmax = True)
        # a0_policy, _, _, a0_down_message = a0_agent.get_action(a0_obs, a0_hidden, argmax = True, discretize_message = True)
