#mdenv contextual policy training
import numpy as np
import torch
from mdtoyenv import MultiDiscreteToyEnv
from reward_fns import within_radius, neg_distance
from util import argmax_onehot
from util import cari_concat, get_hypersphere_locations
from gym.wrappers.time_limit import TimeLimit

import os
import glob
import time
from datetime import datetime
from PPO import PPO
from torch.utils.tensorboard import SummaryWriter
import argparse





################################### Training ###################################

def train(args):
    #FROM https://raw.githubusercontent.com/nikhilbarhate99/PPO-PyTorch/master/train.py
    print("============================================================================================")

    param_names = [f"{key}-{val}" for (key, val) in vars(args).items()]
    param_names = "_-_".join(param_names)
    param_names = param_names.replace("max_training_steps", "ts")
    param_names = param_names.replace("lr_actor", "lr_a")
    param_names = param_names.replace("lr_critic", "lr_c")
    param_names = param_names.replace("gamma", "gam")
    param_names = param_names.replace("random_seed", "seed")
    param_names = param_names.replace("smooth_reward", "smooth_rew")
    param_names = param_names.replace("n_env_dims", "env_dim")
    param_names = param_names.replace("experiment_name", "")
    ####### initialize environment hyperparameters ######

    action_space = "multidiscrete"  # "continuous" or "multidiscrete" else discrete

    max_ep_len = 150                   # max timesteps in one episode
    max_training_timesteps = int(args.max_training_steps)   # break training loop if timeteps > max_training_timesteps

    print_freq = max_ep_len * 1      # print avg reward in the interval (in num timesteps)
    log_freq = max_ep_len * 1          # log avg reward in the interval (in num timesteps)
    save_model_freq = int(5e4)          # save model frequency (in num timesteps)

    action_std = 0.6                    # starting std for action distribution (Multivariate Normal)
    n_weights = int(args.n_weights)
    n_env_dims = int(args.n_env_dims)
    if args.smooth_reward == 1:
        print("KEKE")
        rew_fn = neg_distance
    else:
        print("ASD")
        rew_fn = within_radius

    #####################################################


    ## Note : print/log frequencies should be > than max_ep_len


    ################ PPO hyperparameters ################
    update_timestep = max_ep_len * args.update_every * 32   # update policy every n timesteps
    K_epochs = args.k_epochs             # update policy for K epochs in one PPO update

    eps_clip = args.eps_clip          # clip parameter for PPO
    gamma = args.gamma           # discount factor
    lambd = args.lamb            #GAE lambda

    lr_actor = args.lr_actor     # learning rate for actor network
    lr_critic = args.lr_critic       # learning rate for critic network
    bs = args.bs
    entropy_coef = args.entropy_coef

    random_seed = args.random_seed        # set random seed if required (0 = no random seed)
    n_arenas = 1
    n_pseudo = args.n_pseudo #number of pseudotransitions to add per each transition (really is n_pseudo * n_arenas new steps)
    warmstart = args.warmstart

    #####################################################


    env_name = "mdtoyenv"
    print("training environment name : " + env_name)
    locs, rads = get_hypersphere_locations(n_env_dims, n_weights)
    print(f"GOALS: {locs}, RADS {rads}")

    env = MultiDiscreteToyEnv(locs, rads)
    env = TimeLimit(env, max_ep_len)

    # state space dimension
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.nvec
   
    # initialize a PPO agent
    alphas = 0.01*np.ones(n_weights)
    ppo_agent = PPO(state_dim + n_weights, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, action_space, action_std, lambd, bs, entropy_coef) 

    ###################### logging ######################

    #### log files for multiple runs are NOT overwritten
    
    if args.experiment_name == "mdenv_cari_train":
        log_dir = "PPO_logs"
        directory = "PPO_preTrained"
        writer =  SummaryWriter(f'tbr_logs/{param_names}')
        round_weights = True
    elif args.experiment_name == "warmstart":
        log_dir = "warmstart_logs"
        directory = "PPO_warmstart"
        writer =  SummaryWriter(f'tbr_warmstart/{param_names}')
        round_weights = False
    elif args.experiment_name == "teaser_policy":
        log_dir = "teaser_log"
        directory = "PPO_teaser"
        writer =  SummaryWriter(f'tbr_teaser/{param_names}')
        round_weights = True

    #### get number of log files in log directory
    run_num = 0
    current_num_files = next(os.walk(log_dir))[2]
    run_num = len(current_num_files)

    if not os.path.exists(log_dir):
          os.makedirs(log_dir)

    log_dir = log_dir + '/' + env_name + '/'
    if not os.path.exists(log_dir):
          os.makedirs(log_dir)


    #### create new log file for each run
    log_f_name = log_dir + '/' + param_names + ".csv"

    print("current logging run number for " + env_name + " : ", run_num)
    print("logging at : " + log_f_name)

    #####################################################


    ################### checkpointing ###################

    run_num_pretrained = 0      #### change this to prevent overwriting weights in same env_name folder
    if not os.path.exists(directory):
          os.makedirs(directory)

    directory = directory + '/' + env_name + '/'
    if not os.path.exists(directory):
          os.makedirs(directory)


    checkpoint_path = directory + f"{n_weights}.pth"
    print("save checkpoint path : " + checkpoint_path)

    #####################################################


    ############# print all hyperparameters #############

    print("--------------------------------------------------------------------------------------------")

    print("max training timesteps : ", max_training_timesteps)
    print("max timesteps per episode : ", max_ep_len)

    print("model saving frequency : " + str(save_model_freq) + " timesteps")
    print("log frequency : " + str(log_freq) + " timesteps")
    print("printing average reward over episodes in last : " + str(print_freq) + " timesteps")

    print("--------------------------------------------------------------------------------------------")

    print("state space dimension : ", state_dim)
    print("action space dimension : ", action_dim)

    print("--------------------------------------------------------------------------------------------")

    print("Initializing a discrete or multidiscrete action space policy")

    print("--------------------------------------------------------------------------------------------")

    print("PPO update frequency : " + str(update_timestep) + " timesteps")
    print("PPO K epochs : ", K_epochs)
    print("PPO epsilon clip : ", eps_clip)
    print("discount factor (gamma) : ", gamma)

    print("--------------------------------------------------------------------------------------------")

    print("optimizer learning rate actor : ", lr_actor)
    print("optimizer learning rate critic : ", lr_critic)

    if random_seed:
        print("--------------------------------------------------------------------------------------------")
        print("setting random seed to ", random_seed)
        torch.manual_seed(random_seed)
        env.seed(random_seed)
        np.random.seed(random_seed)

    #####################################################

    print("============================================================================================")

    ################# training procedure ################
     

    # track total training time
    start_time = datetime.now().replace(microsecond=0)
    print("Started training at (GMT) : ", start_time)

    print("============================================================================================")


    # logging file
    log_f = open(log_f_name,"w+")
    log_f.write('episode,timestep,reward\n')


    # printing and logging variables
    print_running_reward = 0
    print_running_episodes = 1

    log_running_reward = 0
    log_running_episodes = 1

    time_step = 0
    i_episode = 0
    n_updates = 0
    assert (not (warmstart == 1 and n_pseudo == 0)), "warmstart not possible if n_pseudo=0"

    # training loop
    while time_step <= max_training_timesteps:
        #sample new weights on every reset for every arena and for every agent
        weights = np.random.dirichlet(alphas, n_arenas)
        if round_weights:
            weights = argmax_onehot(weights)

        state = env.reset()
        returnn = 0
        state = cari_concat(state, weights)

        current_ep_reward = 0
        for t in range(1, max_ep_len+1):

            # select action with policy
            action = ppo_agent.select_action(state) 
            action_to_save = action.clone()
            next_state, _, done, _ = env.step(action)
            if type(done) == bool:
                done = np.array(done).reshape((n_arenas,))

            if warmstart == 1 and time_step > max_training_timesteps // 2 and len(ppo_agent.buffer.rewards) == 0:
                n_pseudo = 0

            #add pseudo transitions
            pseudo_weights = np.random.dirichlet(alphas, n_pseudo * state.shape[0]).astype("float32")
            if round_weights:
                pseudo_weights = argmax_onehot(pseudo_weights)

            state_to_save = torch.FloatTensor(state.copy())
            state_to_save = state_to_save.repeat(n_pseudo + 1, 1)
            state_to_save[n_arenas:, -n_weights:] = torch.tensor(pseudo_weights)
            
            next_state = cari_concat(next_state, weights)
            next_state_to_save = torch.FloatTensor(next_state.copy())
            next_state_to_save = next_state_to_save.repeat(n_pseudo + 1, 1)
            next_state_to_save[n_arenas:, -n_weights:] = torch.tensor(pseudo_weights)

            action_to_save = action_to_save.repeat((n_pseudo+1, 1))
            device = action_to_save.device
            dists = ppo_agent.policy_old.actor(state_to_save.to(device))
            logprobs_to_save = ppo_agent.policy_old.actor.get_logprobs(action_to_save, dists)
            
            # #calculate the rewards with new weights
            rewards = [rew_fn(next_state_to_save[:,:env.n_dim].numpy(), goal.reshape((1,env.n_dim)).repeat(n_pseudo+1, 0), rad * np.ones((n_pseudo + 1, 1))) for ii, (goal, rad) in enumerate(zip(env.goal_locs, env.goal_rads))]
            reward = np.column_stack(rewards)
            reward = np.sum(np.concatenate((weights, pseudo_weights), axis=0) * reward, axis=1)
            done = done.repeat(n_pseudo+1)

            returnn += reward
            # saving reward and is_terminals
            ppo_agent.buffer.rewards.append(reward)
            ppo_agent.buffer.is_terminals.append(done)
            ppo_agent.buffer.states.append(state_to_save)
            ppo_agent.buffer.actions.append(action_to_save)
            ppo_agent.buffer.logprobs.append(logprobs_to_save)

            state = next_state

            time_step +=1
            current_ep_reward += reward[:n_arenas].mean() #mean over arenas, not pseudo-observations

            # update PPO agent
            if time_step % update_timestep == 0:
                n_updates +=1
                losses = ppo_agent.update()

                writer.add_scalar("Loss/loss", losses[3], time_step)
                writer.add_scalar("Loss/loss_actor", losses[0], time_step)
                writer.add_scalar("Loss/loss_critic", losses[1], time_step)
                writer.add_scalar("Loss/loss_entropy", losses[2], time_step)
                writer.add_scalar("Debug/entropy", losses[4], time_step)
                writer.add_scalar("Debug/clipfrac", losses[5], time_step)
                writer.add_scalar("Debug/relative_entropy", losses[6], time_step)
                writer.add_scalar("Debug/kl_div_last", losses[7], time_step)
                writer.add_scalar("Debug/residual_var_vf", losses[8], time_step)

                writer.add_scalars("ValueFunction/target_values", {"mean": losses[9], "min": losses[10], "max": losses[11]}, time_step)
                writer.add_scalars("ValueFunction/pred_values", {"mean": losses[12], "min": losses[13], "max": losses[14]}, time_step)



            # log in logging file
            if time_step % log_freq == 0:

                # log average reward till last episode
                log_avg_reward = log_running_reward / log_running_episodes
                log_avg_reward = round(log_avg_reward, 4)

                log_f.write('{},{},{}\n'.format(i_episode, time_step, log_avg_reward))
                log_f.flush()

                log_running_reward = 0
                log_running_episodes = 0

            # printing average reward
            if time_step % print_freq == 0:

                # print average reward till last episode
                print_avg_reward = print_running_reward / print_running_episodes
                print_avg_reward = round(print_avg_reward, 2)
                writer.add_scalar("Rew/print_avg_rew", print_avg_reward, time_step)

                print("Episode : {} \t\t Timestep : {} \t\t Average Reward : {}".format(i_episode, time_step, print_avg_reward))

                print_running_reward = 0
                print_running_episodes = 0

            # save model weights
            if time_step % save_model_freq == 0:
                print("--------------------------------------------------------------------------------------------")
                print("saving model at : " + checkpoint_path)
                ppo_agent.save(checkpoint_path)
                print("model saved")
                print("Elapsed Time  : ", datetime.now().replace(microsecond=0) - start_time)
                print("--------------------------------------------------------------------------------------------")

            # break; if the episode is over
            if done[0]:
                break
        print("RETURNN ", returnn.mean())
        returnn = 0

        print_running_reward += current_ep_reward
        print_running_episodes += 1

        log_running_reward += current_ep_reward
        log_running_episodes += 1

        i_episode += 1

    writer.flush()
    writer.close()
    log_f.close()
    env.close()

    # print total training time
    print("============================================================================================")
    end_time = datetime.now().replace(microsecond=0)
    print("Started training at (GMT) : ", start_time)
    print("Finished training at (GMT) : ", end_time)
    print("Total training time  : ", end_time - start_time)
    print("============================================================================================")




if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='MDENV PPO-CARI args')
    parser.add_argument("--max_training_steps", default=500000, type=int)
    parser.add_argument("--update_every", default=4, type=int)
    parser.add_argument("--k_epochs", default=20, type=int)
    parser.add_argument("--eps_clip", default=0.2, type=float)
    parser.add_argument("--lr_actor", default=0.003, type=float)
    parser.add_argument("--lr_critic", default=0.01, type=float)
    parser.add_argument("--gamma", default=0.99, type=float)
    parser.add_argument("--lamb", default=0.95, type=float)
    parser.add_argument("--bs", default=4096, type=int)
    parser.add_argument("--random_seed", default=0, type=int)
    parser.add_argument("--n_pseudo", default=1, type=int)
    parser.add_argument("--entropy_coef", default=0.01, type=float)
    parser.add_argument("--n_weights", default=3, type=int)
    parser.add_argument("--smooth_reward", default=0, type=int)
    parser.add_argument("--n_env_dims", default=2, type=int)
    parser.add_argument("--warmstart", default=0, type=int)
    parser.add_argument("--experiment_name", default="mdenv_cari_train", type=str)
    args = parser.parse_args()
    train(args)