# -*- coding: utf-8 -*-

# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-



import os
import random
import sys
import time

from mlt_maze_env2 import MazeEnv2
from agent_pretrained import SacAgent
import argparse

from util_maze import *

# multiprocess
from multiprocessing import Pool, Queue, Pipe, Process




parser = argparse.ArgumentParser()


parser.add_argument("--env_name", default='Maze_square_blocked2', type=str)   ###   Maze_square_blocked,  Maze_square_random
parser.add_argument("--train_num", default=9905, type=int)  ## 4800 norndview, 2randdoor, 6800 closed 4 ball lr 0.00001, 7800 far 4 ball lr 0.0001, far bigger yel ball, 7802 central closed 4 ball
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--resolution", default=(224,224))
parser.add_argument("--state_dim", default=(3,224,224))
parser.add_argument("--action_dim", type=int, default=4)
parser.add_argument("--mem_size",type=int, default=300)
parser.add_argument("--embed_dim",type=int, default=512)
parser.add_argument("--goal_dim",type=int, default=30)
parser.add_argument("--learn_rate",type=float, default=0.0001) ###   0.0001,  3800 try to use 0.00002 full random
parser.add_argument("--gamma",type=float, default=0.95)
parser.add_argument("--save_path",type=str, default='model_save_high/')
parser.add_argument("--load_model", type=int, default=None, help="input load ep number")
parser.add_argument("--mltpro_num",type=int, default=4)
parser.add_argument("--ep_max",type=int, default=10000)
parser.add_argument("--rand_h", action = "store_true", default = False)
parser.add_argument("--frame_skip",type=int, default=25)
parser.add_argument("--dense", action = "store_true", default = False)
parser.add_argument("--abla", action = "store_true", default = False)
parser.add_argument("--random_view", action = "store_true", default = False)
parser.add_argument("--find_obj", action = "store_true", default = True)



args = parser.parse_args()






if __name__ == '__main__':


    MOD_SAVE_PATH = args.save_path + args.env_name + '_model_save_find_obj/'.format(args.train_num)
    if not os.path.isdir(MOD_SAVE_PATH):
        os.mkdir(MOD_SAVE_PATH)

    agent = high_agent(args)
    agent_low = SacAgent()
    agent_low.load_models()
    task = args.env_name      # 'Maze_S_shape', 'Maze_spiral_shape', 'Maze_square_random', 'Maze_U_shape'
    # env = MazeEnv(maze_id=task, resolution=args.resolution)
    print('task bagin')
    ep_max = args.ep_max
    ep_num = 0
    success = 0
    record_rH = []
    success_recoder = []
    test_success_recorder_temp = 0.
    test_reward_recorder_temp = 0.
    test_success_recorder = []
    test_reward_recorder = []
    total_goal_true_flag_rocorder = []

    # reward_flag_before_test = 0.
    total_reward_H = 0.
    test_effect_num = 0
    ep_step_total_temp = 0
    ep_step_total_test_temp = 0
    total_goal_true_flag = 0


    q0 = Queue(300)
    q1 = Queue(300)
    q2 = Queue(300)
    q3 = Queue(300)

    qb0 = Queue(300)
    qb1 = Queue(300)
    qb2 = Queue(300)
    qb3 = Queue(300)

    finish0 = False
    finish1 = False
    finish2 = False
    finish3 = False
    total_finish = False

    for i in range(args.mltpro_num):
        exec ('p%s = Process(target = subprocess, args = (task, q%s, qb%s, %d, args))' % ( i, i, i, i))
        eval('p%s.start()' % i)


    if args.load_model is not None:
        print ('loading model ==========================')
        ep_num = args.load_model
        agent.load_model_frompath(MOD_SAVE_PATH + 'epi_{:08d}'.format(ep_num))

    # agent.agent_low = agent_low
    agent.net2cuda(args.device)
    start_time = time.time()
    curr_time = time.time()
    while ep_num <= ep_max:
        subgoals = [[], [], [], [],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[]]
        while 1:
            for i in range(args.mltpro_num):
                q = eval('q%s' % i)
                if not q.empty():
                    get = q.get()
                    if get[0] != True:
                        if get[0] == 'high':
                            # q.put(['high', s_pix, action, reward, goal_true_flag, s_pix, done, goal])
                            # q.put(['high', s_pix, action, sparse_reward, s_pix_next, done, goal])

                            [_, s_pix, action, reward, s_pix_next, done, subgoal] = get
                            agent.memory.save_exp(s_pix, action, reward, s_pix_next, done, i)
                            total_reward_H += reward
                            # reward_flag_before_test += rewardH
                            # total_goal_true_flag += goal_true_flag
                            subgoals[i].append(subgoal)
                            # if args.dense:
                            #     if reward >0:
                            #         success += 1
                    else:
                        [_, done, ep_step] = get
                        ep_step_total_temp += ep_step
                        if done and not args.dense:
                            success += 1
                        exec ('finish%s = True' % i)
                        break

            total_finish = True
            for i in range(args.mltpro_num):
                exec ('total_finish = total_finish and finish%s' % i)


            if total_finish:
                total_finish = False
                for i in range(args.mltpro_num):
                    exec ('finish%s = False' % i)

                break


        ep_num += args.mltpro_num

        agent.train_mlt(goal_=subgoals, pro_num_=args.mltpro_num)


        print('All subpro done.', 'ep:', ep_num)

        trained_parameters = agent.get_curr_para()

        if ep_num % 10000 == 0:
            agent.entropy_rate = agent.entropy_rate - 0.1
            if agent.entropy_rate < 0.1:
                agent.entropy_rate = 0.1

        for idx in range(len(trained_parameters)):
            for name in trained_parameters[idx]:
                trained_parameters[idx][name] = trained_parameters[idx][name].data.cpu()

        for i in range(args.mltpro_num):
            eval('qb%s.put([trained_parameters,])'%(i,))


        # if ep_num % (args.mltpro_num * 50) == 0:
        #     print('saving model ===================================')
        #     if args.dense:
        #         agent.save_model(MOD_SAVE_PATH + 'epi_%08d' % (ep_num) + '_dense')
        #     else:
        #         agent.save_model(MOD_SAVE_PATH + 'epi_%08d' % (ep_num))


        iter_num = 25
        if ep_num % (args.mltpro_num * iter_num) == 0 and ep_num != 0:
            curr_time = time.time() - start_time

            record_rH.append(total_reward_H / (args.mltpro_num * iter_num))
            # total_goal_true_flag_rocorder.append(total_goal_true_flag/(args.mltpro_num * iter_num))
            # if total_goal_true_flag != 0:
                # success_recoder.append(success / total_goal_true_flag)
            print('==========================================================================')
            print('==========================================================================')
            print(ep_num, '*** curr reward ***', record_rH[-1])
            # print(ep_num, '*** curr success ***', success_recoder[-1])
            # print(ep_num, '*** curr goal true flag ***', total_goal_true_flag_rocorder[-1])
            print(ep_num, '*** average reward  ***', record_rH)
            # print(ep_num, '*** average success rate ***', success_recoder)
            with open('reward_record_{}_{}_find_obj.txt'.format(args.env_name,args.train_num), 'w+') as f:
                # f.write('{}\n'.format())
                f.write('time: {} s\n'.format(curr_time))

                f.write('{}\n'.format(ep_num))
                f.write('{}\n'.format('evaluate average reward per 100 ep'))
                f.write('{}\n'.format(record_rH))
                # f.write('{}\n'.format('success rate'))
                # f.write('{}\n'.format(success_recoder))


            # total_goal_true_flag = 0
            total_reward_H = 0.
            success = 0

    # for i in range(args.mltpro_num):
    #     eval('qb%s.put([False,])'%(i,))
    # time.sleep(10)
    for i in range(args.mltpro_num):
        eval('p%s.terminate()' % i)
        eval('p%s.join()' % i)
        eval('p%s.close()' % i)
    sys.exit()









