from fourrooms import FourRooms
from utils import *
from time import sleep
import numpy as np
import multiprocessing as mp
import time
import os
import shutil
from copy import copy

# set log
log_dir = "./result/2stage/"
model_path = None
cpt = 1000
# Discount
discount = 0.99
# Learning rates - termination, intra-option, critic
lr_term = 0.25
lr_intra = 0.25
lr_critic = 0.5
# Epsilon for epsilon-greedy for policy over options
epsilon = 1e-1
temperature = 1e-3
gamma = 0
delib = 0
# Number of runs
nruns = 10
# Number of episodes per run
nepisodes = 1000
# Maximum number of steps per episode
nsteps = 1000
# Insert episode to change termination
insert_ep = [500]
# Number of options
noptions = [1,4,8]
# Number of process
nprocess = 10

def train(rank, process_runs, csv_name, model_dir):
    for ops in noptions:
        history_name = os.path.join(csv_name, f"{ops}op")
        if not os.path.exists(history_name):
            os.makedirs(history_name)
        num_runs = len(process_runs)
        policy_over_option_select = np.zeros((num_runs, nepisodes, 2, ops))
        # History of steps and average durations
        history_op = np.zeros((num_runs, nepisodes, 2))

        for index, run in enumerate(process_runs):
            # Random number generator for reproducability
            rng = np.random.RandomState(run)

            env = FourRooms()
            
            nstates = env.observation_space.shape[0]
            nactions = env.action_space.shape[0]
            
            # 1. The intra-option policies - linear softmax functions
            option_policies = [SoftmaxPolicy(rng, lr_intra, nstates, nactions, temperature) for _ in range(ops)]
            
            # 2. The termination function - linear sigmoid function
            option_terminations = [SigmoidTermination(rng, lr_term, nstates) for _ in range(ops)]
            
            # 3. The epsilon-greedy policy over options
            policy_over_options = EpsGreedyPolicy(rng, nstates, ops, epsilon)
            
            # Critic
            critic = Critic(lr_critic, gamma, option_policies, delib, discount, policy_over_options.Q_Omega_table, nstates, ops, nactions)
            
            # load model
            if model_path is not None:
                load_name = os.path.join(model_path,f"runs_{run}_op_{ops}_cpt_{cpt}.npz")
                CPT = np.load(load_name)
                for o in range(ops):
                    option_policies[o].weights = CPT["option_policy"][o]
                    option_terminations[o].weights = CPT["termination"][o]
                policy_over_options.Q_Omega_table = CPT["policy_over_option"]
                new_term_table = CPT["new_term_table"]
            
            print('Goal: ', env.goal)
            termination_off = False
            start_ep = insert_ep[0]
            new_term_table = None
            test = False

            for episode in range(nepisodes):
                if episode >= insert_ep[0] and not termination_off:
                    termination_off = True
                    new_term_table = copy_term(option_terminations, nstates)

                state = env.reset(rng)
                option = policy_over_options.sample(state)
                action = option_policies[option].sample(state)
                critic.cache(state, option, action)
                
                duration = 1
                option_switches = 0
                avg_duration = 0.0
                
                for step in range(nsteps):
                    
                    state, reward, done, _ = env.step(action, rng)
                    
                    # Termination might occur upon entering new state
                    if not termination_off:
                        if option_terminations[option].sample(state):
                            option = policy_over_options.sample(state)
                            option_switches += 1
                            avg_duration += (1.0/option_switches)*(duration - avg_duration)
                            duration = 1
                    else:
                        term_prob = new_term_table[option, state]
                        new_term_prob = min(term_prob + 0.01*(episode-start_ep), 1.)
                        new_term_table[option, state] = new_term_prob
                        if int(rng.uniform() < new_term_prob):
                            option = policy_over_options.sample(state)
                            option_switches += 1
                            avg_duration += (1.0/option_switches)*(duration - avg_duration)
                            duration = 1


                    policy_over_option_select[index, episode, 0, option] += 1
                    action = option_policies[option].sample(state)
                    
                    if not test:
                        # Critic update
                        critic.update_Qs(state, option, action, reward, done, option_terminations, new_term_table)
                        
                        if not termination_off:
                            # Intra-option policy update with baseline
                            Q_U = critic.Q_U(state, option, action)
                            Q_U = Q_U - critic.Q_Omega(state, option)
                            option_policies[option].update(state, action, Q_U)
                        
                            # Termination condition update
                            option_terminations[option].update(state, critic.A_Omega(state, option))
                    
                    duration += 1

                    if done:
                        break
    
                history_op[index, episode, 0] = step
                history_op[index, episode, 1] = avg_duration

                # save model
                if episode%10 == 0 and episode>0 and not test:
                    op_arr = np.zeros((ops,nstates,nactions))
                    ter_arr = np.zeros((ops,nstates))
                    for o in range(ops):
                        op_arr[o] = copy(option_policies[o].weights)
                        ter_arr[o] = copy(option_terminations[o].weights)
                    poo_arr = copy(policy_over_options.Q_Omega_table)
                    ter_table_arr = copy(new_term_table)
                    model_name = os.path.join(model_dir,f"runs_{run}_op_{ops}_cpt_{episode}.npz")
                    np.savez(model_name,option_policy=op_arr,termination=ter_arr,policy_over_option=poo_arr,new_term_table=ter_table_arr)


        # write history
        op_log_step = np.mean(history_op[:,:,0], axis=0)
        op_log_avg_duration = np.mean(history_op[:,:,1], axis=0)
        op_log_avg_opselect = np.mean(policy_over_option_select, axis=0)
        step_csv = history_name + f"/op{ops}_log_step_m1_notupdatepi_run_{rank}.csv"
        avg_duration_csv = history_name + f"/op{ops}_log_avg_duration_m1_notupdatepi_run_{rank}.csv"
        avg_opselect_csv = history_name + f"/op{ops}_log_avg_opselect_m1_notupdatepi_run_{rank}.csv"
        np.savetxt(step_csv, op_log_step, fmt='%4d', delimiter=' ')
        np.savetxt(avg_duration_csv, op_log_avg_duration, fmt='%2f', delimiter=' ')
        np.savetxt(avg_opselect_csv, np.mean(op_log_avg_opselect, axis=0), fmt='%4d', delimiter=' ')
    

def split_list_n_list(origin_list, n):
    if len(origin_list) % n == 0:
        cnt = len(origin_list) // n
    else:
        cnt = len(origin_list) // n + 1
    for i in range(0, n):
        yield origin_list[i*cnt:(i+1)*cnt]


if __name__ == '__main__':
    
    # checkpoint
    ex_name = time.strftime("%m-%d_%H-%M-%S", time.localtime())
    save_name = os.path.join(log_dir, ex_name)
    csv_name = os.path.join(save_name, "csvs")
    model_dir = os.path.join(save_name,"model")
    if not os.path.exists(csv_name):
        os.makedirs(csv_name)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)


    runs_l = list(range(nruns))
    each_p_runs = split_list_n_list(runs_l, nprocess)
    subprocess = []
    for rank, pr in enumerate(each_p_runs):
        p = mp.Process(target=train, args=(rank, pr, csv_name, model_dir))
        p.start()
        subprocess.append(p)

    for p in subprocess:
        p.join()


