import os
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
import sys
import random
import numpy as np

import torch
import torch.optim as optim
import torch.nn.functional as F
from models.r2d2_final import OBLR2D2Agent, DIAL_OBLR2D2Agent, convert_msg_to_actions
from utils.memory import Memory, LocalBuffer, OBLMemory, OBLLocalBuffer, MIOBLMemory, MIOBLLocalBuffer, DIALRandomMemory, DialLocalBuffer,  MIIQLMemory, MIIQLLocalBuffer
from tensorboardX import SummaryWriter

from models.r2d2_config import initial_exploration, batch_size, dial_batch_size, update_target, log_interval, eval_argmax, eval_interval, device, replay_memory_capacity, lr, dial_lr, dial_iql_eps, sequence_length, local_mini_batch, dial_local_mini_batch, use_mi_loss, dru_sigma
from utils.pbmaze_config import env_config, iql_env_config, fixed_sender_config, fixed_sender_iql_env_config
from models.ct_util_config import ct_util_dict

from phone_booth_colab_maze_final import PBCMaze
from pbcmaze_belief_model import ReceiverBeliefModel, SenderBeliefModel
from collections import deque
from eval_final import evaluate, ctdu_evaluate, evaluate_policy, ctdu_evaluate_policy

RESULT_PATH = "results_final/"
MODEL_PATH = "trained_models_final/"
NUM_RUNS = 4

ACTIONS = list(range(6))
CTDU_ACTIONS = list(range(7))
CTDU_LEFT, CTDU_RIGHT, CTDU_UP, CTDU_DOWN, CTDU_NOOP, CTDU_SEND = ACTIONS
LEFT, RIGHT, UP, DOWN, NOOP, HINT_UP, HINT_DOWN = CTDU_ACTIONS

def set_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

# @profile
def main():
    # Save results
    if not os.path.exists(RESULT_PATH):
        os.makedirs(RESULT_PATH)
    if not os.path.exists(MODEL_PATH):
        os.makedirs(MODEL_PATH)

    ct_util_str = ''
    if(ct_util_dict['ct_util_type'] == 'dru'):
        ct_util_str = '_' + ct_util_dict['ct_util_type'] + '_' + ('sigmoid' if ct_util_dict['comm_narrow'] else 'softmax') + '_sigma' + str(ct_util_dict['sigma']) + '_' + ('hard' if ct_util_dict['hard'] else 'soft')
    elif(ct_util_dict['ct_util_type'] == 'gb_softmax'):
        ct_util_str = '_' + ct_util_dict['ct_util_type'] + '_' + ('tau' + str(ct_util_dict['tau'])) + '_' + ('st' if ct_util_dict['hard'] else 'reparam')
    elif(ct_util_dict['ct_util_type'] == 'rao_gb_softmax'):
        ct_util_str = '_' + ct_util_dict['ct_util_type'] + '_' + ('tau' + str(ct_util_dict['tau'])) + '_' + ('k' + str(ct_util_dict['k']))
    else:
        print("This shouldn't happen")
        exit()

    if(env_config['use_mi_shaping'] or use_mi_loss):
        sender_result_filename = "obl_dial_sender_time_to_pb" + ("_mi_log2" if env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "") + ("_argmax" if eval_argmax else "") + ct_util_str + ".npy"
        receiver_result_filename = "obl_dial_receiver_time_to_pb" + ("_mi_log2" if env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "") + ("_argmax" if eval_argmax else "") + ct_util_str + ".npy"
        reward_result_filename = "obl_dial_reward" + ("_mi_log2" if env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "") + ("_argmax" if eval_argmax else "") + ct_util_str + ".npy"
        running_reward_result_filename = "obl_dial_running_reward" + ("_mi_log2" if env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "") + ("_argmax" if eval_argmax else "") + ct_util_str + ".npy"
        sender_model_path = MODEL_PATH + "obl_dial_sender_model " + ("_mi_log2" if env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "") + ("_argmax" if eval_argmax else "") + ct_util_str
        receiver_model_path = MODEL_PATH + "obl_dial_receiver_model " + ("_mi_log2" if env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "") + ("_argmax" if eval_argmax else "") + ct_util_str

    elif(env_config['use_intermediate_reward']):
        sender_result_filename = "obl_dial_sender_time_to_pb" + "_ir" + ("_argmax" if eval_argmax else "") + ct_util_str + ".npy"
        receiver_result_filename = "obl_dial_receiver_time_to_pb" + "_ir" + ("_argmax" if eval_argmax else "") + ct_util_str + ".npy"
        reward_result_filename = "obl_dial_reward" + "_ir"  + ("_argmax" if eval_argmax else "") + ct_util_str + ".npy"
        running_reward_result_filename = "obl_dial_running_reward" + "_ir"  + ("_argmax" if eval_argmax else "") + ct_util_str + ".npy"
        sender_model_path = MODEL_PATH + "obl_dial_sender_model " + "_ir" + ("_argmax" if eval_argmax else "") + ct_util_str
        receiver_model_path = MODEL_PATH + "obl_dial_receiver_model " + "_ir" + ("_argmax" if eval_argmax else "") + ct_util_str

    else:
        sender_result_filename = "obl_dial_sender_time_to_pb" + ("_argmax" if eval_argmax else "") + ct_util_str + ".npy"
        receiver_result_filename = "obl_dial_receiver_time_to_pb" + ("_argmax" if eval_argmax else "") + ct_util_str + ".npy"
        reward_result_filename = "obl_dial_reward" + ("_argmax" if eval_argmax else "") + ct_util_str + ".npy"
        running_reward_result_filename = "obl_dial_running_reward" + ("_argmax" if eval_argmax else "") + ct_util_str + ".npy"
        sender_model_path = MODEL_PATH + "obl_dial_sender_model " + ("_argmax" if eval_argmax else "") + ct_util_str
        receiver_model_path = MODEL_PATH + "obl_dial_receiver_model " + ("_argmax" if eval_argmax else "") + ct_util_str

    sender_time_to_booth_result = []
    receiver_time_to_booth_result = []
    reward_result = []
    eval_reward_result = []
    running_reward_result = []
    runnning_eval_reward_result = []
    for run_idx in range(NUM_RUNS):
        print("Run: " + str(run_idx + 1))
        # Set seed
        set_seed(run_idx)

        # Env
        num_episodes = 80000
        num_episodes_for_mi_training = 8000
        stage_2_training = False
        stage_2_only_use_mi_loss = False

        env = PBCMaze(env_args=env_config)
        env.reset()
        eval_env = PBCMaze(env_args=env_config)
        eval_env.reset()
        eval_env.load_env_config(env.save_env_config())
        eval_env.use_mi_shaping = False
        eval_env.use_intermediate_reward = False

        """
        Agent 0 obs: ((channel, width, height), goal feature)
        Agent 1 obs: ((channel, width, height), communication token)
        """
        a0_input_shape  = env.get_obs_size(0)
        a1_input_shape = env.get_obs_size(1)
        a0_num_actions = 6
        a0_comm_bits = 2
        a1_num_actions = 5

        receiver_pi_0 = [0.2, 0.2, 0.2, 0.2, 0.2]
        sender_pi_0 = [1/7, 1/7, 1/7, 1/7, 1/7, 1/7, 1/7]
        rb_model = ReceiverBeliefModel(receiver_pi_0, env)
        sb_model = SenderBeliefModel(sender_pi_0, env)
        if(use_mi_loss):
            a0_agent = DIAL_OBLR2D2Agent(a0_input_shape, a0_num_actions, a0_comm_bits, MIIQLMemory(replay_memory_capacity), MIIQLLocalBuffer(), MIOBLMemory(replay_memory_capacity), MIOBLLocalBuffer(), DIALRandomMemory(replay_memory_capacity), DialLocalBuffer(), lr, batch_size, dial_batch_size, device, 0, rb_model, ct_util_dict, use_mi_loss, multi_pb = False)
        else:
            a0_agent = DIAL_OBLR2D2Agent(a0_input_shape, a0_num_actions, a0_comm_bits, Memory(replay_memory_capacity), LocalBuffer(), OBLMemory(replay_memory_capacity), OBLLocalBuffer(), DIALRandomMemory(replay_memory_capacity), DialLocalBuffer(), lr, batch_size, dial_batch_size, device, 0, rb_model, ct_util_dict, use_mi_loss, multi_pb = False)
        a1_agent = OBLR2D2Agent(a1_input_shape, a1_num_actions, a0_comm_bits, Memory(replay_memory_capacity), LocalBuffer(), OBLMemory(replay_memory_capacity), OBLLocalBuffer(), lr, batch_size, device, 1, sb_model, ct_util_dict, multi_pb = False)

        writer = SummaryWriter('logs')
        running_score = 0
        running_eval_score = 0
        epsilon = 1.0
        dial_eps = dial_iql_eps
        steps = 0
        loss = 0
        per_run_sender_time_to_booth_list = []
        per_run_receiver_time_to_booth_list = []
        per_run_reward = []
        per_run_eval_reward = []
        per_run_running_reward = []
        per_run_running_eval_reward = []

        for e in range(num_episodes):
            done = False

            score = 0
            a0_reward = None
            a1_reward = None
            obs, state = env.reset()

            a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
            a1_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
            a1_next_hidden = None

            # Dial transitions [agent 0 time t transition, agent 1 time t + 1 transition, agent 0 time t + 2 transition, agent 1 state and hidden ]
            dial_transitions = []
            eps_step = 0
            sys.stdout.flush()
            while not done:
                steps += 1

                if(len(dial_transitions) == 2):
                    dial_transitions.append(both_in_booth_flag)

                a0_obs = torch.Tensor(env.get_obs(0)).to(device)
                if(stage_2_training):
                    a0_policy, a0_action, a0_next_hidden, a0_message, _ = a0_agent.get_iql_action(a0_obs, dial_eps, a0_hidden)
                else:
                    a0_policy, a0_action, a0_next_hidden, a0_message = a0_agent.get_action(a0_obs, a0_hidden)

                # OBL Sampling - Stage 1
                a0_curr_env_config = env.save_env_config()
                if(stage_2_training == False or (stage_2_training and stage_2_only_use_mi_loss)):
                    a0_agent.obl_sampling(a0_hidden, a0_next_hidden, a0_policy.squeeze().detach().numpy(), a0_action, a0_curr_env_config, a1_agent, a1_next_hidden if a1_next_hidden != None else a1_hidden, message = a0_message)

                if len(a0_agent.local_buffer.memory) == local_mini_batch:
                    a0_agent.push_to_memory()
                # Agent 0's turn
                if(a0_action == CTDU_SEND):
                    a0_reward, done, a0_info = env.step(0, convert_msg_to_actions(a0_message.size(-1), a0_message), policy = a0_policy.squeeze().detach().numpy())
                else:
                    a0_reward, done, a0_info = env.step(0, a0_action, policy = a0_policy.squeeze().detach().numpy())
                both_in_booth_flag = env.sender_in_booth and env.receiver_in_booth
                mask = 0 if done else 1

                # Add to agent 1's IQL buffer
                if(stage_2_training):
                    if(a1_reward != None):
                        # Add to agent 1's buffer
                        mask = 0 if done else 1
                        next_a1_obs = torch.Tensor(env.get_obs(1)).to(device)
                        a1_agent.iql_buffer.push(a1_obs, next_a1_obs, a1_action, a1_reward, mask, a1_hidden)
                        if len(a1_agent.iql_buffer.memory) == local_mini_batch:
                            a1_agent.push_to_iql_memory()

                    # Agent 1's IQL learning
                    if len(a1_agent.iql_memory) > batch_size:
                        loss, td_error = a1_agent.train_iql_model()
                        if steps % update_target == 0:
                            a1_agent.update_target_model()

                    dial_transitions.append([a0_obs, a0_hidden, a0_action, torch.Tensor(env.get_obs(0)).to(device), a0_next_hidden, a0_reward, mask])
                    eps_step += 1

                # Update after a0 has taken an action
                if(a1_next_hidden != None):
                    a1_hidden = a1_next_hidden

                # Agent 0's OBL learning
                if steps > initial_exploration and len(a0_agent.memory) > batch_size:
                    if(stage_2_training == False):
                        loss, td_error = a0_agent.train_model(obl = True, use_mi_loss = use_mi_loss)
                    elif((stage_2_training and stage_2_only_use_mi_loss)):
                        loss, td_error = a0_agent.train_model(obl = True, use_mi_loss = use_mi_loss, use_only_mi_loss = stage_2_only_use_mi_loss)
                    if steps % update_target == 0:
                        a0_agent.update_target_model()

                # DIAL training
                if len(a0_agent.dial_memory) > dial_batch_size:
                    a0_agent.train_model_dial(a1_agent)

                # Update belief
                a1_agent.belief_model.update_belief(comm_token = env.comm_token)

                a1_obs = torch.Tensor(env.get_obs(1)).to(device)
                if(stage_2_training):
                    a1_policy, a1_action, a1_next_hidden, _, _ = a1_agent.get_iql_action(a1_obs, dial_eps, a1_hidden)
                else:
                    a1_policy, a1_action, a1_next_hidden, _ = a1_agent.get_action(a1_obs, a1_hidden)

                a1_curr_env_config = env.save_env_config()
                # OBL Sampling - Stage 1
                if(stage_2_training == False):
                    a1_agent.obl_sampling(a1_hidden, a1_next_hidden, a1_policy.squeeze().detach().numpy(), a1_action, a1_curr_env_config, a0_agent, a0_next_hidden)
                if len(a1_agent.local_buffer.memory) == local_mini_batch:
                    a1_agent.push_to_memory()

                assert both_in_booth_flag == (env.sender_in_booth and env.receiver_in_booth)

                # Agent 1's turn
                a1_reward, done, info = env.step(1, a1_action)
                mask = 0 if done else 1

                # Add to agent 0's buffer + DIAL
                if(stage_2_training):
                    next_a0_obs = torch.Tensor(env.get_obs(0)).to(device)
                    if(use_mi_loss):
                        a0_agent.iql_buffer.push(a0_obs, next_a0_obs, a0_action, a0_reward + a1_reward, mask, torch.Tensor(a0_info['mi_term_masks']).to(device), a0_hidden)
                    else:
                        a0_agent.iql_buffer.push(a0_obs, next_a0_obs, a0_action, a0_reward + a1_reward, mask, a0_hidden)
                    if len(a0_agent.iql_buffer.memory) == local_mini_batch:
                        a0_agent.push_to_iql_memory(use_mi_loss = use_mi_loss)

                    if(done):
                        # Need to add to a1's buffer
                        next_a1_obs = torch.Tensor(env.get_obs(1)).to(device)
                        a1_agent.iql_buffer.push(a1_obs, next_a1_obs, a1_action, a1_reward, mask, a1_hidden)
                        if len(a1_agent.iql_buffer.memory) == local_mini_batch:
                            a1_agent.push_to_iql_memory()

                    # Agent 0's IQL learning
                    if len(a0_agent.iql_memory) > batch_size:
                        loss, td_error = a0_agent.train_iql_model(use_mi_loss = use_mi_loss)
                        if steps % update_target == 0:
                            a0_agent.update_target_model()

                    dial_transitions.append([a1_obs, a1_hidden, a1_action, torch.Tensor(env.get_obs(1)).to(device), a1_next_hidden, a1_reward, mask])
                    dial_eps -= 0.00001
                    dial_eps = max(dial_eps, 0.1)
                    if(len(dial_transitions) % 5 == 0 or done):
                        # compute target to add to buffer
                        dial_transitions.append(both_in_booth_flag)
                        a0_agent.dial_push_to_local_buffer(dial_transitions, a1_agent)
                        # Remove the first 2 transitions and in booth flag
                        dial_transitions = dial_transitions[3 : ]
                        # Remove the last in booth flag
                        dial_transitions = dial_transitions[ : 2]


                a0_hidden = a0_next_hidden

                if(len(a0_agent.dial_buffer.memory) == dial_local_mini_batch):
                    a0_agent.push_to_dial_memory(a1_agent)

                # Agent 1's OBL learning
                if steps > initial_exploration and len(a1_agent.memory) > batch_size:
                    if(stage_2_training == False):
                        # Only train obl in stage 1
                        loss, td_error = a1_agent.train_model(obl = True)
                    if steps % update_target == 0:
                        a1_agent.update_target_model()

                # Update belief
                a0_agent.belief_model.update_belief()

                # assert both_in_booth_flag == (env.sender_in_booth and env.receiver_in_booth)

                score += a0_reward + a1_reward

            running_score = 0.99 * running_score + 0.01 * score
            # Steps to phone booth
            if(eval_argmax):
                if e % eval_interval == 0:
                    if(e % 2000 == 0):
                        eval_score, info = ctdu_evaluate(eval_env, a0_agent, a1_agent, print_actions = True)
                    else:
                        eval_score, info = ctdu_evaluate(eval_env, a0_agent, a1_agent)
                    # evaluate_policy(a0_agent, a1_agent, iql_env_config)
                    running_eval_score = 0.99 * running_eval_score + 0.01 * eval_score
                    sender_time_to_pb = info["sender_time_to_booth"]
                    receiver_time_to_pb = info["receiver_time_to_booth"]
                    per_run_sender_time_to_booth_list.append(sender_time_to_pb)
                    per_run_receiver_time_to_booth_list.append(receiver_time_to_pb)
                    per_run_eval_reward.append(eval_score)
                    per_run_running_eval_reward.append(running_eval_score)
                    if(e % 2000 == 0):
                        print('Run {} | {} episode | score: {:.2f} | reward sum: {:.2f} | SenderToPB: {:.2f} | ReceiverToPB: {:.2f} | dial eps: {:.2f}'.format(
                            run_idx + 1, e, running_eval_score, eval_score, sender_time_to_pb, receiver_time_to_pb, dial_eps))
                        ctdu_evaluate_policy(a0_agent, a1_agent, iql_env_config)
                    sys.stdout.flush()
            else:
                sender_time_to_pb = info["sender_time_to_booth"]
                receiver_time_to_pb = info["receiver_time_to_booth"]
                per_run_sender_time_to_booth_list.append(sender_time_to_pb)
                per_run_receiver_time_to_booth_list.append(receiver_time_to_pb)
                per_run_reward.append(score)
                per_run_running_reward.append(running_score)
                if e % log_interval == 0:
                    print('Run {} | {} episode | score: {:.2f} | reward sum: {:.2f} | SenderToPB: {:.2f} | ReceiverToPB: {:.2f}'.format(
                        run_idx + 1, e, running_score, score, sender_time_to_pb, receiver_time_to_pb))
                    writer.add_scalar('log/score', float(running_score), e)
                    writer.add_scalar('log/loss', float(loss), e)
                    sys.stdout.flush()


            # Reset belief
            a0_agent.belief_model.reset_belief()
            a1_agent.belief_model.reset_belief()

            # turn off mi training
            if((e + 1) >= num_episodes_for_mi_training and stage_2_training == False):
                # env.turn_off_mi_training()
                stage_2_training = True
                print("Stage 2 training starts")
                sys.stdout.flush()

        sender_time_to_booth_result.append(per_run_sender_time_to_booth_list)
        receiver_time_to_booth_result.append(per_run_receiver_time_to_booth_list)
        if(eval_argmax):
            eval_reward_result.append(per_run_eval_reward)
            runnning_eval_reward_result.append(per_run_running_eval_reward)
        else:
            reward_result.append(per_run_reward)
            running_reward_result.append(per_run_running_reward)

        # Save model
        a0_agent.save_model(sender_model_path + "_{}.pt".format(run_idx + 1))
        a1_agent.save_model(receiver_model_path + "_{}.pt".format(run_idx + 1))


    if(eval_argmax):
        np.save(RESULT_PATH + reward_result_filename, np.array(eval_reward_result))
        np.save(RESULT_PATH + running_reward_result_filename, np.array(runnning_eval_reward_result))
    else:
        np.save(RESULT_PATH + reward_result_filename, np.array(reward_result))
        np.save(RESULT_PATH + running_reward_result_filename, np.array(running_reward_result))

    np.save(RESULT_PATH + sender_result_filename, np.array(sender_time_to_booth_result))
    np.save(RESULT_PATH + receiver_result_filename, np.array(receiver_time_to_booth_result))


if __name__=="__main__":
    main()
