import os
from datetime import datetime

import torch
# from PIL import Image
# from torchvision import transforms
import numpy as np

import gym
#import roboschool
from PPO_hopper import PPO
# from CartPoleEnvC import CartPoleEnvC

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(comment=".ppo_fixedactionstd_smoothL1loss")


################################### Training ###################################
def train():
    print("============================================================================================")

    ################################## set device ##################################

    # set device to cpu or cuda
    device = torch.device('cpu')
    if (torch.cuda.is_available()):
        device = torch.device('cuda:0')
        torch.cuda.empty_cache()
        print("Device set to : " + str(torch.cuda.get_device_name(device)))
    else:
        print("Device set to : cpu")
    print("============================================================================================")

    ####### initialize environment hyperparameters ######
    env_name = "Hopper-v2"             # CartPole-v1  Acrobot-v1  ALE/Atlantis-v5, ALE/Assault-v5, ALE/Freeway-v5
    env_name_string = env_name.replace('/', '-')

    has_continuous_action_space = True # continuous action space; else discrete

    max_ep_len = 2048                   # max timesteps in one episode
    max_training_timesteps = int(2e6)   # break training loop if timeteps > max_training_timesteps

    save_model_freq = 50                # save model frequency (in num timesteps)

    action_std = 0.1                    # starting std for action distribution (Multivariate Normal)
    action_std_decay_rate = 0.05        # linearly decay action_std (action_std = action_std - action_std_decay_rate)
    min_action_std = 0.1                # minimum action_std (stop decay after action_std <= min_action_std)
    action_std_decay_freq = int(2.5e5)  # action_std decay frequency (in num timesteps)
    #####################################################

    ## Note : print/log frequencies should be > than max_ep_len

    ################ PPO hyperparameters ################
    K_epochs = 80               # update policy for K epochs in one PPO update
    eps_clip = 0.2          # clip parameter for PPO
    gamma = 0.99            # discount factor

    lr_actor = 0.0003       # learning rate for actor network
    lr_critic = 0.0003       # learning rate for critic network

    random_seed = 12         # set random seed if required (0 = no random seed)
    #####################################################

    print("training environment name : " + env_name_string)

    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]

    # action space dimension
    if has_continuous_action_space:
        action_dim = env.action_space.shape[0]
    else:
        action_dim = env.action_space.n

    ###################### logging ######################

    #### log files for multiple runs are NOT overwritten
    log_dir = "./PPO_logs"
    if not os.path.exists(log_dir):
          os.makedirs(log_dir)

    log_dir = log_dir + '/' + env_name_string + '/'
    if not os.path.exists(log_dir):
          os.makedirs(log_dir)

    #### get number of log files in log directory
    run_num = 0
    current_num_files = next(os.walk(log_dir))[2]
    run_num = len(current_num_files)

    #### create new log file for each run
    log_f_name = log_dir + '/PPO_ppo_' + env_name_string + "_log_" + str(run_num) + ".csv"

    print("current logging run number for " + env_name_string + " : ", run_num)
    print("logging at : " + log_f_name)
    #####################################################

    ################### checkpointing ###################

    run_num_pretrained = 0      #### change this to prevent overwriting weights in same env_name folder

    directory = "./PPO_preTrained"
    if not os.path.exists(directory):
          os.makedirs(directory)

    directory = directory + '/' + env_name_string + '/'
    if not os.path.exists(directory):
          os.makedirs(directory)

    checkpoint_path = directory + "PPO_ppo_{}_{}_{}.pth".format(env_name_string, random_seed, run_num_pretrained)
    print("save checkpoint path : " + checkpoint_path)
    #####################################################


    ############# print all hyperparameters #############
    print("--------------------------------------------------------------------------------------------")
    print("max training timesteps : ", max_training_timesteps)
    print("max timesteps per episode : ", max_ep_len)
    print("model saving frequency : " + str(save_model_freq) + " timesteps")
    # print("log frequency : " + str(log_freq) + " timesteps")
    # print("printing average reward over episodes in last : " + str(print_freq) + " timesteps")
    print("--------------------------------------------------------------------------------------------")
    print("state space dimension : ", state_dim)
    print("action space dimension : ", action_dim)
    print("--------------------------------------------------------------------------------------------")
    if has_continuous_action_space:
        print("Initializing a continuous action space policy")
        print("--------------------------------------------------------------------------------------------")
        print("starting std of action distribution : ", action_std)
        print("decay rate of std of action distribution : ", action_std_decay_rate)
        print("minimum std of action distribution : ", min_action_std)
        print("decay frequency of std of action distribution : " + str(action_std_decay_freq) + " timesteps")
    else:
        print("Initializing a discrete action space policy")
    print("--------------------------------------------------------------------------------------------")
    # print("PPO update frequency : " + str(update_timestep) + " timesteps")
    print("PPO K epochs : ", K_epochs)
    print("PPO epsilon clip : ", eps_clip)
    print("discount factor (gamma) : ", gamma)
    print("--------------------------------------------------------------------------------------------")
    print("optimizer learning rate actor : ", lr_actor)
    print("optimizer learning rate critic : ", lr_critic)
    if random_seed:
        print("--------------------------------------------------------------------------------------------")
        print("setting random seed to ", random_seed)
        torch.manual_seed(random_seed)
        # env.seed(random_seed)
        np.random.seed(random_seed)
    #####################################################

    print("============================================================================================")

    ################# training procedure ################

    # # flatten the state
    # state = env.reset()
    # state_dim = state.size
    # #state = state.view(state_dim, -1)
    # print("new state dim is ", state_dim)

    #####################################################

    # if image_state:
    #     # use Densenet to flatten the state
    #     # densenet = torch.hub.load('pytorch/vision:v0.10.0', 'densenet121', pretrained=True)
    #     preprocess_net = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
    #     preprocess_net.eval()
    #     preprocess_net.to(device)
    #
    #     preprocess = transforms.Compose([
    #         transforms.Resize(256),
    #         transforms.CenterCrop(224),
    #         transforms.ToTensor(),
    #         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    #     ])
    #
    #     state = env.reset()
    #     pil_state = Image.fromarray(state)
    #     proc_pil_state = preprocess(pil_state)
    #     proc_pil_state = proc_pil_state.unsqueeze(0).to(device)  # create a mini-batch as expected by the model
    #     state = torch.squeeze(preprocess_net(proc_pil_state)).to(device)
    #
    #     state_dim = state.shape[0]
    #
    #     # state = state.view(state_dim, -1)
    #     print("new state dim is ", state_dim)
    #
    #     print("============================================================================================")

    # initialize a PPO agent
    ppo_agent = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, has_continuous_action_space, action_std, device=device)

    # track total training time
    start_time = datetime.now().replace(microsecond=0)
    print("Started training at (GMT) : ", start_time)

    print("============================================================================================")

    # logging file
    log_f = open(log_f_name,"w+")
    log_f.write('episode,timestep,reward\n')

    # printing and logging variables
    print_running_reward = 0
    print_running_episodes = 0

    log_running_reward = 0
    log_running_episodes = 0

    time_step = 0
    i_episode = 0

    score_list = []

    # training loop
    while time_step <= max_training_timesteps:

        env_state = env.reset()
        state = env_state[0]
        current_ep_reward = 0

        for t in range(max_ep_len):

            # select action with policy
            state_torch = torch.from_numpy(state).float().to(device)
            mu, sigma = ppo_agent.policy.select_action(state_torch)
            dist = torch.distributions.Normal(mu, sigma)
            action = dist.sample()
            log_prob = dist.log_prob(action).sum(-1, keepdim=True)
            next_state, reward, done, info, x = env.step(action.cpu().numpy())
            nextstate_torch = torch.FloatTensor(next_state).to(device)

            # saving reward and is_terminals
            ppo_agent.buffer.states.append(state_torch)
            ppo_agent.buffer.nextstates.append(nextstate_torch)
            ppo_agent.buffer.actions.append(action)
            ppo_agent.buffer.logprobs.append(log_prob)
            ppo_agent.buffer.rewards.append(torch.tensor([reward]))
            ppo_agent.buffer.is_terminals.append(torch.tensor([int(done)]))

            time_step += 1
            current_ep_reward += reward

            # break; if the episode is over
            if done:
                score_list.append(current_ep_reward)
                env_state = env.reset()
                state = env_state[0]
                current_ep_reward = 0
            else:
                state = next_state

        i_episode += 1

        # update PPO agent
        ppo_agent.update(time_step, writer)

        # print and log in logging file
        if len(score_list) > 0:
            average_episode_reward = sum(score_list) / len(score_list)
        else:
            average_episode_reward = current_ep_reward

        print("Episode : {} \t\t Timestep : {} \t\t Average Reward : {}".format(i_episode, time_step,
                                                                                average_episode_reward))

        log_f.write('{},{},{}\n'.format(i_episode, time_step, average_episode_reward))
        log_f.flush()

        writer.add_scalar("average_episode_reward", average_episode_reward, time_step)
        writer.flush()

        score_list = []

        # save model weights
        if i_episode % save_model_freq == 0:
            print("--------------------------------------------------------------------------------------------")
            print("saving model at : " + checkpoint_path)
            ppo_agent.save(checkpoint_path)
            print("model saved")
            print("Elapsed Time  : ", datetime.now().replace(microsecond=0) - start_time)
            print("--------------------------------------------------------------------------------------------")

    log_f.close()
    env.close()
    writer.close()

    # print total training time
    print("============================================================================================")
    end_time = datetime.now().replace(microsecond=0)
    print("Started training at (GMT) : ", start_time)
    print("Finished training at (GMT) : ", end_time)
    print("Total training time  : ", end_time - start_time)
    print("============================================================================================")


if __name__ == '__main__':

    train()
    
    
    
    
    
    
    
