import os
from datetime import datetime

import torch
import numpy as np

import gym
import mujoco_py
from PPO2_hopper import fkPPO

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(comment=".fkppo_walker2d_with_0.001_0.001_0fixlogstd_smoothL1loss")

def make_transition(state,action,reward,next_state,done,log_prob=None):
    transition = {}
    transition['state'] = state
    transition['action'] = action
    transition['reward'] = reward
    transition['next_state'] = next_state
    transition['log_prob'] = log_prob
    transition['done'] = done
    return transition


################################### Training ###################################
def train():
    print("============================================================================================")

    ################################## set device ##################################

    # set device to cpu or cuda
    device = torch.device('cpu')
    if (torch.cuda.is_available()):
        device = torch.device('cuda:1')
        torch.cuda.empty_cache()
        print("Device set to : " + str(torch.cuda.get_device_name(device)))
    else:
        print("Device set to : cpu")
    print("============================================================================================")

    ####### initialize environment hyperparameters ######
    env_name = "Walker2d-v2"             # Hopper-v2 Humanoid-v2 InvertedPendulum-v2 InvertedDoublePendulum-v2 | Reacher-v2
    env_name_string = env_name.replace('/', '-')
    # image_state = False

    has_continuous_action_space = True # continuous action space; else discrete

    max_ep_len = 2048                  # max timesteps in one episode
    max_training_timesteps = int(5e6)   # break training loop if timeteps > max_training_timesteps

    max_pooling_timesteps = int(1e4)  # pooling time

    print_freq = 5        # print avg reward in the interval (in num episodes)
    log_freq = 5          # log avg reward in the interval (in num episodes)
    save_model_freq = 50          # save model frequency (in num episodes)

    #####################################################

    ## Note : print/log frequencies should be > than max_ep_len

    ################ PPO hyperparameters ################
    K_epochs = 80               # update policy for K epochs in one PPO update
    eps_clip = 0.2              # clip parameter for PPO
    gamma = 0.99                # discount factor

    lr_actor = 0.0003      # learning rate for actor network
    lr_critic = 0.0003       # learning rate for critic network

    random_seed = 12         # set random seed if required (0 = no random seed)
    #####################################################

    print("training environment name : " + env_name_string)

    env = gym.make(env_name)

    # state space dimension
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    print("state space dimension : ", state_dim)
    print("action space dimension : ", action_dim)

    ###################### logging ######################

    #### log files for multiple runs are NOT overwritten
    log_dir = "./PPO_logs"
    if not os.path.exists(log_dir):
          os.makedirs(log_dir)

    log_dir = log_dir + '/' + env_name_string + '/'
    if not os.path.exists(log_dir):
          os.makedirs(log_dir)

    #### get number of log files in log directory
    run_num = 0
    current_num_files = next(os.walk(log_dir))[2]
    run_num = len(current_num_files)

    #### create new log file for each run
    log_f_name = log_dir + '/PPO_fkppo_' + env_name_string + "_log_" + str(run_num) + ".csv"

    print("current logging run number for " + env_name_string + " : ", run_num)
    print("logging at : " + log_f_name)
    #####################################################

    ################### checkpointing ###################

    run_num_pretrained = 0      #### change this to prevent overwriting weights in same env_name folder

    directory = "./PPO_preTrained"
    if not os.path.exists(directory):
          os.makedirs(directory)

    directory = directory + '/' + env_name_string + '/'
    if not os.path.exists(directory):
          os.makedirs(directory)

    checkpoint_path = directory + "PPO_fkppo_{}_{}_{}.pth".format(env_name_string, random_seed, run_num_pretrained)
    print("save checkpoint path : " + checkpoint_path)
    #####################################################


    ############# print all hyperparameters #############
    print("--------------------------------------------------------------------------------------------")
    print("max training timesteps : ", max_training_timesteps)
    print("max timesteps per episode : ", max_ep_len)
    print("model saving frequency : " + str(save_model_freq) + " timesteps")
    print("log frequency : " + str(log_freq) + " timesteps")
    print("printing average reward over episodes in last : " + str(print_freq) + " timesteps")
    print("--------------------------------------------------------------------------------------------")
    print("state space dimension : ", state_dim)
    print("action space dimension : ", action_dim)
    print("--------------------------------------------------------------------------------------------")
    # if has_continuous_action_space:
    #     print("Initializing a continuous action space policy")
    #     print("--------------------------------------------------------------------------------------------")
    #     print("starting std of action distribution : ", action_std)
    #     print("decay rate of std of action distribution : ", action_std_decay_rate)
    #     print("minimum std of action distribution : ", min_action_std)
    #     print("decay frequency of std of action distribution : " + str(action_std_decay_freq) + " timesteps")
    # else:
    #     print("Initializing a discrete action space policy")
    print("--------------------------------------------------------------------------------------------")
    # print("PPO update frequency : " + str(update_timestep) + " timesteps")
    print("PPO K epochs : ", K_epochs)
    print("PPO epsilon clip : ", eps_clip)
    print("discount factor (gamma) : ", gamma)
    print("--------------------------------------------------------------------------------------------")
    print("optimizer learning rate actor : ", lr_actor)
    print("optimizer learning rate critic : ", lr_critic)
    if random_seed:
        print("--------------------------------------------------------------------------------------------")
        print("setting random seed to ", random_seed)
        torch.manual_seed(random_seed)
        env.reset(seed=random_seed)
        np.random.seed(random_seed)
    #####################################################

    print("============================================================================================")

    ################# training procedure ################

    # initialize a PPO agent
    # ppo_agent = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, has_continuous_action_space, max_ep_len, device=device)
    ppo_agent = fkPPO(state_dim, action_dim, lr_actor, lr_critic,
                      gamma, K_epochs, eps_clip, has_continuous_action_space, max_ep_len,
                      dist_coeff=0.001, prior_coeff=0.001, n_samples=10, activation_fn='tanh', device=device)

    # track total training time
    start_time = datetime.now().replace(microsecond=0)
    print("Started training at (GMT) : ", start_time)

    print("============================================================================================")

    # logging file
    log_f = open(log_f_name,"w+")
    log_f.write('episode,timestep,reward\n')

    time_step = 0
    i_episode = 0

    score_list = []

    # training loop
    env_state = env.reset()
    state = env_state[0]
    # densenet_state = torch.FloatTensor(state).to(device)
    current_ep_reward = 0

    # collect episodes for the pool
    start_time_pool = datetime.now().replace(microsecond=0)
    # print("Started training at (GMT) : ", start_time)

    print("============================================================================================")

    while time_step <= max_pooling_timesteps:
        # select action randomly
        mu, sigma = ppo_agent.policy.select_action(torch.from_numpy(state).float().to(device))
        dist = torch.distributions.Normal(mu, sigma + 0.2)
        action = dist.rsample()
        next_state_, reward, done, info, x = env.step(action.cpu().numpy())
        next_state = next_state_
        ppo_agent.put_data_pool(state)

        time_step += 1

        # break; if the episode is over
        if done:
            env_state = env.reset()
            state = env_state[0]
        else:
            state = next_state

    end_time = datetime.now().replace(microsecond=0)
    print("Pooling uses time  : ", end_time - start_time_pool)

    # start training
    time_step = 0

    while time_step <= max_training_timesteps:

        env_state = env.reset()
        state = env_state[0]
        current_ep_reward = 0

        for t in range(max_ep_len):

            # select action with policy
            state_torch = torch.from_numpy(state).float().to(device)
            mu, sigma = ppo_agent.policy.select_action(state_torch)
            dist = torch.distributions.Normal(mu, sigma)
            action = dist.sample()
            # print("action: ", action)
            log_prob = dist.log_prob(action).sum(-1, keepdim=True)
            next_state, reward, done, info, x = env.step(action.cpu().numpy())
            nextstate_torch = torch.from_numpy(next_state).float().to(device)

            ppo_agent.buffer.states.append(state_torch)
            ppo_agent.buffer.nextstates.append(nextstate_torch)
            ppo_agent.buffer.actions.append(action)
            ppo_agent.buffer.logprobs.append(log_prob)
            ppo_agent.buffer.rewards.append(torch.tensor([reward]))
            ppo_agent.buffer.is_terminals.append(torch.tensor([int(done)]))

            ppo_agent.put_data_pool(state)

            time_step += 1
            current_ep_reward += reward

            # break; if the episode is over
            if done:
                score_list.append(current_ep_reward)
                env_state = env.reset()
                state = env_state[0]
                current_ep_reward = 0
            else:
                state = next_state

        #
        i_episode += 1

        # update PPO agent
        ppo_agent.update(time_step, writer)

        # print and log in logging file
        if len(score_list) > 0:
            average_episode_reward = sum(score_list) / len(score_list)
        else:
            average_episode_reward = current_ep_reward

        print("Episode : {} \t\t Timestep : {} \t\t Average Reward : {}".format(i_episode, time_step,average_episode_reward))

        log_f.write('{},{},{}\n'.format(i_episode, time_step, average_episode_reward))
        log_f.flush()

        writer.add_scalar("average_episode_reward", average_episode_reward, time_step)
        writer.flush()

        score_list = []

        # save model weights
        if i_episode % save_model_freq == 0:
            print("--------------------------------------------------------------------------------------------")
            print("saving model at : " + checkpoint_path)
            ppo_agent.save(checkpoint_path)
            print("model saved")
            print("Elapsed Time  : ", datetime.now().replace(microsecond=0) - start_time)
            print("--------------------------------------------------------------------------------------------")

    log_f.close()
    env.close()
    writer.close()

    # print total training time
    print("============================================================================================")
    end_time = datetime.now().replace(microsecond=0)
    print("Started training at (GMT) : ", start_time)
    print("Finished training at (GMT) : ", end_time)
    print("Total training time  : ", end_time - start_time)
    print("============================================================================================")


if __name__ == '__main__':
    train()
    
    
    
    
    
    
    
