import random

from high_agent import high_agent
import numpy as np
import cv2
import copy
import gym
from mlt_maze_env2 import MazeEnv2
from mlt_maze_env import MazeEnv

from agent_pretrained import SacAgent
import torch
import nltk


def img_preprocess(img, PREPRO=True, resolution = (224,224)):
    if PREPRO:
        img = cv2.resize(img, (resolution[1], resolution[0]), interpolation=cv2.INTER_LINEAR) / 255.0
        img = img.astype(np.float32)

        # cv2.imshow('img', img)
        # cv2.waitKey(30)

        img = np.moveaxis(img, 2, 0)
    img = np.expand_dims(img, axis=0)
    return img



def subprocess(task, q, qb, sub_num, args):
    if args.rand_h:
        agent = SacAgent()
        agent.load_models()
        agent.to_cuda(args.device)

        env = MazeEnv(maze_id=task, resolution=args.resolution, subprocess_num = sub_num, image_obs=False)
        ep_num = 0
        action_vec = np.array([[1, 1, 0, 0, 0, 0, 0, 0],
                            [0, 0, 1, 1, 0, 0, 0, 0],
                            [0, 0, 0, 0, 1, 1, 0, 0],
                            [0, 0, 0, 0, 0, 0, 1, 1]])
        while ep_num < args.ep_max:
            ep_num += 1
            obs, obs_pos = env.reset()  # img, goal_onehot, goal_word
            # s_pix = img_preprocess(obs)
            ep_r = 0
            ep_step = 1
            max_step = env.max_step
            frame_skip = args.frame_skip
            frame_reward = 0
            effect = True
            while True:
                action_high = random.randint(0,3)
                if action_high < 4:
                    for _ in range(frame_skip):
                        state_low = np.concatenate((obs_pos, action_vec[action_high]), axis=-1)
                        state_low = torch.from_numpy(state_low).cuda().float()
                        action = agent.exploit(state_low)
                        next_obs, dense_reward, done, next_obs_pos, sparse_reward = env.step(action)
                        obs_pos = next_obs_pos
                        if next_obs_pos[0] < 0.3:
                            effect = False
                        if args.dense:
                            frame_reward += dense_reward
                        else:
                            frame_reward += sparse_reward
                        if done or not effect:
                            break


                q.put(['high', None, action, frame_reward, sparse_reward, None, done, None])
                ep_r += frame_reward
                frame_reward = 0
                if done or ep_step >= max_step:
                    print('task:', task, 'ins:', 0, 'ep_step', ep_step, 'ep_reward:', ep_r)
                    break
                ep_step += 1

            finish = True
            q.put([finish, done])

            while True:
                if not qb.empty():
                    [flag, ] = qb.get()
                    if flag:
                        break


    else:
        agent = high_agent(args)
        agent.net2cuda(args.device)
        if args.find_obj:
            env = MazeEnv2(maze_id=task, resolution=args.resolution, subprocess_num=sub_num, view_flag=args.random_view)

        else:
            env = MazeEnv(maze_id=task, resolution=args.resolution, subprocess_num = sub_num, view_flag = args.random_view)
            goal = np.zeros(20)

        while 1:

            obs, obs_pos = env.reset()  # img, goal_onehot, goal_word
            s_pix = img_preprocess(obs)
            ep_r = 0
            ep_step = 1
            max_step = env.max_step
            if args.find_obj:
                goal = np.zeros(args.goal_dim)
                goal[env.curr_object_goal_idx*10:(env.curr_object_goal_idx+1)*10-1] = 1

            while True:
                # env.render()
                action = agent.act_forSBG(s_pix, goal)
                next_obs, dense_reward, done, next_obs_pos, sparse_reward = env.step_for_tab(action)
                s_pix_next = img_preprocess(next_obs)
                if args.dense:
                    reward = dense_reward
                else:
                    reward = sparse_reward
                q.put(['high', s_pix, action, sparse_reward, s_pix_next, done, goal])
                ep_r += reward
                if done or ep_step >= max_step:
                    if args.find_obj:
                        print('task:', task, 'ins:', env.curr_object_goal_idx, 'ep_step', ep_step, 'ep_reward:', ep_r,'done',done)
                    else:
                        print('task:', task, 'ins:', 0, 'ep_step', ep_step, 'ep_reward:', ep_r,'done',done)
                    break

                s_pix = copy.deepcopy(s_pix_next)
                ep_step += 1

            finish = True
            q.put([finish, done,ep_step])

            while True:
                if not qb.empty():
                    [total_parameter,] = qb.get()
                    if total_parameter:
                        for idx in range(len(total_parameter)):
                            for name in total_parameter[idx]:
                                total_parameter[idx][name] = total_parameter[idx][name].cuda()
                        # print(total_parameter)
                        agent.load_model_frommodel(total_parameter)
                        break
                    else:
                        print(total_parameter)
                        break
            if not total_parameter:
                print(total_parameter,'-=====')
                break



def subprocess_general(task, q, qb, sub_num, args):
    agent = high_agent(args)
    agent_low = SacAgent(robot=args.robot)

    if not args.oracle:
        if args.robot == 'Ant':
            agent_low.load_models()
        elif args.robot == 'Point':
            agent_low.load_models(4990, args.robot)
        elif args.robot == 'Swimmer' or args.robot == 'Swimmer-v2':
            agent_low.load_models(5087, args.robot)
        else:
            agent_low.load_models()
    # agent_low.load_models()
    agent.agent_low = agent_low
    agent.net2cuda(args.device)
    env = MazeEnv(maze_id=task, resolution=args.resolution, subprocess_num = sub_num, robot=args.robot)
    goal = np.zeros(20)
    frame_skip = args.frame_skip
    direction_dict = ['->', '<-', '^', 'V']

    while 1:
        obs, obs_pos = env.reset()  # img, goal_onehot, goal_word
        # env.render()
        s_pix = img_preprocess(obs)
        ep_r = 0
        ep_step = 1
        if args.robot =='Ant':
            max_step = env.max_step
        else:
            max_step = env.max_step * 3

        frame_reward = 0
        effect = True
        # xy_pos = env.get_curr_robot_xy()
        # q.put(['high', None, 0, frame_reward, 0, None, 0, xy_pos])
        while True:
            if args.oracle:
                action = agent.act_forSBG(s_pix, goal)
                # print(action, direction_dict[action])
                # env.render()
                next_obs, dense_reward, done, next_obs_pos, sparse_reward = env.step_for_tab(action)
                s_pix_next = img_preprocess(next_obs)
                if args.dense:
                    reward = dense_reward
                else:
                    reward = sparse_reward
                q.put(['high', s_pix, action, reward, sparse_reward, s_pix_next, done, goal])
                ep_r += reward
                if done or ep_step >= max_step:
                    print('task:', task, 'ins:', 0, 'ep_step', ep_step, 'ep_reward:', ep_r)
                    break

                s_pix = copy.deepcopy(s_pix_next)
                ep_step += 1
            else:
                action_high = agent.act_forSBG(s_pix, goal)
                if args.robot == 'Swimmer' or args.robot == 'Swimmer-v2':
                    if random.random() < 0.2:
                        action_high = random.randint(0, 3)

                # if action_high == 2 or action_high == 0:
                #     action_high = 1

                # print(action_high, direction_dict[action_high])
                if action_high < 4:
                    for _ in range(frame_skip):
                        # print(obs_pos.shape)
                        action = agent.select_action_low(action_high, obs_pos)
                        next_obs, dense_reward, done, next_obs_pos, sparse_reward = env.step(action)
                        obs_pos = next_obs_pos
                        # env.render()
                        if args.robot == 'Ant' and next_obs_pos[0] < 0.3:
                            effect = False

                        if args.dense:
                            frame_reward += dense_reward
                        else:
                            frame_reward += sparse_reward
                        if done or not effect:
                            break
                        xy_pos = env.get_curr_robot_xy()
                        s_pix_next = img_preprocess(next_obs)
                        q.put(['high', None, action, frame_reward, sparse_reward, None, done, xy_pos])
                ep_r += frame_reward
                frame_reward = 0
                if done or ep_step >= max_step or not effect:
                    print('task:', task, 'ins:', 0, 'ep_step', ep_step, 'ep_reward:', ep_r)
                    break
                s_pix = copy.deepcopy(s_pix_next)
                ep_step += 1


        finish = True
        q.put([finish, done, ep_step])

        while True:
            if not qb.empty():
                [total_parameter,flag] = qb.get()
                if total_parameter:
                    for idx in range(len(total_parameter)):
                        for name in total_parameter[idx]:
                            total_parameter[idx][name] = total_parameter[idx][name].cuda()
                    # print(total_parameter)
                    agent.load_model_frommodel(total_parameter)
                    break
                else:
                    print(total_parameter)
                    break
        if not total_parameter:
            print(total_parameter,'-=====')
            break


def subprocess_trans_baby(task, q, qb, sub_num, args):
    agent = high_agent(args)

    object_list = np.array([['red', 'key'], ['red', 'box'], ['red', 'ball'], ['red', 'door'],
                            ['green', 'key'], ['green', 'box'], ['green', 'ball'], ['green', 'door'],
                            ['blue', 'key'], ['blue', 'box'], ['blue', 'ball'], ['blue', 'door'],
                            ['purple', 'key'], ['purple', 'box'], ['purple', 'ball'], ['purple', 'door'],
                            ['yellow', 'key'], ['yellow', 'box'], ['yellow', 'ball'], ['yellow', 'door'], ])


    agent.net2cuda(args.device)
    if args.baby:
        env = gym.make(
            'BabyAI-%s-v0' % task)  ####！@#！@#￥！@#￥！@#￥！￥         注意多任务时 while循环应该提前               ####1！！！##！#！#￥@￥！@#￥！@#￥！@#￥！@#！@#
        env = env.unwrapped  # 不做这个会有很多限制

    else:
        raise NotImplementedError
        # env = MazeEnv(maze_id=task, resolution=args.resolution, subprocess_num = sub_num, robot=args.robot)
    frame_skip = args.frame_skip
    direction_dict = ['->', '<-', '^', 'V']
    max_step = 25
    while 1:
        if max_step <= 25:
            max_step = 25

        state = env.reset()
        s_pix = env.obs_render_rgb2(state['image'])
        s_pix = img_preprocess(s_pix)
        goal_ins = state['mission']
        goal_ins = nltk.sent_tokenize(goal_ins)
        word = []
        for sent in goal_ins:
            word.append(nltk.word_tokenize(sent))
        word_ins = word[0]
        ext_goal = [word_ins[-2], word_ins[-1]]
        if ext_goal[1] == 'object' or ext_goal[0] == 'a' or ext_goal[0] == 'the' or word_ins[0] == 'go':
            continue

        goal = ((object_list == ext_goal)[:, 0] & (object_list == ext_goal)[:, 1]).astype(np.int32)
        # goal = np.array(goal)
        goal = np.zeros(20)

        # print('goal input: ', goal)

        ep_rH = 0
        ep_rL = 0
        ep_step = 1
        done = False

        s_pix_curr = s_pix

        # # Map of agent direction indices to vectors
        # DIR_TO_VEC = [
        #     # Pointing right (positive X)
        #     np.array((1, 0)),
        #     # Down (positive Y)
        #     np.array((0, 1)),
        #     # Pointing left (negative X)
        #     np.array((-1, 0)),
        #     # Up (negative Y)
        #     np.array((0, -1)),
        # ]


        while True:

            # action = agent.act_forSBG(s_pix, goal)
            action = random.randint(0,3)
            # print(action, direction_dict[action])

            # state_next, reward, done, info = env.step(action)  # 获取下一个 state

            state_next, reward, done, info = env.step_trans(action)  # 获取下一个 state


            s_pix_next = env.obs_render_rgb2(state_next['image'])

            s_pix_next = img_preprocess(s_pix_next)





            if reward != 0:
                # print('task:',task, 'real_reward:',reward)
                done = True
            else:
                done = False

            if done:
                rewardH = 1.
            else:
                rewardH = 0.

            rewardL = rewardH

            q.put(['low', s_pix, action, rewardH, rewardL, s_pix_next, done, goal])

            ep_rH += rewardH
            ep_rL += rewardL
            # end = time.time()
            # print(mlt_num, '_time:', end - start)
            if done or ep_step >= max_step:
                # q.put(['high', s_pix_curr, act_high, s_pix, rewardH, done, goal])
                print('task:', task, 'ins:', state['mission'], 'ep_step', ep_step, 'ep_reward:', ep_rH)

                break


            s_pix = copy.deepcopy(s_pix_next)


            ep_step += 1

        finish = True
        q.put([finish, done, ep_step])

        while True:
            if not qb.empty():
                [total_parameter,flag] = qb.get()
                if total_parameter:
                    for idx in range(len(total_parameter)):
                        for name in total_parameter[idx]:
                            total_parameter[idx][name] = total_parameter[idx][name].cuda()
                    # print(total_parameter)
                    agent.load_model_frommodel(total_parameter)
                    break
                else:
                    print(total_parameter)
                    break
        if not total_parameter:
            print(total_parameter,'-=====')
            break



def subprocess_for_vrf_mj(task, q, qb, sub_num, args):

    agent = high_agent(args)
    agent.net2cuda(args.device)

    env = MazeEnv(maze_id=task, resolution=args.resolution, subprocess_num = sub_num, view_flag = args.random_view)
    env.reset()
    goal_idx = env.given_goal
    goal = np.zeros(3)
    goal[goal_idx] = 1

    while 1:
        obs, obs_pos = env.reset()  # img, goal_onehot, goal_word
        s_pix = img_preprocess(obs)
        ep_r = 0



        # max_step = env.max_step
        # while True:
        #     # env.render()
        #     action = agent.act_forSBG(s_pix, goal)
        #     next_obs, dense_reward, done, next_obs_pos, sparse_reward = env.step_for_tab(action)
        #     s_pix_next = img_preprocess(next_obs)
        #     if args.dense:
        #         reward = dense_reward
        #     else:
        #         reward = sparse_reward
        #     q.put(['high', s_pix, action, reward, sparse_reward, s_pix_next, done, goal])
        #     ep_r += reward
        #     if done or ep_step >= max_step:
        #         print('task:', task, 'ins:', 0, 'ep_step', ep_step, 'ep_reward:', ep_r,'done',done)
        #         break
        #
        #     s_pix = copy.deepcopy(s_pix_next)
        #     ep_step += 1




        # env.render()
        action = agent.act_forSBG(s_pix, goal)
        reward, goal_true_flag = env.step_for_vrf(action)

        if reward >=1:
            done = True
        else:
            done = False

        q.put(['high', s_pix, action, reward, goal_true_flag, s_pix, done, goal])
        ep_r += reward
        print('task:', task, 'ins:', 0, 'ep_step', 1, 'ep_reward:', ep_r, 'done', done)



        finish = True
        q.put([finish, done,1])

        while True:
            if not qb.empty():
                [total_parameter,] = qb.get()
                if total_parameter:
                    for idx in range(len(total_parameter)):
                        for name in total_parameter[idx]:
                            total_parameter[idx][name] = total_parameter[idx][name].cuda()
                    # print(total_parameter)
                    agent.load_model_frommodel(total_parameter)
                    break
                else:
                    print(total_parameter)
                    break
        if not total_parameter:
            print(total_parameter,'-=====')
            break
