import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
import sys
import random
import numpy as np

import torch
import torch.optim as optim
import torch.nn.functional as F
from models.r2d2_final import OBLR2D2Agent
from utils.memory import Memory, LocalBuffer, OBLMemory, OBLLocalBuffer, MIOBLMemory, MIOBLLocalBuffer
from tensorboardX import SummaryWriter

from models.r2d2_config import initial_exploration, batch_size, update_target, log_interval, eval_argmax, eval_interval, device, replay_memory_capacity, lr, sequence_length, local_mini_batch, use_mi_loss, dial_iql_eps, dru_sigma
from utils.pbmaze_config import env_config, iql_env_config, fixed_sender_config, fixed_sender_iql_env_config
from phone_booth_colab_maze_final import PBCMaze
from pbcmaze_belief_model_ecs import ReceiverBeliefModel, SenderBeliefModel
from collections import deque
from eval_final import evaluate, evaluate_policy
from models.ct_util_config import ct_util_dict


RESULT_PATH = "results_final/"
MODEL_PATH = "trained_models_final/"
NUM_RUNS = 4

def set_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

def main():
    # Save results
    if not os.path.exists(RESULT_PATH):
        os.makedirs(RESULT_PATH)
    if not os.path.exists(MODEL_PATH):
        os.makedirs(MODEL_PATH)

    use_iql_for_stage_2 = False

    if(env_config['use_mi_shaping'] or use_mi_loss):
        sender_result_filename = "obl_sender_time_to_pb" + ("_mi_log2" if env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "")  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "") + ".npy"
        receiver_result_filename = "obl_receiver_time_to_pb" + ("_mi_log2" if env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "")  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "") + ".npy"
        reward_result_filename = "obl_reward" + ("_mi_log2" if env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "")  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "") + ".npy"
        mi_reward_result_filename = "obl_mi_reward" + ("_mi_log2" if env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "")  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "") + ".npy"
        running_reward_result_filename = "obl_running_reward" + ("_mi_log2" if env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "")  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "") + ".npy"
        sender_model_path = MODEL_PATH + "obl_sender_model " + ("_mi_log2" if env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "")  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "")
        receiver_model_path = MODEL_PATH + "obl_receiver_model " + ("_mi_log2" if env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "")  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "")

    elif(env_config['use_intermediate_reward']):
        sender_result_filename = "obl_sender_time_to_pb" + "_ir" + ("_argmax" if eval_argmax else "") + ".npy"
        receiver_result_filename = "obl_receiver_time_to_pb" + "_ir" + ("_argmax" if eval_argmax else "") + ".npy"
        reward_result_filename = "obl_reward" + "_ir"  + ("_argmax" if eval_argmax else "") + ".npy"
        running_reward_result_filename = "obl_running_reward" + "_ir"  + ("_argmax" if eval_argmax else "") + ".npy"
        sender_model_path = MODEL_PATH + "obl_sender_model " + "_ir" + ("_argmax" if eval_argmax else "")
        receiver_model_path = MODEL_PATH + "obl_receiver_model " + "_ir" + ("_argmax" if eval_argmax else "")

    else:
        sender_result_filename = "obl_sender_time_to_pb"  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "") + ".npy"
        receiver_result_filename = "obl_receiver_time_to_pb"  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "") + ".npy"
        reward_result_filename = "obl_reward"  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "") + ".npy"
        running_reward_result_filename = "obl_running_reward"  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "") + ".npy"
        sender_model_path = MODEL_PATH + "obl_sender_model "  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "")
        receiver_model_path = MODEL_PATH + "obl_receiver_model "  + ("_util_iql" if use_iql_for_stage_2 else "") + ("_argmax" if eval_argmax else "")

    sender_time_to_booth_result = []
    receiver_time_to_booth_result = []
    reward_result = []
    eval_reward_result = []
    eval_mi_reward_result = []
    running_reward_result = []
    runnning_eval_reward_result = []
    for run_idx in range(NUM_RUNS):
        print("Run: " + str(run_idx + 1))
        # Set seed
        set_seed(run_idx)

        # Env
        num_episodes = 80000
        num_episodes_for_mi_training = 8000
        stage_2_training = False

        env = PBCMaze(env_args=env_config)
        env.reset()
        eval_env = PBCMaze(env_args=env_config)
        eval_env.reset()
        eval_env.load_env_config(env.save_env_config())
        eval_env.use_mi_shaping = False
        eval_env.use_intermediate_reward = False

        """
        Agent 0 obs: ((channel, width, height), goal feature)
        Agent 1 obs: ((channel, width, height), communication token)
        """
        a0_input_shape  = env.get_obs_size(0)
        a1_input_shape = env.get_obs_size(1)
        a0_num_actions = 7
        a1_num_actions = 5
        a0_comm_bits = 2

        receiver_pi_0 = [0.2, 0.2, 0.2, 0.2, 0.2]
        sender_pi_0 = [1/6, 1/6, 1/6, 1/6, 1/6, 1/6]
        rb_model = ReceiverBeliefModel(receiver_pi_0, env)
        sb_model = SenderBeliefModel(sender_pi_0, env)
        if(use_mi_loss):
            a0_agent = OBLR2D2Agent(a0_input_shape, a0_num_actions, a0_comm_bits, Memory(replay_memory_capacity), LocalBuffer(), MIOBLMemory(replay_memory_capacity), MIOBLLocalBuffer(), lr, batch_size, device, 0, rb_model, ct_util_dict, use_mi_loss, multi_pb = False)
        else:
            a0_agent = OBLR2D2Agent(a0_input_shape, a0_num_actions, a0_comm_bits, Memory(replay_memory_capacity), LocalBuffer(), OBLMemory(replay_memory_capacity), OBLLocalBuffer(), lr, batch_size, device, 0, rb_model, ct_util_dict, use_mi_loss, multi_pb = False)
        a1_agent = OBLR2D2Agent(a1_input_shape, a1_num_actions, a0_comm_bits, Memory(replay_memory_capacity), LocalBuffer(), OBLMemory(replay_memory_capacity), OBLLocalBuffer(), lr, batch_size, device, 1, sb_model, ct_util_dict)

        writer = SummaryWriter('logs')

        running_score = 0
        running_eval_score = 0
        epsilon = 1.0
        dial_eps = dial_iql_eps
        steps = 0
        loss = 0
        per_run_sender_time_to_booth_list = []
        per_run_receiver_time_to_booth_list = []
        per_run_reward = []
        per_run_eval_reward = []
        per_run_mi_reward = []
        per_run_running_reward = []
        per_run_running_eval_reward = []
        for e in range(num_episodes):
            done = False

            score = 0
            a0_reward = None
            a1_reward = None
            obs, state = env.reset()

            a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
            a1_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
            a1_next_hidden = None
            # print(e)
            while not done:
                steps += 1

                # Agent 0's turn
                a0_obs = torch.Tensor(env.get_obs(0)).to(device)
                if(stage_2_training and use_iql_for_stage_2):
                    a0_policy, a0_action, a0_next_hidden, a0_message, _ =  a0_agent.get_iql_action(a0_obs, dial_eps, a0_hidden)
                else:
                    a0_policy, a0_action, a0_next_hidden, a0_message = a0_agent.get_action(a0_obs, a0_hidden)

                # OBL Sampling
                a0_curr_env_config = env.save_env_config()
                if(stage_2_training == False  or (stage_2_training and use_iql_for_stage_2 == False)):
                    a0_agent.obl_sampling_flat(a0_hidden, a0_next_hidden, a0_policy.squeeze().detach().numpy(), a0_action, a0_curr_env_config, a1_agent, a1_next_hidden if a1_next_hidden != None else a1_hidden)
                a0_reward, done, info = env.step(0, a0_action, policy = a0_policy.squeeze().detach().numpy())

                # Add to agent 1's IQL buffer
                if(stage_2_training and use_iql_for_stage_2):
                    if(a1_reward != None):
                        # Add to agent 1's buffer
                        mask = 0 if done else 1
                        next_a1_obs = torch.Tensor(env.get_obs(1)).to(device)
                        a1_agent.iql_buffer.push(a1_obs, next_a1_obs, a1_action, a1_reward, mask, a1_hidden)
                        if len(a1_agent.iql_buffer.memory) == local_mini_batch:
                            a1_agent.push_to_iql_memory()

                    # Agent 1's IQL learning
                    if steps > initial_exploration and len(a1_agent.iql_memory) > batch_size:
                        loss, td_error = a1_agent.train_iql_model()

                        if steps % update_target == 0:
                            a1_agent.update_target_model()

                # Update after a0 has taken an action
                if(a1_next_hidden != None):
                    a1_hidden = a1_next_hidden

                if len(a0_agent.local_buffer.memory) == local_mini_batch:
                    a0_agent.push_to_memory()

                if steps > initial_exploration and len(a0_agent.memory) > batch_size:
                    if(stage_2_training == False or (stage_2_training and use_iql_for_stage_2 == False)):
                        loss, td_error = a0_agent.train_model(obl = True, use_mi_loss = use_mi_loss)

                    if steps % update_target == 0:
                        a0_agent.update_target_model()
                # Update belief
                a1_agent.belief_model.update_belief(comm_token = env.comm_token)

                # Agent 1's turn
                a1_obs = torch.Tensor(env.get_obs(1)).to(device)
                if(stage_2_training and use_iql_for_stage_2):
                    a1_policy, a1_action, a1_next_hidden, _, _ = a1_agent.get_iql_action(a1_obs, dial_eps, a1_hidden)
                else:
                    a1_policy, a1_action, a1_next_hidden, _ = a1_agent.get_action(a1_obs, a1_hidden)
                a1_curr_env_config = env.save_env_config()
                if(stage_2_training == False  or (stage_2_training and use_iql_for_stage_2 == False)):
                    a1_agent.obl_sampling_flat(a1_hidden, a1_next_hidden, a1_policy.squeeze().detach().numpy(), a1_action, a1_curr_env_config, a0_agent, a0_next_hidden)
                a1_reward, done, info = env.step(1, a1_action)

                # Add to agent 0's buffer
                if(stage_2_training and use_iql_for_stage_2):
                    mask = 0 if done else 1
                    next_a0_obs = torch.Tensor(env.get_obs(0)).to(device)
                    a0_agent.iql_buffer.push(a0_obs, next_a0_obs, a0_action, a0_reward + a1_reward, mask, a0_hidden)
                    if len(a0_agent.iql_buffer.memory) == local_mini_batch:
                        a0_agent.push_to_iql_memory()

                    if(done):
                        # Need to add to a1's buffer
                        next_a1_obs = torch.Tensor(env.get_obs(1)).to(device)
                        a1_agent.iql_buffer.push(a1_obs, next_a1_obs, a1_action, a1_reward, mask, a1_hidden)
                        if len(a1_agent.iql_buffer.memory) == local_mini_batch:
                            a1_agent.push_to_iql_memory()

                    dial_eps -= 0.00001
                    dial_eps = max(dial_eps, 0.1)

                    # Agent 1's IQL learning
                    if steps > initial_exploration and len(a0_agent.iql_memory) > batch_size:
                        loss, td_error = a0_agent.train_iql_model()
                        if steps % update_target == 0:
                            a0_agent.update_target_model()

                a0_hidden = a0_next_hidden

                if len(a1_agent.local_buffer.memory) == local_mini_batch:
                    a1_agent.push_to_memory()

                if steps > initial_exploration and len(a1_agent.memory) > batch_size:
                    if(stage_2_training == False or (stage_2_training and use_iql_for_stage_2 == False)):
                        loss, td_error = a1_agent.train_model(obl = True)

                    if steps % update_target == 0:
                        a1_agent.update_target_model()
                # Update belief
                a0_agent.belief_model.update_belief()

                score += a0_reward + a1_reward

            running_score = 0.99 * running_score + 0.01 * score
            # Steps to phone booth
            if(eval_argmax):
                if e % eval_interval == 0:
                    if(e % 2000 == 0):
                        eval_score, info = evaluate(eval_env, a0_agent, a1_agent, print_actions = True)
                    else:
                        eval_score, info = evaluate(eval_env, a0_agent, a1_agent)
                    running_eval_score = 0.99 * running_eval_score + 0.01 * eval_score
                    sender_time_to_pb = info["sender_time_to_booth"]
                    receiver_time_to_pb = info["receiver_time_to_booth"]
                    per_run_sender_time_to_booth_list.append(sender_time_to_pb)
                    per_run_receiver_time_to_booth_list.append(receiver_time_to_pb)
                    per_run_eval_reward.append(eval_score)
                    # per_run_mi_reward.append(eval_mi_score)
                    per_run_running_eval_reward.append(running_eval_score)
                    if(e % 2000 == 0):
                        print('Run {} | {} episode | score: {:.2f} | reward sum: {:.2f} | SenderToPB: {:.2f} | ReceiverToPB: {:.2f} | dial eps: {:.2f}'.format(
                            run_idx + 1, e, running_eval_score, eval_score, sender_time_to_pb, receiver_time_to_pb, dial_eps))
                        evaluate_policy(a0_agent, a1_agent, iql_env_config)
                    sys.stdout.flush()
            else:
                sender_time_to_pb = info["sender_time_to_booth"]
                receiver_time_to_pb = info["receiver_time_to_booth"]
                per_run_sender_time_to_booth_list.append(sender_time_to_pb)
                per_run_receiver_time_to_booth_list.append(receiver_time_to_pb)
                per_run_reward.append(score)
                per_run_running_reward.append(running_score)
                if e % log_interval == 0:
                    print('Run {} | {} episode | score: {:.2f} | reward sum: {:.2f} | SenderToPB: {:.2f} | ReceiverToPB: {:.2f}'.format(
                        run_idx + 1, e, running_score, score, sender_time_to_pb, receiver_time_to_pb))
                    writer.add_scalar('log/score', float(running_score), e)
                    writer.add_scalar('log/loss', float(loss), e)
                    sys.stdout.flush()


            # Reset belief
            a0_agent.belief_model.reset_belief()
            a1_agent.belief_model.reset_belief()

            # turn off mi training
            if((e + 1) >= num_episodes_for_mi_training and stage_2_training == False):
                # env.turn_off_mi_training()
                stage_2_training = True
                print("Stage 2 training starts")
                sys.stdout.flush()

        sender_time_to_booth_result.append(per_run_sender_time_to_booth_list)
        receiver_time_to_booth_result.append(per_run_receiver_time_to_booth_list)
        if(eval_argmax):
            eval_reward_result.append(per_run_eval_reward)
            eval_mi_reward_result.append(per_run_eval_reward)
            runnning_eval_reward_result.append(per_run_running_eval_reward)
        else:
            reward_result.append(per_run_reward)
            running_reward_result.append(per_run_running_reward)

        # Save model
        a0_agent.save_model(sender_model_path + "_{}.pt".format(run_idx + 1))
        a1_agent.save_model(receiver_model_path + "_{}.pt".format(run_idx + 1))

    if(eval_argmax):
        np.save(RESULT_PATH + reward_result_filename, np.array(eval_reward_result))
        if(env_config['use_mi_shaping'] or use_mi_loss):
            np.save(RESULT_PATH + mi_reward_result_filename, np.array(eval_mi_reward_result))
        np.save(RESULT_PATH + running_reward_result_filename, np.array(runnning_eval_reward_result))
    else:
        np.save(RESULT_PATH + reward_result_filename, np.array(reward_result))
        np.save(RESULT_PATH + running_reward_result_filename, np.array(running_reward_result))

    np.save(RESULT_PATH + sender_result_filename, np.array(sender_time_to_booth_result))
    np.save(RESULT_PATH + receiver_result_filename, np.array(receiver_time_to_booth_result))


if __name__=="__main__":
    main()
