#Train CARI on derk env
import numpy as np
import torch
from gym_derk.envs import DerkEnv
from util import SingleAgentWrapper, DiscreteActionEnv, BotPolicy
from util import cari_concat, set_reward_function
from gym.wrappers.time_limit import TimeLimit
from reward_fns import RewardCalculator

import os
from datetime import datetime
# import pybullet_envs
from PPO import PPO
from torch.utils.tensorboard import SummaryWriter
import copy
import argparse


################################### Training ###################################

def train(args):
    #FROM https://raw.githubusercontent.com/nikhilbarhate99/PPO-PyTorch/master/train.py
    print("============================================================================================")
    param_names = [f"{key}-{val}" for (key, val) in vars(args).items()]
    param_names = "_-_".join(param_names)
    writer =  SummaryWriter(f'tbr_logs/derk/{param_names}')
    ####### initialize environment hyperparameters ######

    action_space = "multidiscrete"  # "continuous" or "multidiscrete" else discrete

    max_ep_len = 150                   # max timesteps in one episode
    max_training_timesteps = int(args.max_training_steps)   # break training loop if timeteps > max_training_timesteps

    print_freq = max_ep_len * 1      # print avg reward in the interval (in num timesteps)
    log_freq = max_ep_len * 1          # log avg reward in the interval (in num timesteps)
    save_model_freq = int(5e4)          # save model frequency (in num timesteps)

    action_std = 0.6                    # starting std for action distribution (Multivariate Normal)
    bot_policy_update_freq = 20         #update bot policy every nth update on policy itself

    #####################################################


    ## Note : print/log frequencies should be > than max_ep_len


    ################ PPO hyperparameters ################

    update_timestep = max_ep_len * args.update_every     # update policy every n timesteps
    K_epochs = args.k_epochs             # update policy for K epochs in one PPO update

    eps_clip = args.eps_clip          # clip parameter for PPO
    gamma = args.gamma            # discount factor

    lr_actor = args.lr_actor       # learning rate for actor network
    lr_critic = args.lr_critic       # learning rate for critic network

    random_seed = args.random_seed         # set random seed if required (0 = no random seed)
    n_arenas = 32
    n_pseudo = args.n_pseudo

    #####################################################


    env_name = "derk"
    print("training environment name : " + env_name)
    #init reward functions
    base_reward_fn = {
        "killEnemyStatue": 0,
        "killEnemyUnit": 0,
        "damageEnemyUnit": 0,
        "healTeammate1": 0,
        "healTeammate2": 0,
        "friendlyFire": 0,
        "statueDamageTaken": 0,
        "fallDamageTaken": 0,
        "healEnemy": 0,
        "timeScaling": 0
    }
    # random seed set to derk_appinstance.py create_session() -function "pWkn91perNetBJQ3ymNc9"
    env = DerkEnv(n_arenas=n_arenas, turbo_mode=True, reward_function=base_reward_fn, session_args=dict(interleaved=False))
    discretization = [[env.action_space[0].low.item(), env.action_space[0].high.item()], [env.action_space[1].low.item(), env.action_space[1].high.item()], [env.action_space[2].low.item(), env.action_space[2].high.item()], None, None]
    d_steps = 5
    for di, interval in enumerate(discretization):
        if interval is not None:
            disc = torch.linspace(interval[0], interval[1], d_steps+1)
            disc = disc.repeat_interleave(2)
            disc = disc[1:-1].reshape(-1, 2)
            discretization[di] = disc

    env = DiscreteActionEnv(env, discretization)

    # state space dimension
    state_dim = env.observation_space.shape[0]

    # action space dimension
    if action_space == "continuous":
        action_dim = env.action_space.shape[0]
    elif action_space == "multidiscrete":
        action_dim = env.action_space.nvec
    else:
        action_dim = env.action_space.n
   
    # initialize a PPO agent
    n_weights = 2
    alphas = np.ones(n_weights)
    weights = np.random.dirichlet(alphas, 1)[0]
    ppo_agent = PPO(state_dim + n_weights, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, action_space, action_std) 

    #copy if one wants static bot policy
    bot_net = copy.deepcopy(ppo_agent.policy.actor)
    #if we want to load existing model
    #bot_net = PPO(state_dim + n_weights, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, action_space, action_std)
    #bot_net.load("PPO_preTrained/derk/PPO_derk_0_0.pth")
    #bot_net = bot_net.policy.actor
    #if we wish to have exactly the same policy
    #bot_net = ppo_agent.policy.actor
    #bot_net = None
    bp = BotPolicy(env=env, bot_net=bot_net, weights=weights)
    env = SingleAgentWrapper(env, bp)
    env = TimeLimit(env, max_ep_len)


    ###################### logging ######################

    #### log files for multiple runs are NOT overwritten

    log_dir = "PPO_logs"
    if not os.path.exists(log_dir):
          os.makedirs(log_dir)

    log_dir = log_dir + '/' + env_name + '/'
    if not os.path.exists(log_dir):
          os.makedirs(log_dir)


    #### get number of log files in log directory
    run_num = 0
    current_num_files = next(os.walk(log_dir))[2]
    run_num = len(current_num_files)


    #### create new log file for each run
    log_f_name = log_dir + '/PPO_' + env_name + "_log_" + str(run_num) + ".csv"

    print("current logging run number for " + env_name + " : ", run_num)
    print("logging at : " + log_f_name)

    #####################################################


    ################### checkpointing ###################

    run_num_pretrained = 0      #### change this to prevent overwriting weights in same env_name folder

    directory = "PPO_preTrained"
    if not os.path.exists(directory):
          os.makedirs(directory)

    directory = directory + '/' + env_name + '/'
    if not os.path.exists(directory):
          os.makedirs(directory)


    checkpoint_path = directory + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained)
    print("save checkpoint path : " + checkpoint_path)

    #####################################################


    ############# print all hyperparameters #############

    print("--------------------------------------------------------------------------------------------")

    print("max training timesteps : ", max_training_timesteps)
    print("max timesteps per episode : ", max_ep_len)

    print("model saving frequency : " + str(save_model_freq) + " timesteps")
    print("log frequency : " + str(log_freq) + " timesteps")
    print("printing average reward over episodes in last : " + str(print_freq) + " timesteps")

    print("--------------------------------------------------------------------------------------------")

    print("state space dimension : ", state_dim)
    print("action space dimension : ", action_dim)

    print("--------------------------------------------------------------------------------------------")

    print("Initializing a discrete or multidiscrete action space policy")

    print("--------------------------------------------------------------------------------------------")

    print("PPO update frequency : " + str(update_timestep) + " timesteps")
    print("PPO K epochs : ", K_epochs)
    print("PPO epsilon clip : ", eps_clip)
    print("discount factor (gamma) : ", gamma)

    print("--------------------------------------------------------------------------------------------")

    print("optimizer learning rate actor : ", lr_actor)
    print("optimizer learning rate critic : ", lr_critic)

    if random_seed:
        print("--------------------------------------------------------------------------------------------")
        print("setting random seed to ", random_seed)
        torch.manual_seed(random_seed)
        env.seed(random_seed)
        np.random.seed(random_seed)

    #####################################################

    print("============================================================================================")

    ################# training procedure ################
     

    # track total training time
    start_time = datetime.now().replace(microsecond=0)
    print("Started training at (GMT) : ", start_time)

    print("============================================================================================")


    # logging file
    log_f = open(log_f_name,"w+")
    log_f.write('episode,timestep,reward\n')


    # printing and logging variables
    print_running_reward = 0
    print_running_episodes = 1

    log_running_reward = 0
    log_running_episodes = 1

    time_step = 0
    i_episode = 0
    n_updates = 0

    n_pseudo = 0 #number of pseudotransitions to add per each transition (really is n_pseudo * n_arenas new steps)

    # training loop
    while time_step <= max_training_timesteps:
        #sample new weights on every reset for every arena and for every agent
        weights = np.random.dirichlet(alphas, n_arenas).round(0)
        set_reward_function(weights, env) 

        state = env.reset()
        returnn = 0
        rc = RewardCalculator(n_arenas)
        state = cari_concat(state, weights)
        #new weights for the bots as well
        bot_weights = np.random.dirichlet(alphas, 1)[0].round(0)
        bp.set_weights(bot_weights)
        # print("NEW WEIGHTS FOR LEARNER: ", weights)
        # print("NEW WEIGHTS FOR BOTS: ", bp.weights)

        current_ep_reward = 0
        for t in range(1, max_ep_len+1):

            # select action with policy
            action = ppo_agent.select_action(state) 
            action_to_save = action.clone()
    
            action = [tuple(a.tolist()) for a in action]
            next_state, _, done, _ = env.step(action)
            next_state = cari_concat(next_state, weights)

            #add pseudo transitions
            pseudo_weights = np.random.dirichlet(alphas, n_pseudo * state.shape[0]).astype("float32").round(0)
            state_to_save = torch.FloatTensor(state.copy())
            state_to_save = state_to_save.repeat(n_pseudo + 1, 1)
            state_to_save[n_arenas:, -n_weights:] = torch.tensor(pseudo_weights)
            next_state_to_save = torch.FloatTensor(next_state.copy())
            next_state_to_save = next_state_to_save.repeat(n_pseudo + 1, 1)
            next_state_to_save[n_arenas:, -n_weights:] = torch.tensor(pseudo_weights)

            action_to_save = action_to_save.repeat((n_pseudo+1, 1))
            device = action_to_save.device
            dists = ppo_agent.policy_old.actor(state_to_save.to(device))
            logprobs_to_save = ppo_agent.policy_old.actor.get_logprobs(action_to_save, dists)
            
            rc.update_focus(state, action, next_state)
            focus_copy = rc.focus.clone()
            rc.focus = rc.focus.repeat((n_pseudo+1))
            #calculate the rewards with new weights
            reward_calc_actions = [tuple(a.tolist()) for a in action_to_save]  
            reward_dmg = rc.damage_enemy(state_to_save.numpy(), reward_calc_actions, next_state_to_save.numpy()) / 10.0 #normalize to be more equivalent with healing
            reward_heal = rc.heal_teammate(state_to_save.numpy(), reward_calc_actions, next_state_to_save.numpy())
            reward = np.column_stack([reward_dmg, reward_heal])
            reward = np.sum(np.concatenate((weights, pseudo_weights), axis=0) * reward, axis=1)
            #refocus back to "normal"
            rc.focus = focus_copy
            
            returnn += reward

            if np.abs(reward[0]) > 0.05:
                print(reward[0])
            done_mask = state_to_save[:, 0] <= 0.0    
            if isinstance(done, list):
                done = done[0]
            #if done either by hp or by time
            done_mask = np.logical_or(done_mask, np.repeat(done, n_arenas * (n_pseudo + 1)))
            # saving reward and is_terminals
            ppo_agent.buffer.rewards.append(reward)
            ppo_agent.buffer.is_terminals.append(done_mask)
            ppo_agent.buffer.states.append(state_to_save)
            ppo_agent.buffer.actions.append(action_to_save)
            ppo_agent.buffer.logprobs.append(logprobs_to_save)

            state = next_state

            time_step +=1
            current_ep_reward += reward.mean() #mean over arenas

            # update PPO agent
            if time_step % update_timestep == 0:
                n_updates +=1
                losses = ppo_agent.update()
                if n_updates % bot_policy_update_freq == 0:
                    print("updating bot policy")
                    bot_net = copy.deepcopy(ppo_agent.policy.actor)
                    bp.bot_net = bot_net

                writer.add_scalar("Loss/loss", losses[3], time_step)
                writer.add_scalar("Loss/loss_actor", losses[0], time_step)
                writer.add_scalar("Loss/loss_critic", losses[1], time_step)
                writer.add_scalar("Loss/loss_entropy", losses[2], time_step)
                writer.add_scalar("Debug/entropy", losses[4], time_step)
                writer.add_scalar("Debug/clipfrac", losses[5], time_step)
                writer.add_scalar("Debug/relative_entropy", losses[6], time_step)
                writer.add_scalar("Debug/kl_div_last", losses[7], time_step)
                writer.add_scalar("Debug/residual_var_vf", losses[8], time_step)

                writer.add_scalars("ValueFunction/target_values", {"mean": losses[9], "min": losses[10], "max": losses[11]}, time_step)
                writer.add_scalars("ValueFunction/pred_values", {"mean": losses[12], "min": losses[13], "max": losses[14]}, time_step)


            # log in logging file
            if time_step % log_freq == 0:

                # log average reward till last episode
                log_avg_reward = log_running_reward / log_running_episodes
                log_avg_reward = round(log_avg_reward, 4)

                log_f.write('{},{},{}\n'.format(i_episode, time_step, log_avg_reward))
                log_f.flush()

                log_running_reward = 0
                log_running_episodes = 0

            # printing average reward
            if time_step % print_freq == 0:

                # print average reward till last episode
                print_avg_reward = print_running_reward / print_running_episodes
                print_avg_reward = round(print_avg_reward, 2)
                writer.add_scalar("Rew/print_avg_rew", print_avg_reward, time_step)

                print("Episode : {} \t\t Timestep : {} \t\t Average Reward : {}".format(i_episode, time_step, print_avg_reward))

                print_running_reward = 0
                print_running_episodes = 0

            # save model weights
            if time_step % save_model_freq == 0:
                print("--------------------------------------------------------------------------------------------")
                print("saving model at : " + checkpoint_path)
                ppo_agent.save(checkpoint_path)
                print("model saved")
                print("Elapsed Time  : ", datetime.now().replace(microsecond=0) - start_time)
                print("--------------------------------------------------------------------------------------------")

            # break; if the episode is over
            if done:
                break
        print("RETURNN ", returnn.mean())
        returnn = 0

        print_running_reward += current_ep_reward
        print_running_episodes += 1

        log_running_reward += current_ep_reward
        log_running_episodes += 1

        i_episode += 1

    writer.flush()
    writer.close()
    log_f.close()
    env.close()

    # print total training time
    print("============================================================================================")
    end_time = datetime.now().replace(microsecond=0)
    print("Started training at (GMT) : ", start_time)
    print("Finished training at (GMT) : ", end_time)
    print("Total training time  : ", end_time - start_time)
    print("============================================================================================")




if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='DERK PPO-CARI args')
    parser.add_argument("--max_training_steps", default=500000, type=int)
    parser.add_argument("--update_every", default=4, type=int)
    parser.add_argument("--k_epochs", default=20, type=int)
    parser.add_argument("--eps_clip", default=0.2, type=float)
    parser.add_argument("--lr_actor", default=0.003, type=float)
    parser.add_argument("--lr_critic", default=0.005, type=float)
    parser.add_argument("--gamma", default=0.99, type=float)
    parser.add_argument("--lamb", default=0.95, type=float)
    parser.add_argument("--bs", default=4096, type=int)
    parser.add_argument("--random_seed", default=0, type=int)
    parser.add_argument("--n_pseudo", default=0, type=int)
    parser.add_argument("--entropy_coef", default=0.01, type=float)
    args = parser.parse_args()
    train(args)