import os
import glob
import time
from datetime import datetime

import torch
import numpy as np
from copy import deepcopy
import gym
from util import get_output_folder
from PPO import PPO
import sys
sys.path.append("..")
from attacker.Attacker import Attacker
from env.ControlSlide import ControlSlideEnv
from env.CarFindFlag import CarFindFlagEnv
from env.CarFindFlag_e import CarFindFlagEEnv
from env.CarFindFlag_m import CarFindFlagMEnv
from torch.distributions import MultivariateNormal
from torch.distributions import Categorical

################################### Training ###################################
def train(args,attacker):
    print("============================================================================================")

    ####### initialize environment hyperparameters ###### MountainCarContinuous-v0 CarFindFlagEnv
    env_name = args.env_name

    has_continuous_action_space = True  # continuous action space; else discrete

    max_ep_len = args.max_episode_length                   # max timesteps in one episode
    max_training_timesteps = args.max_training_epochs * max_ep_len   # break training loop if timeteps > max_training_timesteps

    print_freq = 100                     # print avg reward in the interval (in num timesteps)
    log_freq = 2                        # log avg reward in the interval (in num timesteps)
    save_model_freq = int(10000)          # save model frequency (in num epochs)

    action_std = args.action_std                   # starting std for action distribution (Multivariate Normal)
    action_std_decay_rate = args.action_std_decay_rate        # linearly decay action_std (action_std = action_std - action_std_decay_rate)
    min_action_std = args.min_action_std                # minimum action_std (stop decay after action_std <= min_action_std)
    action_std_decay_freq = args.action_std_decay_freq  # action_std decay frequency (in num timesteps)
    #####################################################

    ## Note : print/log frequencies should be > than max_ep_len

    ################ PPO hyperparameters ################
    update_timestep = max_ep_len * 10      # update policy every n timesteps
    K_epochs = args.K_epochs               # update policy for K epochs in one PPO update

    eps_clip = 0.2          # clip parameter for PPO
    gamma = 0.99            # discount factor

    lr_actor = args.lr_actor       # learning rate for actor network
    lr_critic = args.lr_critic       # learning rate for critic network

    random_seed = args.random_seed         # set random seed if required (0 = no random seed)
    #####################################################

    print("training environment name : " + env_name)

    if env_name == "ControlSlideEnv":
        env = ControlSlideEnv()
    elif args.env_name == "CarFindFlagMEnv":
        print(args.env_name)
        env = CarFindFlagMEnv()
    # state space dimension
    state_dim = env.observation_space.shape[0]

    # action space dimension
    if has_continuous_action_space:
        action_dim = env.action_space.shape[0]
    else:
        action_dim = env.action_space.n

    ###################### logging ######################

    #### log files for multiple runs are NOT overwritten
    log_dir = "PPO_logs"
    if not os.path.exists(log_dir):
          os.makedirs(log_dir)

    #log_dir = log_dir + '/' + env_name + '/'
    log_dir = get_output_folder(log_dir, env_name)
    if not os.path.exists(log_dir):
          os.makedirs(log_dir)
    print(log_dir)
    #### get number of log files in log directory
    run_num = 0
    current_num_files = next(os.walk(log_dir))[2]
    run_num = len(current_num_files)

    #### create new log file for each run
    log_f_name = log_dir + '/PPO_' + env_name + "_log_" + str(run_num) + ".csv"

    print("current logging run number for " + env_name + " : ", run_num)
    print("logging at : " + log_f_name)
    #####################################################

    ################### checkpointing ###################
    run_num_pretrained = 0      #### change this to prevent overwriting weights in same env_name folder

    directory = "PPO_preTrained"
    if not os.path.exists(directory):
          os.makedirs(directory)

    #directory = directory + '/' + env_name + '/'
    directory = get_output_folder(directory, env_name)
    if not os.path.exists(directory):
          os.makedirs(directory)


    checkpoint_path = directory
    print("save checkpoint path : " + checkpoint_path)
    #####################################################

    ################## save args ########################
    argsdict = args.__dict__
    with open(directory+'setting.txt','w') as f:
        for eachAcg in argsdict:
            f.writelines(str(eachAcg) + ':' + str(argsdict[eachAcg]) + '\n')

    #####################################################

    ############# print all hyperparameters #############
    print("--------------------------------------------------------------------------------------------")
    print("max training timesteps : ", max_training_timesteps)
    print("max timesteps per episode : ", max_ep_len)
    print("model saving frequency : " + str(save_model_freq) + " timesteps")
    print("log frequency : " + str(log_freq) + " timesteps")
    print("printing average reward over episodes in last : " + str(print_freq) + " timesteps")
    print("--------------------------------------------------------------------------------------------")
    print("state space dimension : ", state_dim)
    print("action space dimension : ", action_dim)
    print("--------------------------------------------------------------------------------------------")
    if has_continuous_action_space:
        print("Initializing a continuous action space policy")
        print("--------------------------------------------------------------------------------------------")
        print("starting std of action distribution : ", action_std)
        print("decay rate of std of action distribution : ", action_std_decay_rate)
        print("minimum std of action distribution : ", min_action_std)
        print("decay frequency of std of action distribution : " + str(action_std_decay_freq) + " timesteps")
    else:
        print("Initializing a discrete action space policy")
    print("--------------------------------------------------------------------------------------------")
    print("PPO update frequency : " + str(update_timestep) + " timesteps")
    print("PPO K epochs : ", K_epochs)
    print("PPO epsilon clip : ", eps_clip)
    print("discount factor (gamma) : ", gamma)
    print("--------------------------------------------------------------------------------------------")
    print("optimizer learning rate actor : ", lr_actor)
    print("optimizer learning rate critic : ", lr_critic)
    if random_seed:
        print("--------------------------------------------------------------------------------------------")
        print("setting random seed to ", random_seed)
        torch.manual_seed(random_seed)
        env.seed(random_seed)
        np.random.seed(random_seed)
    #####################################################

    print("============================================================================================")

    ################# training procedure ################

    # initialize a PPO agent
    ppo_agent = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, has_continuous_action_space, action_std)

    # track total training time
    start_time = datetime.now().replace(microsecond=0)
    print("Started training at (GMT) : ", start_time)

    print("============================================================================================")

    # logging file
    log_f = open(log_f_name,"w+")
    log_f.write('episode,timestep,reward\n')

    # printing and logging variables
    print_running_reward = 0
    print_running_episodes = 0

    log_running_reward = 0
    log_running_episodes = 0

    time_step = 0
    i_episode = 0

    re = []
    tarj = []
    e_t = []
    # training loop
    while time_step <= max_training_timesteps:

        state = env.reset()
        current_ep_reward = 0

        for t in range(1, max_ep_len+1):

            # select action with policy
            action = ppo_agent.select_action(state)

            tarAction = deepcopy(action)
            if args.ATTACK:
                tarAction, wh = attacker.antiAction(action, t-1, state)

            t_state = deepcopy(state)

            state, reward, done, _ = env.step(tarAction)

            # saving reward and is_terminals
            ppo_agent.buffer.rewards.append(reward)
            ppo_agent.buffer.is_terminals.append(done)
            if args.ATTACK:
                tarj.append([tarAction, reward, t_state, state, wh])
            time_step +=1
            current_ep_reward += reward

            # update PPO agent
            if time_step % update_timestep == 0:
                ppo_agent.update()

            # if continuous action space; then decay action std of ouput action distribution
            if has_continuous_action_space and time_step % action_std_decay_freq == 0:
                ppo_agent.decay_action_std(action_std_decay_rate, min_action_std)

            # break; if the episode is over
            if done:
                # printing average reward
                e_t.append(t)
                if i_episode > 0 and i_episode % print_freq == 0:
                    # print average reward till last episode
                    print_avg_reward = print_running_reward / print_running_episodes
                    print_avg_reward = round(print_avg_reward, 2)

                    print("Episode : {} \t\t Timestep : {} \t\t Average Reward : {:<10} current_ep_reward : {}".format(i_episode, time_step, print_avg_reward,round(current_ep_reward,2)))

                    print_running_reward = 0
                    print_running_episodes = 0
                # save model weights
                if i_episode > 0 and i_episode % save_model_freq == 0:
                    print(
                        "--------------------------------------------------------------------------------------------")
                    checkpoint_path = directory + "PPO_{}_{}_{}.pth".format(env_name, i_episode, round(current_ep_reward,2))
                    print("saving model at : " + checkpoint_path)
                    ppo_agent.save(checkpoint_path)
                    print("model saved")
                    print("Elapsed Time  : ", datetime.now().replace(microsecond=0) - start_time)
                    print(
                        "--------------------------------------------------------------------------------------------")
                    np.save(directory + "reward.npy", np.array(re))
                    np.save(directory + "steps.npy", np.array(e_t))
                    if args.ATTACK:
                        print("similarity save")
                        np.save(directory + "sim.npy", np.array(attacker.similarity))

                # log in logging file
                if i_episode > 0 and i_episode % log_freq == 0:
                    # log average reward till last episode
                    log_avg_reward = log_running_reward / log_running_episodes
                    log_avg_reward = round(log_avg_reward, 4)

                    log_f.write('{},{},{}\n'.format(i_episode, time_step, log_avg_reward))
                    log_f.flush()

                    log_running_reward = 0
                    log_running_episodes = 0

                if args.ATTACK:
                    attacker.update(tarj)
                    tarj = []
                break
        re.append(current_ep_reward)
        print_running_reward += current_ep_reward
        print_running_episodes += 1

        log_running_reward += current_ep_reward
        log_running_episodes += 1

        i_episode += 1

    log_f.close()
    env.close()

    # print total training time
    print("============================================================================================")
    end_time = datetime.now().replace(microsecond=0)
    print("Started training at (GMT) : ", start_time)
    print("Finished training at (GMT) : ", end_time)
    print("Total training time  : ", end_time - start_time)
    print("============================================================================================")


    
    
    
    
    
    
    
