import random

from high_agent import high_agent
import numpy as np
import cv2
import copy
import gym
from mlt_maze_env import MazeEnv
from agent_pretrained import SacAgent
import torch



def img_preprocess(img, PREPRO=True, resolution = (224,224)):
    if PREPRO:
        img = cv2.resize(img, (resolution[1], resolution[0]), interpolation=cv2.INTER_LINEAR) / 255.0
        img = img.astype(np.float32)
        img = np.moveaxis(img, 2, 0)
    img = np.expand_dims(img, axis=0)
    return img



def subprocess(task, q, qb, sub_num, args):
    if args.rand_h:
        agent = SacAgent()
        agent.load_models()
        agent.to_cuda(args.device)
        env = MazeEnv(maze_id=task, resolution=args.resolution, subprocess_num = sub_num, image_obs=False)
        ep_num = 0
        action_vec = np.array([[1, 1, 0, 0, 0, 0, 0, 0],
                            [0, 0, 1, 1, 0, 0, 0, 0],
                            [0, 0, 0, 0, 1, 1, 0, 0],
                            [0, 0, 0, 0, 0, 0, 1, 1]])
        while ep_num < args.ep_max:
            ep_num += 1
            obs, obs_pos = env.reset()  # img, goal_onehot, goal_word
            # s_pix = img_preprocess(obs)
            ep_r = 0
            ep_step = 1
            max_step = env.max_step
            frame_skip = args.frame_skip
            frame_reward = 0
            effect = True
            while True:
                action_high = random.randint(0,3)
                if action_high < 4:
                    for _ in range(frame_skip):
                        state_low = np.concatenate((obs_pos, action_vec[action_high]), axis=-1)
                        state_low = torch.from_numpy(state_low).cuda().float()
                        action = agent.exploit(state_low)
                        next_obs, dense_reward, done, next_obs_pos, sparse_reward = env.step(action)
                        obs_pos = next_obs_pos
                        if next_obs_pos[0] < 0.3:
                            effect = False
                        if args.dense:
                            frame_reward += dense_reward
                        else:
                            frame_reward += sparse_reward
                        if done or not effect:
                            break


                q.put(['high', None, action, frame_reward, sparse_reward, None, done, None])
                ep_r += frame_reward
                frame_reward = 0
                if done or ep_step >= max_step:
                    print('task:', task, 'ins:', 0, 'ep_step', ep_step, 'ep_reward:', ep_r)
                    break
                ep_step += 1

            finish = True
            q.put([finish, done])

            while True:
                if not qb.empty():
                    [flag, ] = qb.get()
                    if flag:
                        break



    else:
        agent = high_agent(args)
        agent.net2cuda(args.device)
        env = MazeEnv(maze_id=task, resolution=args.resolution, subprocess_num = sub_num)
        goal = np.zeros(20)

        while 1:
            obs, obs_pos = env.reset()  # img, goal_onehot, goal_word
            s_pix = img_preprocess(obs)
            ep_r = 0
            ep_step = 1
            max_step = env.max_step

            while True:
                action = agent.act_forSBG(s_pix, goal)
                next_obs, dense_reward, done, next_obs_pos, sparse_reward = env.step_for_tab(action)
                s_pix_next = img_preprocess(next_obs)
                if args.dense:
                    reward = dense_reward
                else:
                    reward = sparse_reward
                q.put(['high', s_pix, action, reward, sparse_reward, s_pix_next, done, goal])
                ep_r += reward
                if done or ep_step >= max_step:
                    print('task:', task, 'ins:', 0, 'ep_step', ep_step, 'ep_reward:', ep_r,'done',done)
                    break

                s_pix = copy.deepcopy(s_pix_next)
                ep_step += 1

            finish = True
            q.put([finish, done,ep_step])

            while True:
                if not qb.empty():
                    [total_parameter,] = qb.get()
                    if total_parameter:
                        for idx in range(len(total_parameter)):
                            for name in total_parameter[idx]:
                                total_parameter[idx][name] = total_parameter[idx][name].cuda()
                        # print(total_parameter)
                        agent.load_model_frommodel(total_parameter)
                        break
                    else:
                        print(total_parameter)
                        break
            if not total_parameter:
                print(total_parameter,'-=====')
                break



def subprocess_general(task, q, qb, sub_num, args):
    agent = high_agent(args)
    agent_low = SacAgent(robot=args.robot)

    if not args.oracle:
        if args.robot == 'Ant':
            agent_low.load_models()
        elif args.robot == 'Point':
            agent_low.load_models(4990, args.robot)
        elif args.robot == 'Swimmer' or args.robot == 'Swimmer-v2':
            agent_low.load_models(5087, args.robot)
        else:
            agent_low.load_models()
    # agent_low.load_models()
    agent.agent_low = agent_low
    agent.net2cuda(args.device)
    env = MazeEnv(maze_id=task, resolution=args.resolution, subprocess_num = sub_num, robot=args.robot)
    goal = np.zeros(20)
    frame_skip = args.frame_skip
    direction_dict = ['->', '<-', '^', 'V']

    while 1:
        obs, obs_pos = env.reset()  # img, goal_onehot, goal_word
        env.render()
        s_pix = img_preprocess(obs)
        ep_r = 0
        ep_step = 1
        if args.robot =='Ant':
            max_step = env.max_step
        else:
            max_step = env.max_step * 3

        frame_reward = 0
        effect = True
        # xy_pos = env.get_curr_robot_xy()
        # q.put(['high', None, 0, frame_reward, 0, None, 0, xy_pos])
        while True:
            if args.oracle:
                action = agent.act_forSBG(s_pix, goal)
                # print(action, direction_dict[action])
                # env.render()
                next_obs, dense_reward, done, next_obs_pos, sparse_reward = env.step_for_tab(action)
                s_pix_next = img_preprocess(next_obs)
                if args.dense:
                    reward = dense_reward
                else:
                    reward = sparse_reward
                q.put(['high', s_pix, action, reward, sparse_reward, s_pix_next, done, goal])
                ep_r += reward
                if done or ep_step >= max_step:
                    print('task:', task, 'ins:', 0, 'ep_step', ep_step, 'ep_reward:', ep_r)
                    break

                s_pix = copy.deepcopy(s_pix_next)
                ep_step += 1
            else:
                action_high = agent.act_forSBG(s_pix, goal)
                if args.robot == 'Swimmer' or args.robot == 'Swimmer-v2':
                    if random.random() < 0.2:
                        action_high = random.randint(0, 3)

                # if action_high == 2 or action_high == 0:
                #     action_high = 1

                # print(action_high, direction_dict[action_high])
                if action_high < 4:
                    for _ in range(frame_skip):
                        # print(obs_pos.shape)
                        action = agent.select_action_low(action_high, obs_pos)
                        next_obs, dense_reward, done, next_obs_pos, sparse_reward = env.step(action)
                        obs_pos = next_obs_pos
                        env.render()
                        if args.robot == 'Ant' and next_obs_pos[0] < 0.3:
                            effect = False

                        if args.dense:
                            frame_reward += dense_reward
                        else:
                            frame_reward += sparse_reward
                        if done or not effect:
                            break
                        xy_pos = env.get_curr_robot_xy()
                        s_pix_next = img_preprocess(next_obs)
                        q.put(['high', None, action, frame_reward, sparse_reward, None, done, xy_pos])
                ep_r += frame_reward
                frame_reward = 0
                if done or ep_step >= max_step or not effect:
                    print('task:', task, 'ins:', 0, 'ep_step', ep_step, 'ep_reward:', ep_r)
                    break
                s_pix = copy.deepcopy(s_pix_next)
                ep_step += 1


        finish = True
        q.put([finish, done, ep_step])

        while True:
            if not qb.empty():
                [total_parameter,flag] = qb.get()
                if total_parameter:
                    for idx in range(len(total_parameter)):
                        for name in total_parameter[idx]:
                            total_parameter[idx][name] = total_parameter[idx][name].cuda()
                    # print(total_parameter)
                    agent.load_model_frommodel(total_parameter)
                    break
                else:
                    print(total_parameter)
                    break
        if not total_parameter:
            print(total_parameter,'-=====')
            break