import os
import platform
import pathlib
plt = platform.system()
if plt == "Linux":
    os.environ['PYOPENGL_PLATFORM'] = 'egl'
    pathlib.WindowsPath = pathlib.PosixPath
import time
import numpy as np
import random
import matplotlib.pyplot as plt
import torch
import argparse
from tensorboardX import SummaryWriter
from gymnasium.spaces import Discrete
from gym.spaces import Discrete
from torch.multiprocessing import freeze_support
import yaml

from sensing_agent import Sensor, crop_img
from control_agent import Controller


class AdvancedEnv():

    def __init__(self, maze_len=10,
                 grid_spacing=50,
                 grid_offset=25,
                 goal_pos=[99],
                 screen_width=600,
                 screen_height=600,
                 line_offset=50,
                 obj_offset=75,
                 tr_offset=25,
                 obj_size=20,
                 bias=20,
                 fire_num=None,
                 unknown_num=None,
                 fire_pos=[32, 75, 50, 74, 90],
                 unknown_pos=[8, 43, 47, 66, 80],
                 *args, **kwargs):
        self.maze_len = maze_len
        self.action_space = Discrete(4)
        self.observation_space = Discrete(self.maze_len * self.maze_len)
        self.agents = {0: (0, 0), 1: (1, 0)}
        self.agent_goal = [0, 0]
        self.info = {}
        self.viewer = None
        self.screen_width = screen_width
        self.screen_height = screen_height
        self.grid_spacing = grid_spacing
        self.grid_offset = grid_offset
        self.line_offset = line_offset
        self.obj_offset = obj_offset
        self.tr_offset = tr_offset
        self.obj_size = obj_size
        self.fire_num = fire_num
        self.unknown_num = unknown_num
        self.fire_pos = fire_pos
        self.unknown_pos = unknown_pos
        self.goal = goal_pos
        self.bias = bias

        x_y = [x * self.grid_spacing + self.grid_offset for x in range(1, self.maze_len + 1)]
        self.x = x_y * self.maze_len
        self.y = []
        index = reversed(x_y)
        for i in index:
            self.y = self.y + [i] * self.maze_len

        if fire_num is None and unknown_num is None:
            self.fires = self.fire_pos
            self.unknowns = self.unknown_pos
        else:
            pos = [i for i in range(self.maze_len * self.maze_len)]
            for i in self.goal:
                pos.remove(i)
            pos.remove(0 * self.maze_len)
            pos.remove(1 * self.maze_len)
            pos.remove(2 * self.maze_len)
            self.fires = random.sample(pos, fire_num)
            for i in self.fires:
                pos.remove(i)
            self.unknowns = random.sample(pos, unknown_num)
            print("random obstacles: ", self.fires)
            print("random unknowns: ", self.unknowns)

    def reset(self, *, seed=None, options=None):
        """Reset seeker position randomly, return observations."""
        self.agents = {0: (0, 0), 1: (1, 0)}
        return {0: self.get_observation(0), 1: self.get_observation(1)}, {}

    def get_observation(self, agent_id):
        """Encode the seeker position as integer"""
        seeker = self.agents[agent_id]
        return self.maze_len * seeker[0] + seeker[1]

    def get_reward(self, agent_id):
        """Reward finding the goal and punish forbidden states"""
        if self.get_observation(agent_id) in self.goal:
            reward = 100
        elif self.get_observation(agent_id) in self.fires or self.get_observation(agent_id) in self.unknowns:
            reward = -100
        elif (((list(self.agents.values()).count(self.agents[agent_id]))>1) and self.get_observation(agent_id) !=self.goal):
            reward = -10
        else:
            reward = 0
            # seeker = self.agents[agent_id]
            # reward = rr[seeker[0]][seeker[1]]
        return reward

    def is_done(self, agent_id): # collision
        if self.get_observation(agent_id) in self.goal:
            print("agent ", agent_id, " arrives.")
        if self.get_observation(agent_id) in self.fires:
            print("agent ", agent_id, " hit fire.")
        elif self.get_observation(agent_id) in self.unknowns:
            print("agent ", agent_id, " hit unknown.")
        elif (((list(self.agents.values()).count(self.agents[agent_id])) > 1) and self.get_observation(
                agent_id) != self.goal):
            print("agents hit!")
        return (((list(self.agents.values()).count(self.agents[agent_id])) > 1) and self.get_observation(
                agent_id) != self.goal)

    def step(self, action):
        for i in list(action.keys()):
            if self.get_observation(i) in self.goal:
                self.agent_goal[i] = 100
                del action[i]
            elif self.get_observation(i) in self.fires:
                self.agent_goal[i] = -100
                del action[i]
            elif self.get_observation(i) in self.unknowns:
                self.agent_goal[i] = -100
                del action[i]
        agent_ids = action.keys()

        for agent_id in agent_ids:
            seeker = self.agents[agent_id]
            if action[agent_id] == 0:  # move down
                seeker = (min(seeker[0] + 1, self.maze_len-1), seeker[1])
            elif action[agent_id] == 1:  # move left
                seeker = (seeker[0], max(seeker[1] - 1, 0))
            elif action[agent_id] == 2:  # move up
                seeker = (max(seeker[0] - 1, 0), seeker[1])
            elif action[agent_id] == 3:  # move right
                seeker = (seeker[0], min(seeker[1] + 1, self.maze_len-1))
            else:
                raise ValueError("Invalid action")
            self.agents[agent_id] = seeker

        observations = {i: self.get_observation(i) for i in agent_ids}
        rewards = {i: self.get_reward(i) for i in agent_ids}
        done = {i: self.is_done(i) for i in agent_ids}

        done["__all__"] = all(done.values())

        return observations, rewards, done, done, self.info

    def render(self, action, *args, **kwargs):
        """We override this method here so clear the output in Jupyter notebooks.
        The previous implementation works well in the terminal, but does not clear
        the screen in interactive environments.
        """
        print("render: ",action)
        from gym.envs.classic_control import rendering
        screen_width = self.screen_width
        screen_height = self.screen_height

        if self.viewer is None:
            self.viewer = rendering.Viewer(screen_width, screen_height)

            # create gridworld
            for index in range(self.maze_len + 1):
                self.line = rendering.Line((self.line_offset, self.line_offset + index * self.grid_spacing), (self.screen_width-self.line_offset, self.line_offset + index * self.grid_spacing))
                self.line.set_color(0, 0, 0)
                self.viewer.add_geom(self.line)
            for index in range(self.maze_len + 1):
                self.line = rendering.Line((self.line_offset + index * self.grid_spacing, self.line_offset), (self.line_offset + index * self.grid_spacing, self.screen_height-self.line_offset))
                self.line.set_color(0, 0, 0)
                self.viewer.add_geom(self.line)

            # obstacles
            for fire in self.fires:
                pos_x = (fire % self.maze_len) * self.grid_spacing + self.obj_offset
                pos_y = ((self.maze_len-1) - int(fire / self.maze_len)) * self.grid_spacing + self.obj_offset
                self.fire = rendering.make_circle(self.obj_size)
                self.circletrans = rendering.Transform(translation=(pos_x, pos_y))
                self.fire.add_attr(self.circletrans)
                self.fire.set_color(1, 0, 0)
                self.viewer.add_geom(self.fire)
                # print(pos_x, pos_y)
                # print("add fires")

            # unknowns
            for unknown in self.unknowns:
                pos_x = (unknown % self.maze_len) * self.grid_spacing + self.obj_offset
                pos_y = ((self.maze_len-1) - int(unknown / self.maze_len)) * self.grid_spacing + self.obj_offset
                self.unknown = rendering.FilledPolygon([(self.tr_offset-self.obj_size, self.tr_offset-self.obj_size), (self.tr_offset-self.obj_size, self.tr_offset+self.obj_size), (self.tr_offset+self.obj_size, self.tr_offset+self.obj_size), (self.tr_offset+self.obj_size, self.tr_offset-self.obj_size)])
                self.circletrans = rendering.Transform(translation=(pos_x - self.tr_offset, pos_y - self.tr_offset))
                self.unknown.add_attr(self.circletrans)
                self.unknown.set_color(0, 0, 0)
                self.viewer.add_geom(self.unknown)

            for g in self.goal:
                pos_x = (g % self.maze_len) * self.grid_spacing + self.obj_offset
                pos_y = ((self.maze_len-1) - int(g / self.maze_len)) * self.grid_spacing + self.obj_offset
                self.diamond = rendering.make_circle(self.obj_size)
                self.circletrans = rendering.Transform(translation=(pos_x, pos_y))
                self.diamond.add_attr(self.circletrans)
                self.diamond.set_color(0, 0, 1)
                self.viewer.add_geom(self.diamond)

            self.robot1 = rendering.make_circle(self.obj_size)
            self.robotrans1 = rendering.Transform()
            self.robot1.add_attr(self.robotrans1)
            self.robot1.set_color(0, 1, 0)
            self.viewer.add_geom(self.robot1)

            self.robot2 = rendering.make_circle(self.obj_size)
            self.robotrans2 = rendering.Transform()
            self.robot2.add_attr(self.robotrans2)
            self.robot2.set_color(0, 1, 0)
            self.viewer.add_geom(self.robot2)

        if 0 in action.keys():
            self.robotrans1.set_translation(self.x[self.get_observation(0)], self.y[self.get_observation(0)])
        else:
            self.robotrans1.set_translation(self.x[0*self.maze_len], self.y[0*self.maze_len])
        if 1 in action.keys():
            self.robotrans2.set_translation(self.x[self.get_observation(1)], self.y[self.get_observation(1)])
        else:
            self.robotrans2.set_translation(self.x[1*self.maze_len], self.y[1*self.maze_len])

        return self.viewer.render(return_rgb_array=True)

    def close(self):
        if self.viewer:
            self.viewer.close()
            self.viewer = None


def vworld_control(control_agent, episode, obs, action, virtual_world):
    flag = 0
    a_score = np.zeros([4])
    while flag < 4:
        flag = flag + 1
        key = "%d_%d" % (obs, action)
        if key not in transfer:
            a_score[action] = -100
            _, x = control_agent.get_action(obs, action, obs, -100, episode, 0.5)
            action = (action + 1) % 4
            continue
        else:
            next_state = transfer[key]
            if virtual_world[int(next_state / env.maze_len)][int(next_state % env.maze_len)] == -100:
                a_score[action] = -100
                _, x = control_agent.get_action(obs, action, next_state, -100, episode, 0.5)
                action = (action + 1) % 4
            elif virtual_world[int(next_state / env.maze_len)][int(next_state % env.maze_len)] < 0:
                a_score[action] = virtual_world[int(next_state / env.maze_len)][int(next_state % env.maze_len)]
                action = (action + 1) % 4
                _, x = control_agent.get_action(obs, action, next_state, virtual_world[int(next_state / env.maze_len)][int(next_state % env.maze_len)], episode, 0.5)
            else:
                for index in range(4):
                    if a_score[index] == 0:
                        a_score[index] = control_agent.q_table[obs][index]
                break
    if flag >= 4:
        epsilon_coefficient = 0.2
        epsilon = epsilon_coefficient * (0.99 ** episode)
        if epsilon <= np.random.uniform(0, 1):
            action = np.argmax(a_score)
        else:
            choindex = [i for i, e in enumerate(a_score) if e != -100]
            action = np.random.choice(choindex)
    return action

def check_sense_range(sensing_agents_list, agent_obs, maze_len):
    res = []
    for sensing_agent in sensing_agents_list:
        if int(agent_obs / maze_len) >= sensing_agent.start_x and int(agent_obs / maze_len) <= sensing_agent.end_x and int(agent_obs % maze_len) >= sensing_agent.start_y and int(agent_obs % maze_len) <= sensing_agent.end_y:
            res.append(sensing_agent)
    return res

def overlap(sensing_agent0, sensing_agent1):
    x_start = max(sensing_agent0.start_x, sensing_agent1.start_x)
    y_start = max(sensing_agent0.start_y, sensing_agent1.start_y)
    x_end = min(sensing_agent0.end_x, sensing_agent1.end_x)
    y_end = min(sensing_agent0.end_y, sensing_agent1.end_y)
    if x_start < x_end and y_start < y_end:
        return (x_start, y_start, x_end, y_end)
    return None

def normalization(width, height, xmin, xmax, ymin, ymax):
    dw = 1 / width
    dh = 1 / height
    x_center = (xmin + xmax) / 2
    y_center = (ymax + ymin) / 2
    w = (xmax - xmin)
    h = (ymax - ymin)
    x, y, w, h = x_center * dw, y_center * dh, w * dw, h * dh
    return x, y, w, h


if __name__ == '__main__':
    freeze_support()
    parser = argparse.ArgumentParser()
    parser.add_argument('--maze_size', type=int, choices=[4, 10, 50], default=10, help='size of the maze 4/10/50')
    parser.add_argument('--mode', type=str, choices=["easy", "medium", "hard", "customized", "random"], default="easy",
                        help='choose the mode of gridworld ("easy"/"medium"/"hard"/"customized"/"random"), "easy" means fewer obstacles '
                             'and "hard" means more obstacles. You can choose the "random" mode, where you need to specify the '
                             'number of fires and unknown objects, and their positions will be randomly generated in the gridworld. '
                             'You can also choose the "customized" mode, where you need to define the positions of obstacles in the YAML file. ')
    parser.add_argument('--fire_num', type=int, default=4, help='the number of fires')
    parser.add_argument('--unknown_num', type=int, default=5, help='the number of unknown objects')
    parser.add_argument('--sensor1_start_x', type=int, default=0, help='perception range of the first sensor')
    parser.add_argument('--sensor1_start_y', type=int, default=0, help='perception range of the first sensor')
    parser.add_argument('--sensor1_end_x', type=int, default=3, help='perception range of the first sensor')
    parser.add_argument('--sensor1_end_y', type=int, default=1, help='perception range of the first sensor')
    parser.add_argument('--sensor2_start_x', type=int, default=0, help='perception range of the second sensor')
    parser.add_argument('--sensor2_start_y', type=int, default=2, help='perception range of the second sensor')
    parser.add_argument('--sensor2_end_x', type=int, default=3, help='perception range of the second sensor')
    parser.add_argument('--sensor2_end_y', type=int, default=3, help='perception range of the second sensor')
    parser.add_argument('--need_opt', type=bool, default=0,
                        help='with coordinated optimization or without coordinated optimization, 0 means without and 1 means with')
    args = parser.parse_args()
    with open("env_config.yaml", "r") as file:
        config = yaml.safe_load(file)
    # print(config)
    env_config = config[args.maze_size]
    print("maze size: %d" % args.maze_size)
    if args.need_opt:
        print("with coordinated optimization")
    else:
        print("without coordinated optimization")
    print("mode: %s\nThe positions of the obstacles are as follows." % args.mode)
    if args.mode == "easy":
        print("fires: ", env_config["easy"]["fires"])
        print("unknown objects: ", env_config["easy"]["unknowns"])
        env = AdvancedEnv(env_config["maze_len"], env_config["grid_spacing"],
                          env_config["grid_offset"], env_config["goal_pos"],
                          env_config["screen_width"], env_config["screen_height"],
                          env_config["line_offset"], env_config["obj_offset"],
                          env_config["tr_offset"], env_config["obj_size"], env_config["bias"],
                          None, None, env_config["easy"]["fires"], env_config["easy"]["unknowns"])
    elif args.mode == "medium":
        print("fires: ", env_config["medium"]["fires"])
        print("unknown objects: ", env_config["medium"]["unknowns"])
        env = AdvancedEnv(env_config["maze_len"], env_config["grid_spacing"],
                          env_config["grid_offset"], env_config["goal_pos"],
                          env_config["screen_width"], env_config["screen_height"],
                          env_config["line_offset"], env_config["obj_offset"],
                          env_config["tr_offset"], env_config["obj_size"], env_config["bias"],
                          None, None, env_config["medium"]["fires"], env_config["medium"]["unknowns"])
    elif args.mode == "hard":
        print("fires: ", env_config["hard"]["fires"])
        print("unknown objects: ", env_config["hard"]["unknowns"])
        env = AdvancedEnv(env_config["maze_len"], env_config["grid_spacing"],
                          env_config["grid_offset"], env_config["goal_pos"],
                          env_config["screen_width"], env_config["screen_height"],
                          env_config["line_offset"], env_config["obj_offset"],
                          env_config["tr_offset"], env_config["obj_size"], env_config["bias"],
                          None, None, env_config["hard"]["fires"], env_config["hard"]["unknowns"])
    elif args.mode == "customized":
        print("fires: ", env_config["customized"]["fires"])
        print("unknown objects: ", env_config["customized"]["unknowns"])
        env = AdvancedEnv(env_config["maze_len"], env_config["grid_spacing"],
                          env_config["grid_offset"], env_config["goal_pos"],
                          env_config["screen_width"], env_config["screen_height"],
                          env_config["line_offset"], env_config["obj_offset"],
                          env_config["tr_offset"], env_config["obj_size"], env_config["bias"],
                          None, None, env_config["customized"]["fires"], env_config["customized"]["unknowns"])
    else:
        # print(env_config)
        env = AdvancedEnv(env_config["maze_len"], env_config["grid_spacing"],
                          env_config["grid_offset"], env_config["goal_pos"],
                          env_config["screen_width"], env_config["screen_height"],
                          env_config["line_offset"], env_config["obj_offset"],
                          env_config["tr_offset"], env_config["obj_size"], env_config["bias"],
                          args.fire_num, args.unknown_num, None, None)
    print("perception range of sensor1: (%d, %d)->(%d, %d)" % (args.sensor1_start_x, args.sensor1_start_y, args.sensor1_end_x, args.sensor1_end_y))
    print("perception range of sensor2: (%d, %d)->(%d, %d)" % (args.sensor2_start_x, args.sensor2_start_y, args.sensor2_end_x, args.sensor2_end_y))

    np.random.seed(42)
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)

    writer = SummaryWriter("runs/20250617_2cn3_wo/reward")
    writer0 = SummaryWriter("runs/20250617_2cn3_wo/q_value_0")
    writer1 = SummaryWriter("runs/20250617_2cn3_wo/q_value_1")


    transfer = dict()
    size = env.maze_len
    for i in range(size, size * size):
        transfer[str(i) + '_2'] = i - size
    for i in range(size * (size - 1)):
        transfer[str(i) + '_0'] = i + size
    for i in range(1, size * size):
        if i % size == 0:
            continue
        transfer[str(i) + '_1'] = i - 1
    for i in range(size * size):
        if (i + 1) % size == 0:
            continue
        transfer[str(i) + '_3'] = i + 1

    pos = [[i * env.maze_len + j for j in range(env.maze_len)] for i in range(env.maze_len)]

    control_agent0 = Controller(env.maze_len, 4)
    control_agent1 = Controller(env.maze_len, 4)
    last_time_steps = np.zeros(50)
    goal_average_steps = 0.98

    agent0_cnt = np.zeros((env.maze_len, env.maze_len))
    agent1_cnt = np.zeros((env.maze_len, env.maze_len))


    success_cnt = 0
    hit_cnt = 0

    sensing_agents = []
    sensing_agent0 = Sensor(args.sensor1_start_x, args.sensor1_start_y, args.sensor1_end_x, args.sensor1_end_y, "sensor0")
    sensing_agents.append(sensing_agent0)

    sensing_agent1 = Sensor(args.sensor2_start_x, args.sensor2_start_y, args.sensor2_end_x, args.sensor2_end_y, "sensor1")
    sensing_agents.append(sensing_agent1)

    overlap_range = overlap(sensing_agent0, sensing_agent1)

    n=1
    image_cnt = 0
    ss = 0
    timer = time.time()
    for eposide in range(3000):
        print("\nepisode: ", eposide)
        obs, _ = env.reset()
        episode_reward = 0
        cnt_obstacle_rate = 0
        max_cnt_obstacle_rate = 0
        # print(obs)
        sensing_agent0.cnt_obstacle_rate = 0
        sensing_agent0.max_cnt_obstacle_rate = 0
        sensing_agent1.cnt_obstacle_rate = 0
        sensing_agent1.max_cnt_obstacle_rate = 0
        action0 = np.argmax(control_agent0.q_table[obs[0]])
        action1 = np.argmax(control_agent1.q_table[obs[1]])
        action = {0: action0, 1: action1}
        state = obs
        env.agent_goal[0] = 0
        env.agent_goal[1] = 0
        stp = 0
        while stp < 2000:
            virtual_world = np.zeros([env.maze_len, env.maze_len])
            ss += 1
            stp += 1
            print("step: ", stp)

            for sensing_agent in sensing_agents:
                sensing_agent.cnt_obstacle = 0

            # log q-value
            sample_pos = env.maze_len*env.maze_len-2
            RL0_value = np.array([control_agent0.q_table[sample_pos][0], control_agent0.q_table[sample_pos][1], control_agent0.q_table[sample_pos][2], control_agent0.q_table[sample_pos][3]])
            RL1_value = np.array([control_agent1.q_table[sample_pos][0], control_agent1.q_table[sample_pos][1], control_agent1.q_table[sample_pos][2], control_agent1.q_table[sample_pos][3]])
            print("control agent 0, down, left, up, right: ", RL0_value[0], RL0_value[1], RL0_value[2], RL0_value[3])
            print("control agent 1, down, left, up, right: ", RL1_value[0], RL1_value[1], RL1_value[2], RL1_value[3])
            writer0.add_scalar("q_value", RL0_value[3], ss)
            writer1.add_scalar("q_value", RL1_value[3], ss)

            if 0 in action.keys():
                agent0_cnt[int(state[0] / env.maze_len)][int(state[0] % env.maze_len)] += 1
            if 1 in action.keys():
                agent1_cnt[int(state[1] / env.maze_len)][int(state[1] % env.maze_len)] += 1

            print(sensing_agent0.image_cnt())
            print(sensing_agent1.image_cnt())

            # retrain sensing agents models
            if args.need_opt:
                if (sensing_agent0.image_cnt() == 50 or sensing_agent0.image_cnt() == 51 or sensing_agent0.image_cnt()==52) and sensing_agent0.train_cnt == 0:
                    sensing_agent0.sensor_train()
                if (sensing_agent0.image_cnt() == 100 or sensing_agent0.image_cnt() == 101 or sensing_agent0.image_cnt() == 102) and sensing_agent0.train_cnt == 1:
                    sensing_agent0.sensor_train()
                if (sensing_agent0.image_cnt() == 150 or sensing_agent0.image_cnt() == 151 or sensing_agent0.image_cnt() == 152) and sensing_agent0.train_cnt == 2:
                    sensing_agent0.sensor_train()
                if (sensing_agent0.image_cnt() == 200 or sensing_agent0.image_cnt() == 201 or sensing_agent0.image_cnt() == 202) and sensing_agent0.train_cnt == 3:
                    sensing_agent0.sensor_train()
                if (sensing_agent0.image_cnt() == 250 or sensing_agent0.image_cnt() == 251 or sensing_agent0.image_cnt()==252) and sensing_agent0.train_cnt == 4:
                    sensing_agent0.sensor_train()
                if (sensing_agent0.image_cnt() == 300 or sensing_agent0.image_cnt() == 301 or sensing_agent0.image_cnt() == 302) and sensing_agent0.train_cnt == 5:
                    sensing_agent0.sensor_train()
                if (sensing_agent0.image_cnt() == 350 or sensing_agent0.image_cnt() == 351 or sensing_agent0.image_cnt() == 352) and sensing_agent0.train_cnt == 6:
                    sensing_agent0.sensor_train()
                if (sensing_agent0.image_cnt() == 400 or sensing_agent0.image_cnt() == 401 or sensing_agent0.image_cnt() == 402) and sensing_agent0.train_cnt == 7:
                    sensing_agent0.sensor_train()

                if (sensing_agent1.image_cnt() == 50 or sensing_agent1.image_cnt() == 51 or sensing_agent1.image_cnt()==52) and sensing_agent1.train_cnt == 0:
                    sensing_agent1.sensor_train()
                if (sensing_agent1.image_cnt() == 100 or sensing_agent1.image_cnt() == 101 or sensing_agent1.image_cnt() == 102) and sensing_agent1.train_cnt == 1:
                    sensing_agent1.sensor_train()
                if (sensing_agent1.image_cnt() == 150 or sensing_agent1.image_cnt() == 151 or sensing_agent1.image_cnt() == 152) and sensing_agent1.train_cnt == 2:
                    sensing_agent1.sensor_train()
                if (sensing_agent1.image_cnt() == 200 or sensing_agent1.image_cnt() == 201 or sensing_agent1.image_cnt() == 202) and sensing_agent1.train_cnt == 3:
                    sensing_agent1.sensor_train()
                if (sensing_agent1.image_cnt() == 250 or sensing_agent1.image_cnt() == 251 or sensing_agent1.image_cnt()== 252) and sensing_agent1.train_cnt == 4:
                    sensing_agent1.sensor_train()
                if (sensing_agent1.image_cnt() == 300 or sensing_agent1.image_cnt() == 301 or sensing_agent1.image_cnt() == 302) and sensing_agent1.train_cnt == 5:
                    sensing_agent1.sensor_train()
                if (sensing_agent1.image_cnt() == 350 or sensing_agent1.image_cnt() == 351 or sensing_agent1.image_cnt() == 352) and sensing_agent1.train_cnt == 6:
                    sensing_agent1.sensor_train()
                if (sensing_agent1.image_cnt() == 400 or sensing_agent1.image_cnt() == 401 or sensing_agent1.image_cnt() == 402) and sensing_agent1.train_cnt == 7:
                    sensing_agent1.sensor_train()

            # obtain raw data
            img = env.render(action)
            plt.imsave("./tmp_image/tmp.jpg", img)
            time.sleep(0.1)
            sensing_agent0.newly_added_image = crop_img("./tmp_image/tmp.jpg", sensing_agent0.start_y, sensing_agent0.start_x, sensing_agent0.end_y,
                     sensing_agent0.end_x, env.line_offset, env.grid_spacing, sensing_agent0.image_path)
            sensing_agent1.newly_added_image = crop_img("./tmp_image/tmp.jpg", sensing_agent1.start_y, sensing_agent1.start_x, sensing_agent1.end_y,
                     sensing_agent1.end_x, env.line_offset, env.grid_spacing, sensing_agent1.image_path)

            sensing_agent0.det_res = sensing_agent0.sensor_detect()
            sensing_agent1.det_res = sensing_agent1.sensor_detect()
            # print(sensing_agent0.det_res)
            # print(sensing_agent1.det_res)

            reg_res = {}
            reg_res0 = {}
            reg_res0_overlap = []
            for index in range(len(sensing_agent0.det_res)):
                y = int(((sensing_agent0.det_res[index][0]+sensing_agent0.det_res[index][2]) / 2) / env.grid_spacing)
                x = int(((sensing_agent0.det_res[index][1]+sensing_agent0.det_res[index][3]) / 2) / env.grid_spacing)
                # print(x,y)
                # print(sensing_agent0.start_x, sensing_agent0.start_y)
                trans_x = x + sensing_agent0.start_x
                trans_y = y + sensing_agent0.start_y
                reg_res0[(trans_x, trans_y)] = (sensing_agent0.det_res[index][4], int(sensing_agent0.det_res[index][5]))
                if overlap_range and int(sensing_agent0.det_res[index][5]) == 2 and trans_x >= overlap_range[0] and trans_x <= overlap_range[2] and trans_y >= overlap_range[1] and trans_y <= overlap_range[3]:
                    reg_res0_overlap.append((trans_x, trans_y))
            # print("reg_res0: ", reg_res0)

            reg_res1 = {}
            reg_res1_overlap = []
            for index in range(len(sensing_agent1.det_res)):
                y = int(((sensing_agent1.det_res[index][0] + sensing_agent1.det_res[index][2]) / 2) / env.grid_spacing)
                x = int(((sensing_agent1.det_res[index][1] + sensing_agent1.det_res[index][3]) / 2) / env.grid_spacing)
                trans_x = x + sensing_agent1.start_x
                trans_y = y + sensing_agent1.start_y
                reg_res1[(trans_x, trans_y)]=(sensing_agent1.det_res[index][4], int(sensing_agent1.det_res[index][5]))
                if overlap_range and int(sensing_agent1.det_res[index][5]) == 2 and trans_x >= overlap_range[0] and trans_x <= overlap_range[2] and trans_y >= overlap_range[
                    1] and trans_y <= overlap_range[3]:
                    reg_res1_overlap.append((trans_x, trans_y))

            if (reg_res1_overlap != reg_res0_overlap):
                unknown_pos0 = list(set(reg_res1_overlap) - set(reg_res0_overlap))
                # print("unknown_pos0: ", unknown_pos0)
                if unknown_pos0:
                    print("add to sensing agent 0 dataset")
                    for pos in unknown_pos0:
                        label = 2
                        x_min = (pos[1]-sensing_agent0.start_y) * env.grid_spacing
                        y_min = (pos[0]-sensing_agent0.start_x) * env.grid_spacing
                        x_max = x_min + env.grid_spacing
                        y_max = y_min + env.grid_spacing
                        # print("unknown pos 0: ", x_min, x_max, y_min, y_max)
                        txts_save_path = sensing_agent0.labels_path
                        txt_name = str(sensing_agent0.newly_added_data_cnt) + ".txt"
                        os.mkdir(txts_save_path) if not os.path.exists(txts_save_path) else None
                        f = open(os.path.join(txts_save_path, txt_name), 'w')
                        # 5050
                        # x, y, w, h = normalization(820, 820, x_min, x_max, y_min, y_max)
                        # 1010
                        x, y, w, h = normalization(1+env.grid_spacing*(sensing_agent0.end_y-sensing_agent0.start_y+1), 1+env.grid_spacing*(sensing_agent0.end_x-sensing_agent0.start_x+1), x_min, x_max, y_min, y_max)
                        f.write(str(label) + ' ' + str(x) + ' ' + str(y) + ' ' + str(w) + ' ' + str(h) + ' ' + '\n')
                        ori = sensing_agent0.det_res.tolist()
                        for k in range(len(ori)):
                            x, y, w, h = normalization(1+env.grid_spacing*(sensing_agent0.end_y-sensing_agent0.start_y+1), 1+env.grid_spacing*(sensing_agent0.end_x-sensing_agent0.start_x+1), ori[k][0], ori[k][2], ori[k][1],
                                                       ori[k][3])
                            f.write(str(int(ori[k][5])) + ' ' + str(x) + ' ' + str(y) + ' ' + str(w) + ' ' + str(
                                h) + ' ' + '\n')
                        f.close()
                        newly_added_image_pth = os.path.join(sensing_agent0.images_path,
                                                             "%d.jpg" % sensing_agent0.newly_added_data_cnt)
                        sensing_agent0.newly_added_image.save(newly_added_image_pth)
                        if sensing_agent0.newly_added_data_cnt % 8 == 0:
                            if not os.path.exists(sensing_agent0.val_path):
                                os.makedirs(sensing_agent0.val_path)
                            list_file = open(sensing_agent0.val_path, "a")
                            list_file.write("./images/%d.jpg\n" % sensing_agent0.newly_added_data_cnt)
                            sensing_agent0.newly_added_data_cnt += 1
                            list_file.close()
                        else:
                            if not os.path.exists(sensing_agent0.train_path):
                                os.makedirs(sensing_agent0.train_path)
                            list_file = open(sensing_agent0.train_path, "a")
                            list_file.write("./images/%d.jpg\n" % sensing_agent0.newly_added_data_cnt)
                            sensing_agent0.newly_added_data_cnt += 1
                            list_file.close()


                unknown_pos1 = list(set(reg_res0_overlap) - set(reg_res1_overlap))
                # print("unknown_pos1: ", unknown_pos1)
                if unknown_pos1:
                    print("add to sensing agent 1 dataset")
                    for pos in unknown_pos1:
                        label = 2
                        x_min = (pos[1] - sensing_agent1.start_y) * env.grid_spacing
                        y_min = (pos[0] - sensing_agent1.start_x) * env.grid_spacing
                        x_max = x_min + env.grid_spacing
                        y_max = y_min + env.grid_spacing
                        # print("unknown pos 1: ", x_min, x_max, y_min, y_max)
                        txts_save_path = sensing_agent1.labels_path
                        txt_name = str(sensing_agent1.newly_added_data_cnt) + ".txt"
                        os.mkdir(txts_save_path) if not os.path.exists(txts_save_path) else None
                        f = open(os.path.join(txts_save_path, txt_name), 'w')
                        # 5050
                        # x, y, w, h = normalization(820, 820, x_min, x_max, y_min, y_max)
                        # 1010
                        x, y, w, h = normalization(1+env.grid_spacing*(sensing_agent1.end_y-sensing_agent1.start_y+1), 1+env.grid_spacing*(sensing_agent1.end_x-sensing_agent1.start_x+1), x_min, x_max, y_min, y_max)
                        f.write(str(label) + ' ' + str(x) + ' ' + str(y) + ' ' + str(w) + ' ' + str(h) + ' ' + '\n')
                        ori = sensing_agent1.det_res.tolist()

                        for k in range(len(ori)):
                            # 5050
                            # x, y, w, h = normalization(820, 820, ori[k][0], ori[k][2], ori[k][1], ori[k][3])
                            # 1010
                            # print(ori[k][0], ori[k][2], ori[k][1], ori[k][3])
                            x, y, w, h = normalization(1+env.grid_spacing*(sensing_agent1.end_y-sensing_agent1.start_y+1), 1+env.grid_spacing*(sensing_agent1.end_x-sensing_agent1.start_x+1), ori[k][0], ori[k][2], ori[k][1],
                                                       ori[k][3])
                            f.write(str(int(ori[k][5])) + ' ' + str(x) + ' ' + str(y) + ' ' + str(w) + ' ' + str(
                                h) + ' ' + '\n')
                        f.close()
                        newly_added_image_pth = os.path.join(sensing_agent1.images_path,
                                                             "%d.jpg" % sensing_agent1.newly_added_data_cnt)
                        sensing_agent1.newly_added_image.save(newly_added_image_pth)
                        if sensing_agent1.newly_added_data_cnt % 8 == 0:
                            if not os.path.exists(sensing_agent1.val_path):
                                os.makedirs(sensing_agent1.val_path)
                            list_file = open(sensing_agent1.val_path, "a")
                            list_file.write("./images/%d.jpg\n" % sensing_agent1.newly_added_data_cnt)
                            sensing_agent1.newly_added_data_cnt += 1
                            list_file.close()
                        else:
                            if not os.path.exists(sensing_agent1.train_path):
                                os.makedirs(sensing_agent1.train_path)
                            list_file = open(sensing_agent1.train_path, "a")
                            list_file.write("./images/%d.jpg\n" % sensing_agent1.newly_added_data_cnt)
                            sensing_agent1.newly_added_data_cnt += 1
                            list_file.close()

            # print(reg_res0)
            # print(reg_res1)
            for k in reg_res0:
                reg_res[k]=reg_res0[k]
            for k in reg_res1:
                if k in reg_res:
                    if reg_res1[k][1] != reg_res[k][1]:
                        if reg_res1[k][0] > reg_res[k][0]:
                            reg_res[k] = reg_res1[k]
                else:
                    reg_res[k] = reg_res1[k]
            # print(reg_res)

            # virtual world construction
            # 0 agent0
            # 1 agent1
            # 2 obstacle
            # 3 goal
            for k in reg_res:
                if reg_res[k][1] == 2:
                    virtual_world[k[0]][k[1]] = -100
                elif reg_res[k][1] == 3:
                    virtual_world[k[0]][k[1]] = 100
            virtual_world[env.maze_len-1][env.maze_len-1] = 100
            # print(virtual_world)
            # choose action and simulate in the virtual world
            if 0 in action.keys():
                a0 = vworld_control(control_agent0, eposide, state[0], action[0], virtual_world)
                action[0] = a0
            if 1 in action.keys():
                a1 = vworld_control(control_agent1, eposide, state[1], action[1], virtual_world)
                action[1] = a1

            # execute in the real world, obtain the feedback
            obs, rewards, dones, x, infos = env.step(action)
            # print(obs)

            # coordinated optimization
            for key in action.keys():
                if obs[key] not in env.goal and abs(virtual_world[int(obs[key] / env.maze_len)][int(obs[key] % env.maze_len)] - rewards[key]) > 50:

                    within_sensing_agent = check_sense_range(sensing_agents, obs[key], env.maze_len)
                    if len(within_sensing_agent) == 0:
                        break
                    else:
                        for sensing_agent in within_sensing_agent:
                            # 5050
                            # label = 2
                            # x_min = (int(obs[key] % 50)) * 16 + 10
                            # y_min = (int(obs[key] / 50)) * 16 + 10
                            # x_max = x_min + 16
                            # y_max = y_min + 16
                            # print(x_min, y_min, x_max, y_max)
                            # 1010
                            x = int(obs[key] / env.maze_len)
                            y = int(obs[key] % env.maze_len)
                            x = x - sensing_agent.start_x
                            y = y - sensing_agent.start_y
                            label = 2
                            x_min = y * env.grid_spacing
                            y_min = x * env.grid_spacing
                            x_max = x_min + env.grid_spacing
                            y_max = y_min + env.grid_spacing
                            # print(sensing_agent.sensor_name, x_min, x_max, y_min, y_max)
                            txts_save_path = sensing_agent.labels_path
                            txt_name = str(sensing_agent.newly_added_data_cnt) + ".txt"
                            os.mkdir(txts_save_path) if not os.path.exists(txts_save_path) else None
                            f = open(os.path.join(txts_save_path, txt_name), 'w')
                            # 5050
                            # x, y, w, h = normalization(820, 820, x_min, x_max, y_min, y_max)
                            # 1010
                            x, y, w, h = normalization(1+env.grid_spacing*(sensing_agent.end_y-sensing_agent.start_y+1), 1+env.grid_spacing*(sensing_agent.end_x-sensing_agent.start_x+1), x_min, x_max, y_min, y_max)
                            f.write(str(label) + ' ' + str(x) + ' ' + str(y) + ' ' + str(w) + ' ' + str(h) + ' ' + '\n')
                            ori = sensing_agent.det_res.tolist()
                            # print(ori)
                            for k in range(len(ori)):
                                # 5050
                                # x, y, w, h = normalization(820, 820, ori[k][0], ori[k][2], ori[k][1], ori[k][3])
                                # 1010
                                # print(ori[k][0], ori[k][2], ori[k][1], ori[k][3])
                                x, y, w, h = normalization(1+env.grid_spacing*(sensing_agent.end_y-sensing_agent.start_y+1), 1+env.grid_spacing*(sensing_agent.end_x-sensing_agent.start_x+1), ori[k][0], ori[k][2], ori[k][1], ori[k][3])
                                f.write(str(int(ori[k][5])) + ' ' + str(x) + ' ' + str(y) + ' ' + str(w) + ' ' + str(
                                    h) + ' ' + '\n')
                            f.close()
                            newly_added_image_pth = os.path.join(sensing_agent.images_path, "%d.jpg" % sensing_agent.newly_added_data_cnt)
                            sensing_agent.newly_added_image.save(newly_added_image_pth)
                            if sensing_agent.newly_added_data_cnt % 8 == 0:
                                if not os.path.exists(sensing_agent.val_path):
                                    os.makedirs(sensing_agent.val_path)
                                list_file = open(sensing_agent.val_path, "a")
                                list_file.write("./images/%d.jpg\n" % sensing_agent.newly_added_data_cnt)
                                sensing_agent.newly_added_data_cnt += 1
                                list_file.close()
                            else:
                                if not os.path.exists(sensing_agent.train_path):
                                    os.makedirs(sensing_agent.train_path)
                                list_file = open(sensing_agent.train_path, "a")
                                list_file.write("./images/%d.jpg\n" % sensing_agent.newly_added_data_cnt)
                                sensing_agent.newly_added_data_cnt += 1
                                list_file.close()

            for key1 in action.keys():
                for key2 in action.keys():
                    if key1 != key2 and obs[key1] == obs[key2]:
                        print("hit!")
                        env.agent_goal[key1] = -10
                        env.agent_goal[key2] = -10
            print(env.agent_goal)

            if 0 in action.keys():
                action0, state0 = control_agent0.get_action(state[0], action[0], obs[0], rewards[0], eposide, 0.2, epi=50)
                action[0] = action0
            if 1 in action.keys():
                action1, state1 = control_agent1.get_action(state[1], action[1], obs[1], rewards[1], eposide, 0.2, epi=50)
                action[1] = action1
            state = obs

            if len(action) == 0:
                if env.agent_goal[0]==100 and env.agent_goal[1]==100:
                    success_cnt += 1
                    print("all arrive!")
                print('episodes:  %d, steps: %d, episode_reward：%dscore: %f' % (eposide, stp, (sum(env.agent_goal)/len(env.agent_goal))/100, last_time_steps.mean()))
                np.savetxt("q_table0.txt", control_agent0.q_table, delimiter=",")
                np.savetxt("q_table1.txt", control_agent1.q_table, delimiter=",")
                last_time_steps = np.hstack((last_time_steps[1:], [(sum(env.agent_goal)/len(env.agent_goal))/100]))
                writer.add_scalar("reward_show", (sum(env.agent_goal)/len(env.agent_goal))/100, eposide)
                break

        if (last_time_steps.mean() >= goal_average_steps):
            np.savetxt("q_table0.txt", control_agent0.q_table, delimiter=",")
            np.savetxt("q_table1.txt", control_agent1.q_table, delimiter=",")
            env.close()
            break
        time.sleep(0.5)
        print('episodes:  %d, steps: %d, episode_reward：%dscore: %f' % (
        eposide, stp, (sum(env.agent_goal) / len(env.agent_goal)) / 100, last_time_steps.mean()))

    print("both arrive count: ", success_cnt)
    print("ss: ", ss)
    print("finish")

    print(agent0_cnt)
    print(agent1_cnt)
    print(agent0_cnt.tolist())
    print(agent1_cnt.tolist())
    with open('output.txt', 'w') as file:
        file.write(str(agent0_cnt))
        file.write(str(agent1_cnt))
        file.write(str(agent0_cnt.tolist()))
        file.write(str(agent1_cnt.tolist()))

