# -*- coding: utf-8 -*-

# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-



import os
import random
import sys
import time

from mlt_maze_env import MazeEnv
from agent_pretrained import SacAgent
import argparse

from util_maze import *

# multiprocess
from multiprocessing import Pool, Queue, Pipe, Process




parser = argparse.ArgumentParser()


parser.add_argument("--env_name", default='Maze_U_shape', type=str)
parser.add_argument("--train_num", default=5, type=int)
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--resolution", default=(224,224))
parser.add_argument("--state_dim", default=(3,224,224))
parser.add_argument("--action_dim", type=int, default=4)
parser.add_argument("--mem_size",type=int, default=300)
parser.add_argument("--embed_dim",type=int, default=512)
parser.add_argument("--goal_dim",type=int, default=20)
parser.add_argument("--learn_rate",type=float, default=0.0001)
parser.add_argument("--gamma",type=float, default=0.95)
parser.add_argument("--save_path",type=str, default='model_save_high/')
parser.add_argument("--load_model", type=int, default=None, help="input load ep number")
parser.add_argument("--mltpro_num",type=int, default=4)
parser.add_argument("--ep_max",type=int, default=6000)
parser.add_argument("--rand_h", action = "store_true", default = False)
parser.add_argument("--frame_skip",type=int, default=25)
parser.add_argument("--dense", action = "store_true", default = False)


args = parser.parse_args()






if __name__ == '__main__':
    if args.rand_h:

        agent = SacAgent()
        agent.load_models()
        task = args.env_name      # 'Maze_S_shape', 'Maze_spiral_shape', 'Maze_square_random', 'Maze_U_shape'
        # env = MazeEnv(maze_id=task, resolution=args.resolution)
        print('task bagin')
        result_path = './result_rand_h/'
        if not os.path.isdir(result_path):
            os.mkdir(result_path)
        ep_max = args.ep_max
        ep_num = 0
        success = 0
        record_rH = []
        success_recoder = []
        test_success_recorder_temp = 0.
        test_reward_recorder_temp = 0.
        test_success_recorder = []
        test_reward_recorder = []
        reward_flag_before_test = 0.
        total_reward_H = 0.
        test_effect_num = 0

        q0 = Queue(300)
        q1 = Queue(300)
        q2 = Queue(300)
        q3 = Queue(300)

        qb0 = Queue(300)
        qb1 = Queue(300)
        qb2 = Queue(300)
        qb3 = Queue(300)

        finish0 = False
        finish1 = False
        finish2 = False
        finish3 = False
        total_finish = False



        for i in range(args.mltpro_num):
            exec ('p%s = Process(target = subprocess, args = (task, q%s, qb%s, %d, args))' % ( i, i, i, i))
            eval('p%s.start()' % i)

        agent.to_cuda(args.device)
        while ep_num <= ep_max:
            while 1:
                for i in range(args.mltpro_num):
                    q = eval('q%s' % i)
                    if not q.empty():
                        get = q.get()
                        if get[0] != True:
                            if get[0] == 'high':
                                [_, s_pix, action, rewardH, rewardL, s_pix_next, done, subgoal] = get
                                total_reward_H += rewardH
                        else:
                            [_, done] = get
                            if done:
                                success += 1
                            exec ('finish%s = True' % i)
                            break

                total_finish = True
                for i in range(args.mltpro_num):
                    exec ('total_finish = total_finish and finish%s' % i)


                if total_finish:
                    total_finish = False
                    for i in range(args.mltpro_num):
                        exec ('finish%s = False' % i)

                    break


            ep_num += args.mltpro_num


            print('All subpro done.', 'ep:', ep_num)


            for i in range(args.mltpro_num):
                eval('qb%s.put([True])'%(i,))

            iter_num = 25
            if ep_num % (args.mltpro_num * iter_num) == 0 and ep_num != 0:
                record_rH.append(total_reward_H / (args.mltpro_num * iter_num))
                success_recoder.append(success / (args.mltpro_num * iter_num))
                print('==========================================================================')
                print('==========================================================================')
                print(ep_num, '*** curr reward ***', record_rH[-1])
                print(ep_num, '*** curr success ***', success_recoder[-1])

                print(ep_num, '*** average reward  ***', record_rH)
                print(ep_num, '*** average success rate ***', success_recoder)
                with open('{}reward_record_MJ_{}_{}_rand-H.txt'.format(result_path, task,args.train_num), 'w+') as f:
                    f.write('{}\n'.format(ep_num))
                    f.write('{}\n'.format('evaluate average reward per 100 ep'))
                    f.write('{}\n'.format(record_rH))
                    f.write('{}\n'.format('success rate'))
                    f.write('{}\n'.format(success_recoder))

                total_reward_H = 0.
                success = 0

        for i in range(args.mltpro_num):
            eval('p%s.terminate()' % i)
            eval('p%s.join()' % i)
            eval('p%s.close()' % i)
        sys.exit()






    else:
        if not os.path.isdir(args.save_path):
            os.mkdir(args.save_path)
        if not os.path.isdir(args.save_path + args.env_name + '_model_save_{}/'.format(args.train_num)):
            os.mkdir(args.save_path + args.env_name + '_model_save/'.format(args.train_num))

        MOD_SAVE_PATH = args.save_path + args.env_name + '_model_save/'.format(args.train_num)
        agent = high_agent(args)
        agent_low = SacAgent()
        agent_low.load_models()
        task = args.env_name      # 'Maze_S_shape', 'Maze_spiral_shape', 'Maze_square_random', 'Maze_U_shape'
        env = MazeEnv(maze_id=task, resolution=args.resolution)
        print('task bagin')
        ep_max = args.ep_max
        ep_num = 0
        success = 0
        record_rH = []
        success_recoder = []
        test_success_recorder_temp = 0.
        test_reward_recorder_temp = 0.
        test_success_recorder = []
        test_reward_recorder = []
        reward_flag_before_test = 0.
        total_reward_H = 0.
        test_effect_num = 0
        ep_step_total_temp = 0
        ep_step_total_test_temp = 0
        q0 = Queue(300)
        q1 = Queue(300)
        q2 = Queue(300)
        q3 = Queue(300)

        qb0 = Queue(300)
        qb1 = Queue(300)
        qb2 = Queue(300)
        qb3 = Queue(300)

        finish0 = False
        finish1 = False
        finish2 = False
        finish3 = False
        total_finish = False

        for i in range(args.mltpro_num):
            exec ('p%s = Process(target = subprocess, args = (task, q%s, qb%s, %d, args))' % ( i, i, i, i))
            eval('p%s.start()' % i)


        if args.load_model is not None:
            print ('loading model ==========================')
            ep_num = args.load_model
            agent.load_model_frompath(MOD_SAVE_PATH + 'epi_{:08d}'.format(ep_num))

        agent.agent_low = agent_low
        agent.net2cuda(args.device)
        start_time = time.time()
        curr_time = time.time()
        while ep_num <= ep_max:
            subgoals = [[], [], [], [],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[]]
            while 1:
                for i in range(args.mltpro_num):
                    q = eval('q%s' % i)
                    if not q.empty():
                        get = q.get()
                        if get[0] != True:
                            if get[0] == 'high':
                                [_, s_pix, action, rewardH, rewardL, s_pix_next, done, subgoal] = get
                                agent.memory.save_exp(s_pix, action, rewardL, s_pix_next, done, i)
                                total_reward_H += rewardH
                                reward_flag_before_test += rewardH
                                subgoals[i].append(subgoal)
                                if args.dense:
                                    if rewardL >0:
                                        success += 1
                        else:
                            [_, done, ep_step] = get
                            ep_step_total_temp += ep_step
                            if done and not args.dense:
                                success += 1
                            exec ('finish%s = True' % i)
                            break

                total_finish = True
                for i in range(args.mltpro_num):
                    exec ('total_finish = total_finish and finish%s' % i)


                if total_finish:
                    total_finish = False
                    for i in range(args.mltpro_num):
                        exec ('finish%s = False' % i)

                    break


            ep_num += args.mltpro_num

            agent.train_mlt(goal_=subgoals, pro_num_=args.mltpro_num)


            print('All subpro done.', 'ep:', ep_num)

            trained_parameters = agent.get_curr_para()

            if ep_num % 10000 == 0:
                agent.entropy_rate = agent.entropy_rate - 0.1
                if agent.entropy_rate < 0.1:
                    agent.entropy_rate = 0.1

            for idx in range(len(trained_parameters)):
                for name in trained_parameters[idx]:
                    trained_parameters[idx][name] = trained_parameters[idx][name].data.cpu()

            for i in range(args.mltpro_num):
                eval('qb%s.put([trained_parameters,])'%(i,))


            if ep_num % (args.mltpro_num * 50) == 0:
                print('saving model ===================================')
                if args.dense:
                    agent.save_model(MOD_SAVE_PATH + 'epi_%08d' % (ep_num) + '_dense')
                else:
                    agent.save_model(MOD_SAVE_PATH + 'epi_%08d' % (ep_num))

            test_iter_step = int(args.mltpro_num * 5)

            if ep_num % test_iter_step == 0:
                if reward_flag_before_test > 0:
                    reward_flag_before_test = 0.
                    #####   test with low-level
                    frame_skip = args.frame_skip
                    test_ep = 10
                    goal = np.zeros(20)
                    for idx in range(test_ep):
                        obs, obs_pos = env.reset()  # img, goal_onehot, goal_word
                        s_pix = img_preprocess(obs)
                        effect = True
                        test_step = 0
                        frame_reward = 0
                        ep_reward_test = 0
                        ep_max_step = env.max_step
                        test_success = 0
                        while True:
                            test_step += 1
                            action_high = agent.act_forSBG(s_pix, goal)
                            if action_high < 4:
                                for _ in range(frame_skip):
                                    action = agent.select_action_low(action_high, obs_pos)
                                    next_obs, dense_reward, done, next_obs_pos, sparse_reward = env.step(action)
                                    obs_pos = next_obs_pos
                                    ep_step_total_test_temp += 1
                                    if next_obs_pos[0] < 0.3:
                                        effect = False
                                    if args.dense:
                                        ep_reward_test += dense_reward
                                        if sparse_reward >0:
                                            test_success += 1
                                    else:
                                        ep_reward_test += sparse_reward
                                    if done or not effect:
                                        break
                            if done and not args.dense:
                                test_success += 1
                            next_obs_pix = img_preprocess(next_obs)
                            s_pix = next_obs_pix
                            obs_pos = next_obs_pos

                            if done or test_step >= ep_max_step:
                                test_effect_num += 1
                                break
                        test_success_recorder_temp += test_success
                        test_reward_recorder_temp += ep_reward_test
                else:
                    test_success_recorder_temp += 0
                    test_reward_recorder_temp += 0

            test_save_iter_step = int(args.mltpro_num * 25)
            if ep_num % test_save_iter_step == 0:
                if args.dense:
                    print('****************************************************')
                    test_success_recorder.append(test_success_recorder_temp * 1.0 / test_effect_num)
                    test_reward_recorder.append(test_reward_recorder_temp * 1.0 / ep_step_total_test_temp)
                    print('test reward:{},step:{},dense'.format(test_reward_recorder_temp, ep_step_total_test_temp))
                    print('test success rate:{},dense'.format(test_success_recorder))
                    print('test reward {},dense'.format(test_reward_recorder))

                    with open('reward_record_MJ_test_{}_{}_dense.txt'.format(task, args.train_num), 'w+') as f:
                        f.write('{}\n'.format(ep_num))
                        f.write('{}\n'.format(test_success_recorder))
                        f.write('{}\n'.format('reward'))
                        f.write('{}\n'.format(test_reward_recorder))

                else:
                    print('****************************************************')
                    test_success_recorder.append(test_success_recorder_temp * 1.0 / test_effect_num)
                    test_reward_recorder.append(test_reward_recorder_temp * 1.0 / test_effect_num)
                    print('test reward:{},effe_num:{}'.format(test_reward_recorder_temp, test_effect_num))
                    print('test success rate:{}'.format(test_success_recorder))
                    print('test reward {}'.format(test_reward_recorder))

                    with open('reward_record_MJ_test_{}_{}.txt'.format(task,args.train_num), 'w+') as f:
                        f.write('{}\n'.format(ep_num))
                        f.write('{}\n'.format(test_success_recorder))
                        f.write('{}\n'.format('reward'))
                        f.write('{}\n'.format(test_reward_recorder))
                test_success_recorder_temp = 0.
                test_reward_recorder_temp = 0.
                test_effect_num = 0
                ep_step_total_test_temp = 0

            iter_num = 25
            if ep_num % (args.mltpro_num * iter_num) == 0 and ep_num != 0:
                curr_time = time.time() - start_time
                if args.dense:
                    record_rH.append(total_reward_H / ep_step_total_temp)
                    success_recoder.append(success / (args.mltpro_num * iter_num))
                    print('==========================================================================')
                    print('==========================================================================')
                    print(ep_num, '*** curr reward ***', record_rH[-1])
                    print(ep_num, '*** curr success ***', success_recoder[-1])

                    print(ep_num, '*** average reward  ***', record_rH)
                    print(ep_num, '*** average success rate ***', success_recoder)
                    with open('reward_record_MJ_{}_{}_dense.txt'.format(task, args.train_num), 'w+') as f:
                        # f.write('{}\n'.format())
                        f.write('time: {} s\n'.format(curr_time))

                        f.write('{}\n'.format(ep_num))
                        f.write('{}\n'.format('evaluate average reward per 100 ep'))
                        f.write('{}\n'.format(record_rH))
                        f.write('{}\n'.format('success rate'))
                        f.write('{}\n'.format(success_recoder))

                else:
                    record_rH.append(total_reward_H / (args.mltpro_num * iter_num))
                    success_recoder.append(success / (args.mltpro_num * iter_num))
                    print('==========================================================================')
                    print('==========================================================================')
                    print(ep_num, '*** curr reward ***', record_rH[-1])
                    print(ep_num, '*** curr success ***', success_recoder[-1])

                    print(ep_num, '*** average reward  ***', record_rH)
                    print(ep_num, '*** average success rate ***', success_recoder)
                    with open('reward_record_MJ_{}_{}.txt'.format(task,args.train_num), 'w+') as f:
                        f.write('time: {} s\n'.format(curr_time))


                        f.write('{}\n'.format(ep_num))
                        f.write('{}\n'.format('evaluate average reward per 100 ep'))
                        f.write('{}\n'.format(record_rH))
                        f.write('{}\n'.format('success rate'))
                        f.write('{}\n'.format(success_recoder))
                ep_step_total_temp = 0
                total_reward_H = 0.
                success = 0

        # for i in range(args.mltpro_num):
        #     eval('qb%s.put([False,])'%(i,))
        # time.sleep(10)
        for i in range(args.mltpro_num):
            eval('p%s.terminate()' % i)
            eval('p%s.join()' % i)
            eval('p%s.close()' % i)
        sys.exit()









