# -*- coding: utf-8 -*-

# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-



import os
import random
import sys
import time

from mlt_maze_env import MazeEnv
from agent_pretrained import SacAgent
import argparse

from util_maze import *

# multiprocess
from multiprocessing import Pool, Queue, Pipe, Process




parser = argparse.ArgumentParser()


parser.add_argument("--env_name", default='Maze_square_random', type=str) #  Maze_square_random, Maze_square_blocked
parser.add_argument("--train_num", default=6801, type=int)  ## 4800 norndview, 2randdoor
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--resolution", default=(224,224))
parser.add_argument("--state_dim", default=(3,224,224))
parser.add_argument("--action_dim", type=int, default=4)
parser.add_argument("--mem_size",type=int, default=300)
parser.add_argument("--embed_dim",type=int, default=512)
parser.add_argument("--goal_dim",type=int, default=20)
parser.add_argument("--learn_rate",type=float, default=0.0001) ###   0.0001,  3800 try to use 0.00002 full random
parser.add_argument("--gamma",type=float, default=0.95)
parser.add_argument("--save_path",type=str, default='model_save_high/')
parser.add_argument("--load_model", type=int, default=None, help="input load ep number")
parser.add_argument("--mltpro_num",type=int, default=4)
parser.add_argument("--ep_max",type=int, default=10000)
parser.add_argument("--rand_h", action = "store_true", default = False)
parser.add_argument("--frame_skip",type=int, default=25)
parser.add_argument("--dense", action = "store_true", default = False)
parser.add_argument("--abla", action = "store_true", default = False)
parser.add_argument("--random_view", action = "store_true", default = False)
parser.add_argument("--find_obj", action = "store_true", default = False)



args = parser.parse_args()






if __name__ == '__main__':
    # if args.rand_h:
    #
    #     agent = SacAgent()
    #     agent.load_models()
    #     task = args.env_name      # 'Maze_S_shape', 'Maze_spiral_shape', 'Maze_square_random', 'Maze_U_shape'
    #     # env = MazeEnv(maze_id=task, resolution=args.resolution)
    #     print('task bagin')
    #     result_path = './result_rand_h/'
    #     if not os.path.isdir(result_path):
    #         os.mkdir(result_path)
    #     ep_max = args.ep_max
    #     ep_num = 0
    #     success = 0
    #     record_rH = []
    #     success_recoder = []
    #     test_success_recorder_temp = 0.
    #     test_reward_recorder_temp = 0.
    #     test_success_recorder = []
    #     test_reward_recorder = []
    #     reward_flag_before_test = 0.
    #     total_reward_H = 0.
    #     test_effect_num = 0
    #
    #     q0 = Queue(300)
    #     q1 = Queue(300)
    #     q2 = Queue(300)
    #     q3 = Queue(300)
    #
    #     qb0 = Queue(300)
    #     qb1 = Queue(300)
    #     qb2 = Queue(300)
    #     qb3 = Queue(300)
    #
    #     finish0 = False
    #     finish1 = False
    #     finish2 = False
    #     finish3 = False
    #     total_finish = False
    #
    #
    #
    #     for i in range(args.mltpro_num):
    #         exec ('p%s = Process(target = subprocess_for_vrf_mj, args = (task, q%s, qb%s, %d, args))' % ( i, i, i, i))
    #         eval('p%s.start()' % i)
    #
    #     agent.to_cuda(args.device)
    #     while ep_num <= ep_max:
    #         while 1:
    #             for i in range(args.mltpro_num):
    #                 q = eval('q%s' % i)
    #                 if not q.empty():
    #                     get = q.get()
    #                     if get[0] != True:
    #                         if get[0] == 'high':
    #                             [_, s_pix, action, rewardH, rewardL, s_pix_next, done, subgoal] = get
    #                             total_reward_H += rewardH
    #                     else:
    #                         [_, done] = get
    #                         if done:
    #                             success += 1
    #                         exec ('finish%s = True' % i)
    #                         break
    #
    #             total_finish = True
    #             for i in range(args.mltpro_num):
    #                 exec ('total_finish = total_finish and finish%s' % i)
    #
    #
    #             if total_finish:
    #                 total_finish = False
    #                 for i in range(args.mltpro_num):
    #                     exec ('finish%s = False' % i)
    #
    #                 break
    #
    #
    #         ep_num += args.mltpro_num
    #
    #
    #         print('All subpro done.', 'ep:', ep_num)
    #
    #
    #         for i in range(args.mltpro_num):
    #             eval('qb%s.put([True])'%(i,))
    #
    #         iter_num = 25
    #         if ep_num % (args.mltpro_num * iter_num) == 0 and ep_num != 0:
    #             record_rH.append(total_reward_H / (args.mltpro_num * iter_num))
    #             success_recoder.append(success / (args.mltpro_num * iter_num))
    #             print('==========================================================================')
    #             print('==========================================================================')
    #             print(ep_num, '*** curr reward ***', record_rH[-1])
    #             print(ep_num, '*** curr success ***', success_recoder[-1])
    #
    #             print(ep_num, '*** average reward  ***', record_rH)
    #             print(ep_num, '*** average success rate ***', success_recoder)
    #             with open('{}reward_record_MJ_{}_{}_rand-H.txt'.format(result_path, task,args.train_num), 'w+') as f:
    #                 f.write('{}\n'.format(ep_num))
    #                 f.write('{}\n'.format('evaluate average reward per 100 ep'))
    #                 f.write('{}\n'.format(record_rH))
    #                 f.write('{}\n'.format('success rate'))
    #                 f.write('{}\n'.format(success_recoder))
    #
    #             total_reward_H = 0.
    #             success = 0
    #
    #     for i in range(args.mltpro_num):
    #         eval('p%s.terminate()' % i)
    #         eval('p%s.join()' % i)
    #         eval('p%s.close()' % i)
    #     sys.exit()
    #
    #
    #
    #
    #
    #
    # else:

        # if not os.path.isdir(args.save_path):
        #     os.mkdir(args.save_path)
        # if not os.path.isdir(args.save_path + args.env_name + '_model_save_{}/'.format(args.train_num)):
        #     os.mkdir(args.save_path + args.env_name + '_model_save/'.format(args.train_num))



    MOD_SAVE_PATH = args.save_path + args.env_name + '_model_save/'.format(args.train_num)
    if not os.path.isdir(MOD_SAVE_PATH):
        os.mkdir(MOD_SAVE_PATH)

    agent = high_agent(args)
    agent_low = SacAgent()
    agent_low.load_models()
    task = args.env_name      # 'Maze_S_shape', 'Maze_spiral_shape', 'Maze_square_random', 'Maze_U_shape'
    # env = MazeEnv(maze_id=task, resolution=args.resolution)
    print('task bagin')
    ep_max = args.ep_max
    ep_num = 0
    success = 0
    record_rH = []
    # success_recoder = []
    test_success_recorder_temp = 0.
    test_reward_recorder_temp = 0.
    test_success_recorder = []
    test_reward_recorder = []
    # total_goal_true_flag_rocorder = []

    # reward_flag_before_test = 0.
    total_reward_H = 0.
    test_effect_num = 0
    ep_step_total_temp = 0
    ep_step_total_test_temp = 0
    # total_goal_true_flag = 0


    q0 = Queue(300)
    q1 = Queue(300)
    q2 = Queue(300)
    q3 = Queue(300)

    qb0 = Queue(300)
    qb1 = Queue(300)
    qb2 = Queue(300)
    qb3 = Queue(300)

    finish0 = False
    finish1 = False
    finish2 = False
    finish3 = False
    total_finish = False

    for i in range(args.mltpro_num):
        exec ('p%s = Process(target = subprocess, args = (task, q%s, qb%s, %d, args))' % ( i, i, i, i))
        eval('p%s.start()' % i)


    if args.load_model is not None:
        print ('loading model ==========================')
        ep_num = args.load_model
        agent.load_model_frompath(MOD_SAVE_PATH + 'epi_{:08d}'.format(ep_num))

    # agent.agent_low = agent_low
    agent.net2cuda(args.device)
    start_time = time.time()
    curr_time = time.time()
    while ep_num <= ep_max:
        subgoals = [[], [], [], [],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[]]
        while 1:
            for i in range(args.mltpro_num):
                q = eval('q%s' % i)
                if not q.empty():
                    get = q.get()
                    if get[0] != True:
                        if get[0] == 'high':
                            # q.put(['high', s_pix, action, reward, goal_true_flag, s_pix, done, goal])

                            #q.put(['high', s_pix, action, sparse_reward, s_pix_next, done, goal])


                            [_, s_pix, action, reward, s_pix_next, done, subgoal] = get
                            agent.memory.save_exp(s_pix, action, reward, s_pix_next, done, i)
                            total_reward_H += reward
                            # reward_flag_before_test += rewardH
                            # total_goal_true_flag += goal_true_flag
                            subgoals[i].append(subgoal)
                            # if args.dense:
                            #     if reward >0:
                            #         success += 1
                    else:
                        [_, done, ep_step] = get
                        ep_step_total_temp += ep_step
                        if done and not args.dense:
                            success += 1
                        exec ('finish%s = True' % i)
                        break

            total_finish = True
            for i in range(args.mltpro_num):
                exec ('total_finish = total_finish and finish%s' % i)


            if total_finish:
                total_finish = False
                for i in range(args.mltpro_num):
                    exec ('finish%s = False' % i)

                break


        ep_num += args.mltpro_num

        agent.train_mlt(goal_=subgoals, pro_num_=args.mltpro_num)


        print('All subpro done.', 'ep:', ep_num)

        trained_parameters = agent.get_curr_para()

        if ep_num % 10000 == 0:
            agent.entropy_rate = agent.entropy_rate - 0.1
            if agent.entropy_rate < 0.1:
                agent.entropy_rate = 0.1

        for idx in range(len(trained_parameters)):
            for name in trained_parameters[idx]:
                trained_parameters[idx][name] = trained_parameters[idx][name].data.cpu()

        for i in range(args.mltpro_num):
            eval('qb%s.put([trained_parameters,])'%(i,))


        # if ep_num % (args.mltpro_num * 50) == 0:
        #     print('saving model ===================================')
        #     if args.dense:
        #         agent.save_model(MOD_SAVE_PATH + 'epi_%08d' % (ep_num) + '_dense')
        #     else:
        #         agent.save_model(MOD_SAVE_PATH + 'epi_%08d' % (ep_num))


        iter_num = 25
        if ep_num % (args.mltpro_num * iter_num) == 0 and ep_num != 0:
            curr_time = time.time() - start_time

            record_rH.append(total_reward_H / (args.mltpro_num * iter_num))
            # total_goal_true_flag_rocorder.append(total_goal_true_flag/(args.mltpro_num * iter_num))
            # if total_goal_true_flag != 0:
            #     success_recoder.append(success / total_goal_true_flag)
            print('==========================================================================')
            print('==========================================================================')
            print(ep_num, '*** curr reward ***', record_rH[-1])
            print(ep_num, '*** average reward  ***', record_rH)

            with open('reward_record_{}_{}.txt'.format(args.env_name,args.train_num), 'w+') as f:
                # f.write('{}\n'.format())
                f.write('time: {} s\n'.format(curr_time))

                f.write('{}\n'.format(ep_num))
                f.write('{}\n'.format('evaluate average reward per 100 ep'))
                f.write('{}\n'.format(record_rH))
                # f.write('{}\n'.format('success rate'))
                # f.write('{}\n'.format(success_recoder))


            # total_goal_true_flag = 0
            total_reward_H = 0.
            success = 0

    # for i in range(args.mltpro_num):
    #     eval('qb%s.put([False,])'%(i,))
    # time.sleep(10)
    for i in range(args.mltpro_num):
        eval('p%s.terminate()' % i)
        eval('p%s.join()' % i)
        eval('p%s.close()' % i)
    sys.exit()









