import os
import platform
import pathlib
plt = platform.system()
if plt == "Linux":
    os.environ['PYOPENGL_PLATFORM'] = 'egl'
    pathlib.WindowsPath = pathlib.PosixPath
import time
import numpy as np
import random
import matplotlib.pyplot as plt
from PIL import Image
import torch
import argparse
from tensorboardX import SummaryWriter
from gymnasium.spaces import Discrete
from gym.spaces import Discrete
from torch.multiprocessing import freeze_support
import yaml

import train
from detect import run

class AdvancedEnv():

    def __init__(self, maze_len=10,
                 grid_spacing=50,
                 grid_offset=25,
                 goal_pos=[99],
                 screen_width=600,
                 screen_height=600,
                 line_offset=50,
                 obj_offset=75,
                 tr_offset=25,
                 obj_size=20,
                 bias = 20,
                 fire_num=None,
                 unknown_num=None,
                 fire_pos=[ 32, 75, 50, 74, 90 ],
                 unknown_pos=[ 8, 43, 47, 66, 80 ],
                 *args, **kwargs):
        self.maze_len = maze_len
        self.action_space = Discrete(4)
        self.observation_space = Discrete(self.maze_len * self.maze_len)
        self.agents = {0: (0, 0), 1: (1, 0), 2: (2, 0)}
        self.agent_goal = [0, 0, 0]
        self.info = {}
        self.viewer = None
        self.screen_width = screen_width
        self.screen_height = screen_height
        self.grid_spacing = grid_spacing
        self.grid_offset = grid_offset
        self.line_offset = line_offset
        self.obj_offset = obj_offset
        self.tr_offset = tr_offset
        self.obj_size = obj_size
        self.fire_num = fire_num
        self.unknown_num = unknown_num
        self.fire_pos = fire_pos
        self.unknown_pos = unknown_pos
        self.goal = goal_pos
        self.bias = bias

        x_y = [x * self.grid_spacing + self.grid_offset for x in range(1, self.maze_len + 1)]
        self.x = x_y * self.maze_len
        self.y = []
        index = reversed(x_y)
        for i in index:
            self.y = self.y + [i] * self.maze_len

        if fire_num is None and unknown_num is None:
            self.fires = fire_pos
            self.unknowns = unknown_pos
        else:
            pos = [i for i in range(self.maze_len * self.maze_len)]
            for i in self.goal:
                pos.remove(i)
            pos.remove(0 * self.maze_len)
            pos.remove(1 * self.maze_len)
            pos.remove(2 * self.maze_len)
            self.fires = random.sample(pos, fire_num)
            for i in self.fires:
                pos.remove(i)
            self.unknowns = random.sample(pos, unknown_num)
            print("random obstacles: ", self.fires)
            print("random unknowns: ", self.unknowns)

    def reset(self, *, seed=None, options=None):
        """Reset seeker position randomly, return observations."""
        self.agents = {0: (0, 0), 1: (1, 0), 2: (2, 0)}
        return {0: self.get_observation(0), 1: self.get_observation(1), 2: self.get_observation(2)}, {}


    def get_observation(self, agent_id):
        """Encode the seeker position as integer"""
        seeker = self.agents[agent_id]
        return self.maze_len * seeker[0] + seeker[1]

    def get_reward(self, agent_id):
        """Reward finding the goal and punish forbidden states"""
        if self.get_observation(agent_id) in self.goal:
            reward = 100
        elif self.get_observation(agent_id) in self.fires or self.get_observation(agent_id) in self.unknowns:
            reward = -100
        elif (((list(self.agents.values()).count(self.agents[agent_id]))>1) and self.get_observation(agent_id) !=self.goal):
            reward = -10
        else:
            reward = 0
            # seeker = self.agents[agent_id]
            # reward = rr[seeker[0]][seeker[1]]
        return reward

    def is_done(self, agent_id): # collision
        if self.get_observation(agent_id) in self.goal:
            print("agent ", agent_id, " arrives.")
        if self.get_observation(agent_id) in self.fires:
            print("agent ", agent_id, " hit fire.")
        elif self.get_observation(agent_id) in self.unknowns:
            print("agent ", agent_id, " hit unknown.")
        elif (((list(self.agents.values()).count(self.agents[agent_id])) > 1) and self.get_observation(
                agent_id) != self.goal):
            print("agents hit!")
        return (((list(self.agents.values()).count(self.agents[agent_id])) > 1) and self.get_observation(
                agent_id) != self.goal)

    def step(self, action):
        for i in list(action.keys()):
            if self.get_observation(i) in self.goal:
                self.agent_goal[i] = 100
                del action[i]
            elif self.get_observation(i) in self.fires:
                self.agent_goal[i] = -100
                del action[i]
            elif self.get_observation(i) in self.unknowns:
                self.agent_goal[i] = -100
                del action[i]
        agent_ids = action.keys()

        for agent_id in agent_ids:
            seeker = self.agents[agent_id]
            if action[agent_id] == 0:  # move down
                seeker = (min(seeker[0] + 1, self.maze_len-1), seeker[1])
            elif action[agent_id] == 1:  # move left
                seeker = (seeker[0], max(seeker[1] - 1, 0))
            elif action[agent_id] == 2:  # move up
                seeker = (max(seeker[0] - 1, 0), seeker[1])
            elif action[agent_id] == 3:  # move right
                seeker = (seeker[0], min(seeker[1] + 1, self.maze_len-1))
            else:
                raise ValueError("Invalid action")
            self.agents[agent_id] = seeker

        observations = {i: self.get_observation(i) for i in agent_ids}
        rewards = {i: self.get_reward(i) for i in agent_ids}
        done = {i: self.is_done(i) for i in agent_ids}

        done["__all__"] = all(done.values())

        return observations, rewards, done, done, self.info

    def render(self, action, *args, **kwargs):
        """We override this method here so clear the output in Jupyter notebooks.
        The previous implementation works well in the terminal, but does not clear
        the screen in interactive environments.
        """
        print("render: ",action)
        from gym.envs.classic_control import rendering
        screen_width = self.screen_width
        screen_height = self.screen_height

        if self.viewer is None:
            self.viewer = rendering.Viewer(screen_width, screen_height)

            # create gridworld
            for index in range(self.maze_len + 1):
                self.line = rendering.Line((self.line_offset, self.line_offset + index * self.grid_spacing), (self.screen_width-self.line_offset, self.line_offset + index * self.grid_spacing))
                self.line.set_color(0, 0, 0)
                self.viewer.add_geom(self.line)
            for index in range(self.maze_len + 1):
                self.line = rendering.Line((self.line_offset + index * self.grid_spacing, self.line_offset), (self.line_offset + index * self.grid_spacing, self.screen_height-self.line_offset))
                self.line.set_color(0, 0, 0)
                self.viewer.add_geom(self.line)

            # obstacles
            for fire in self.fires:
                pos_x = (fire % self.maze_len) * self.grid_spacing + self.obj_offset
                pos_y = ((self.maze_len-1) - int(fire / self.maze_len)) * self.grid_spacing + self.obj_offset
                self.fire = rendering.make_circle(self.obj_size)
                self.circletrans = rendering.Transform(translation=(pos_x, pos_y))
                self.fire.add_attr(self.circletrans)
                self.fire.set_color(1, 0, 0)
                self.viewer.add_geom(self.fire)

            # unknowns
            for unknown in self.unknowns:
                pos_x = (unknown % self.maze_len) * self.grid_spacing + self.obj_offset
                pos_y = ((self.maze_len-1) - int(unknown / self.maze_len)) * self.grid_spacing + self.obj_offset
                self.unknown = rendering.FilledPolygon([(self.tr_offset-self.obj_size, self.tr_offset-self.obj_size), (self.tr_offset-self.obj_size, self.tr_offset+self.obj_size), (self.tr_offset+self.obj_size, self.tr_offset+self.obj_size), (self.tr_offset+self.obj_size, self.tr_offset-self.obj_size)])
                self.circletrans = rendering.Transform(translation=(pos_x - self.tr_offset, pos_y - self.tr_offset))
                self.unknown.add_attr(self.circletrans)
                self.unknown.set_color(0, 0, 0)
                self.viewer.add_geom(self.unknown)

            for g in self.goal:
                pos_x = (g % self.maze_len) * self.grid_spacing + self.obj_offset
                pos_y = ((self.maze_len-1) - int(g / self.maze_len)) * self.grid_spacing + self.obj_offset
                self.diamond = rendering.make_circle(self.obj_size)
                self.circletrans = rendering.Transform(translation=(pos_x, pos_y))
                self.diamond.add_attr(self.circletrans)
                self.diamond.set_color(0, 0, 1)
                self.viewer.add_geom(self.diamond)


            self.robot1 = rendering.make_circle(self.obj_size)
            self.robotrans1 = rendering.Transform()
            self.robot1.add_attr(self.robotrans1)
            self.robot1.set_color(0, 1, 0)
            self.viewer.add_geom(self.robot1)

            self.robot2 = rendering.make_circle(self.obj_size)
            self.robotrans2 = rendering.Transform()
            self.robot2.add_attr(self.robotrans2)
            self.robot2.set_color(0, 1, 0)
            self.viewer.add_geom(self.robot2)

            self.robot3 = rendering.make_circle(self.obj_size)
            self.robotrans3 = rendering.Transform()
            self.robot3.add_attr(self.robotrans3)
            self.robot3.set_color(0, 1, 0)
            self.viewer.add_geom(self.robot3)

        if 0 in action.keys():
            self.robotrans1.set_translation(self.x[self.get_observation(0)], self.y[self.get_observation(0)])
        else:
            self.robotrans1.set_translation(self.x[0*self.maze_len], self.y[0*self.maze_len])
        if 1 in action.keys():
            self.robotrans2.set_translation(self.x[self.get_observation(1)], self.y[self.get_observation(1)])
        else:
            self.robotrans2.set_translation(self.x[1*self.maze_len], self.y[1*self.maze_len])
        if 2 in action.keys():
            self.robotrans3.set_translation(self.x[self.get_observation(2)], self.y[self.get_observation(2)])
        else:
            self.robotrans3.set_translation(self.x[2*self.maze_len], self.y[2*self.maze_len])

        return self.viewer.render(return_rgb_array=True)

    def close(self):
        if self.viewer:
            self.viewer.close()
            self.viewer = None


def get_action(q_table, state, action, observation, reward, episode, epsilon_coefficient=0.1, epi=0):
    # print(observation)
    next_state = observation
    if episode < epi:
        next_action = np.random.choice([0, 1, 2, 3])
    else:
        epsilon = epsilon_coefficient * (0.99 ** (episode-epi))
        # if epsilon <= 0.02:
        #     epsilon = 0.02
        if epsilon <= np.random.uniform(0, 1):
            next_action = np.argmax(q_table[next_state])
        else:
            next_action = np.random.choice([0, 1, 2, 3])
    # -------------------------------------update q_table----------------------------------
    alpha = 0.2
    gamma = 0.99
    q_table[state, action] = (1 - alpha) * q_table[state, action] + alpha * (
            reward + gamma * q_table[next_state, next_action])
    # -------------------------------------------------------------------------------------------
    return next_action, next_state, q_table


def vworld_control(q_table, episode, obs, action, virtual_world):
    flag = 0
    a_score = np.zeros([4])
    while flag < 4:
        flag = flag + 1
        key = "%d_%d" % (obs, action)
        if key not in transfer:
            a_score[action] = -100
            _, x, q_table = get_action(q_table, obs, action, obs, -100, episode, 0.5)
            action = (action + 1) % 4
            continue
        else:
            next_state = transfer[key]
            if virtual_world[int(next_state / env.maze_len)][int(next_state % env.maze_len)] == -100:
                a_score[action] = -100
                _, x, q_table = get_action(q_table, obs, action, next_state, -100, episode, 0.5)
                action = (action + 1) % 4
            elif virtual_world[int(next_state / env.maze_len)][int(next_state % env.maze_len)] < 0:
                a_score[action] = virtual_world[int(next_state / env.maze_len)][int(next_state % env.maze_len)]
                action = (action + 1) % 4
                _, x, q_table = get_action(q_table, obs, action, next_state, virtual_world[int(next_state / env.maze_len)][int(next_state % env.maze_len)], episode, 0.5)
            else:
                for index in range(4):
                    if a_score[index] == 0:
                        a_score[index] = q_table[obs][index]
                break
    if flag >= 4:
        epsilon_coefficient = 0.2
        epsilon = epsilon_coefficient * (0.99 ** episode)
        if epsilon <= np.random.uniform(0, 1):
            action = np.argmax(a_score)
        else:
            choindex = [i for i, e in enumerate(a_score) if e != -100]
            action = np.random.choice(choindex)
    return action, q_table

def normalization(width, height, xmin, xmax, ymin, ymax):
    dw = 1 / width
    dh = 1 / height
    x_center = (xmin + xmax) / 2
    y_center = (ymax + ymin) / 2
    w = (xmax - xmin)
    h = (ymax - ymin)
    x, y, w, h = x_center * dw, y_center * dh, w * dw, h * dh
    return x, y, w, h


if __name__ == '__main__':
    freeze_support()
    parser = argparse.ArgumentParser()
    parser.add_argument('--maze_size', type=int, choices=[4, 10, 50], default=50, help='size of the maze 4/10/50')
    parser.add_argument('--mode', type=str, choices=["easy", "medium", "hard", "customized", "random"], default="easy",
                        help='choose the mode of gridworld ("easy"/"medium"/"hard"/"customized"/"random"), "easy" means fewer obstacles '
                             'and "hard" means more obstacles. You can choose the "random" mode, where you need to specify the '
                             'number of fires and unknown objects, and their positions will be randomly generated in the gridworld. '
                             'You can also choose the "customized" mode, where you need to define the positions of obstacles in the YAML file. ')
    parser.add_argument('--fire_num', type=int, default=4, help='the number of fires')
    parser.add_argument('--unknown_num', type=int, default=5, help='the number of unknown objects')
    parser.add_argument('--need_opt', type=int, default=0, choices=[0, 1], help='with coordinated optimization or without coordinated optimization, 0 means without and 1 means with')
    args = parser.parse_args()
    with open("env_config.yaml", "r") as file:
        config = yaml.safe_load(file)
    # print(config)
    env_config = config[args.maze_size]
    print("maze size: %d\n" % args.maze_size)
    if args.need_opt:
        print("with coordinated optimization")
    else:
        print("without coordinated optimization")
    print("mode: %s\nThe positions of the obstacles are as follows." % args.mode)
    if args.mode == "easy":
        print("fires: ", env_config["easy"]["fires"])
        print("unknown objects: ", env_config["easy"]["unknowns"])
        env = AdvancedEnv(env_config["maze_len"], env_config["grid_spacing"],
                          env_config["grid_offset"], env_config["goal_pos"],
                          env_config["screen_width"], env_config["screen_height"],
                          env_config["line_offset"], env_config["obj_offset"],
                          env_config["tr_offset"], env_config["obj_size"], env_config["bias"],
                          None, None, env_config["easy"]["fires"], env_config["easy"]["unknowns"])
    elif args.mode == "medium":
        print("fires: ", env_config["medium"]["fires"])
        print("unknown objects: ", env_config["medium"]["unknowns"])
        env = AdvancedEnv(env_config["maze_len"], env_config["grid_spacing"],
                          env_config["grid_offset"], env_config["goal_pos"],
                          env_config["screen_width"], env_config["screen_height"],
                          env_config["line_offset"], env_config["obj_offset"],
                          env_config["tr_offset"], env_config["obj_size"], env_config["bias"],
                          None, None, env_config["medium"]["fires"], env_config["medium"]["unknowns"])
    elif args.mode == "hard":
        print("fires: ", env_config["hard"]["fires"])
        print("unknown objects: ", env_config["hard"]["unknowns"])
        env = AdvancedEnv(env_config["maze_len"], env_config["grid_spacing"],
                          env_config["grid_offset"], env_config["goal_pos"],
                          env_config["screen_width"], env_config["screen_height"],
                          env_config["line_offset"], env_config["obj_offset"],
                          env_config["tr_offset"], env_config["obj_size"], env_config["bias"],
                          None, None, env_config["hard"]["fires"], env_config["hard"]["unknowns"])
    elif args.mode == "customized":
        print("fires: ", env_config["customized"]["fires"])
        print("unknown objects: ", env_config["customized"]["unknowns"])
        env = AdvancedEnv(env_config["maze_len"], env_config["grid_spacing"],
                          env_config["grid_offset"], env_config["goal_pos"],
                          env_config["screen_width"], env_config["screen_height"],
                          env_config["line_offset"], env_config["obj_offset"],
                          env_config["tr_offset"], env_config["obj_size"], env_config["bias"],
                          None, None, env_config["customized"]["fires"], env_config["customized"]["unknowns"])
    else:
        # print(env_config)
        env = AdvancedEnv(env_config["maze_len"], env_config["grid_spacing"],
                          env_config["grid_offset"], env_config["goal_pos"],
                          env_config["screen_width"], env_config["screen_height"],
                          env_config["line_offset"], env_config["obj_offset"],
                          env_config["tr_offset"], env_config["obj_size"], env_config["bias"],
                          args.fire_num, args.unknown_num, None, None)

    # set seed
    np.random.seed(42)
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)

    # tensorboard
    writer = SummaryWriter("runs/20250617_2bn6_wo/reward")
    writer0 = SummaryWriter("runs/20250617_2bn6_wo/q_value_0")
    writer1 = SummaryWriter("runs/20250617_2bn6_wo/q_value_1")
    writer2 = SummaryWriter("runs/20250617_2bn6_wo/q_value_2")
    # writer3 = SummaryWriter("runs/tt_y/accurate")


    transfer = dict()
    size = env.maze_len
    for i in range(size, size * size):
        transfer[str(i) + '_2'] = i - size  # up
    for i in range(size * (size - 1)):
        transfer[str(i) + '_0'] = i + size  # down
    for i in range(1, size * size):
        if i % size == 0:
            continue
        transfer[str(i) + '_1'] = i - 1  # left
    for i in range(size * size):
        if (i + 1) % size == 0:
            continue
        transfer[str(i) + '_3'] = i + 1  # right

    pos = [[i + j * size for i in range(size)] for j in range(size)]

    q_table0 = np.zeros([size*size,4])
    q_table1 = np.zeros([size*size,4])
    q_table2 = np.zeros([size*size,4])
    last_time_steps = np.zeros(50)  # latest 50 episodes scores
    goal_average_steps = 0.98

    agent0_cnt = np.zeros((size, size))
    agent1_cnt = np.zeros((size, size))
    agent2_cnt = np.zeros((size, size))

    success_cnt = 0
    hit_cnt = 0

    n=1
    image_cnt = 0
    opt_flag = 0
    flag1=0
    ss = 0
    timer = time.time()
    for episode in range(2000):
        print("\nepisode: ", episode)
        obs, _ = env.reset(options=episode)

        episode_reward = 0
        cnt_obstacle_rate = 0
        max_cnt_obstacle_rate = 0

        print(obs)
        action0 = np.argmax(q_table0[obs[0]])
        action1 = np.argmax(q_table1[obs[1]])
        action2 = np.argmax(q_table1[obs[2]])
        action = {0: action0, 1: action1, 2: action2}
        state = obs
        virtual_world = np.zeros([env.maze_len, env.maze_len])
        env.agent_goal[0]=0
        env.agent_goal[1]=0
        env.agent_goal[2]=0
        stp = 0
        while stp < 2000:
            ss += 1
            stp += 1
            print("step: ", stp)
            print("平均分: ", last_time_steps.mean())
            print("episode: ", episode)

            # q-value
            p = env.maze_len * env.maze_len - 2
            RL0_value = np.array([q_table0[p][0], q_table0[p][1], q_table0[p][2], q_table0[p][3]])
            RL1_value = np.array([q_table1[p][0], q_table1[p][1], q_table1[p][2], q_table1[p][3]])
            RL2_value = np.array([q_table2[p][0], q_table2[p][1], q_table2[p][2], q_table2[p][3]])
            print("RL0 98, down, left, up, right: ", RL0_value[0], RL0_value[1], RL0_value[2], RL0_value[3])
            print("RL1 98, down, left, up, right: ", RL1_value[0], RL1_value[1], RL1_value[2], RL1_value[3])
            print("RL2 98, down, left, up, right: ", RL2_value[0], RL2_value[1], RL2_value[2], RL2_value[3])
            writer0.add_scalar("q_value", RL0_value[3], ss)
            writer1.add_scalar("q_value", RL1_value[3], ss)
            writer2.add_scalar("q_value", RL2_value[3], ss)

            if 0 in action.keys():
                agent0_cnt[int(state[0] / env.maze_len)][int(state[0] % env.maze_len)] += 1
            if 1 in action.keys():
                agent1_cnt[int(state[1] / env.maze_len)][int(state[1] % env.maze_len)] += 1
            if 2 in action.keys():
                agent2_cnt[int(state[2] / env.maze_len)][int(state[2] % env.maze_len)] += 1

            image_list = os.listdir("./dataset/VOC2007/images")
            image_cnt = len(image_list)
            print(image_cnt)

            if args.need_opt == 1:
                if (image_cnt == 50 or image_cnt == 51 or image_cnt == 52) and opt_flag == 0:
                    train.run()
                    opt_flag = 1

            img = env.render(action)
            # if image_cnt < 200:
            plt.imsave("./tmp_image/tmp.jpg", img)
            # time.sleep(0.1)
            det = run(source="./tmp_image/tmp.jpg", weights="./runs/train/exp/weights/last.pt")
            image = Image.open("./tmp_image/tmp.jpg")

            # 0 agent0
            # 1 agent1
            # 2 obstacle
            # 3 unknown
            # 4 goal
            cnt_obstacle = 0
            for index in range(len(det)):
                y = int((det[index][0] + env.bias - env.line_offset) / env.grid_spacing)
                x = int((det[index][1] + env.bias - env.line_offset) / env.grid_spacing)
                # print(x, y)
                # if int(det[index][5]) == 0:
                #     obs0 = pos[x][y]
                # elif int(det[index][5]) == 1:
                #     obs1 = pos[x][y]
                if int(det[index][5]) == 2:
                    cnt_obstacle += 1
                    virtual_world[x][y] = -100
                elif int(det[index][5]) == 3:
                    virtual_world[x][y] = 100
            for pos in env.goal:
                virtual_world[int(pos / env.maze_len)][int(pos % env.maze_len)] = 100

            print(cnt_obstacle)
            # cnt_obstacle_rate = cnt_obstacle / (env.fire_num + env.unknown_num)
            # if cnt_obstacle_rate > max_cnt_obstacle_rate:
            #     max_cnt_obstacle_rate = cnt_obstacle_rate

            if 0 in action.keys():
                a0, q_table0 = vworld_control(q_table0, episode, state[0], action[0], virtual_world)
                action[0] = a0
            if 1 in action.keys():
                a1, q_table1 = vworld_control(q_table1, episode, state[1], action[1], virtual_world)
                action[1] = a1
            if 2 in action.keys():
                a2, q_table2 = vworld_control(q_table2, episode, state[2], action[2], virtual_world)
                action[2] = a2

            obs, rewards, dones, x, infos = env.step(action)

            for key in action.keys():
                if obs[key] not in env.goal and abs(virtual_world[int(obs[key] / env.maze_len)][int(obs[key] % env.maze_len)] - rewards[key]) > 50:
                #     print(observation, action)
                    label = 2
                    x_min = (int(obs[key] % env.maze_len)) * env.grid_spacing + env.line_offset
                    y_min = (int(obs[key] / env.maze_len)) * env.grid_spacing + env.line_offset
                    x_max = x_min + env.grid_spacing
                    y_max = y_min + env.grid_spacing
                    print(x_min, y_min, x_max, y_max)
                    txts_save_path = r"./dataset/VOC2007/labels/"
                    txt_name = str(n)+".txt"
                    os.mkdir(txts_save_path) if not os.path.exists(txts_save_path) else None
                    f = open(os.path.join(txts_save_path, txt_name), 'w')
                    x, y, w, h = normalization(env.screen_width, env.screen_height, x_min, x_max, y_min, y_max)
                    # if unknown_loc.find(str(label) + ' ' + str(x) + ' ' + str(y) + ' ' + str(w) + ' ' + str(h) + ' ' + '\n') == -1:
                    #     unknown_loc += str(label) + ' ' + str(x) + ' ' + str(y) + ' ' + str(w) + ' ' + str(h) + ' ' + '\n'
                    # f.write(unknown_loc)
                    f.write(str(label) + ' ' + str(x) + ' ' + str(y) + ' ' + str(w) + ' ' + str(h) + ' ' + '\n')
                    ori = det.tolist()
                    print(ori)
                    for k in range(len(ori)):
                        x, y, w, h = normalization(env.screen_width, env.screen_height, ori[k][0], ori[k][2], ori[k][1], ori[k][3])
                        f.write(str(int(ori[k][5])) + ' ' + str(x) + ' ' + str(y) + ' ' + str(w) + ' ' + str(h) + ' ' + '\n')
                    f.close()
                    image.save("./dataset/VOC2007/images/%d.jpg"%n)
                    if n % 8 == 0:
                        if not os.path.exists("./dataset/VOC2007/val2017.txt"):
                            os.makedirs("./dataset/VOC2007/val2017.txt")
                        list_file = open("./dataset/VOC2007/val2017.txt", "a")
                        list_file.write("./images/%d.jpg\n"%n)
                        n = n + 1
                        list_file.close()
                    else:
                        if not os.path.exists("./dataset/VOC2007/train2017.txt"):
                            os.makedirs("./dataset/VOC2007/train2017.txt")
                        list_file = open("./dataset/VOC2007/train2017.txt", "a")
                        list_file.write("./images/%d.jpg\n" % n)
                        n = n + 1
                        list_file.close()

            # agent collision
            for key1 in action.keys():
                for key2 in action.keys():
                    if key1 != key2 and obs[key1] == obs[key2]:
                        print("hit!")
                        env.agent_goal[key1] = -10
                        env.agent_goal[key2] = -10
            print(env.agent_goal)

            if 0 in action.keys():
                action0, state0, q_table0 = get_action(q_table0, state[0], action[0], obs[0], rewards[0], episode, epi=100)
                action[0] = action0
            if 1 in action.keys():
                action1, state1, q_table1 = get_action(q_table1, state[1], action[1], obs[1], rewards[1], episode, epi=100)
                action[1] = action1
            if 2 in action.keys():
                action2, state2, q_table2 = get_action(q_table2, state[2], action[2], obs[2], rewards[2], episode, epi=100)
                action[2] = action2

            state = obs

            if len(action) == 0:
                if env.agent_goal[0]==100 and env.agent_goal[1]==100 and env.agent_goal[2]==100:
                    success_cnt += 1
                    print("all arrive!")
                print('episodes:  %d, steps: %d, episode_reward：%dscore: %f' % (episode, stp, (sum(env.agent_goal)/len(env.agent_goal))/100, last_time_steps.mean()))
                np.savetxt("q_table0.txt", q_table0, delimiter=",")
                np.savetxt("q_table1.txt", q_table1, delimiter=",")
                np.savetxt("q_table2.txt", q_table2, delimiter=",")
                last_time_steps = np.hstack((last_time_steps[1:], [(sum(env.agent_goal)/len(env.agent_goal))/100]))
                writer.add_scalar("reward_show", (sum(env.agent_goal)/len(env.agent_goal))/100, episode)
                break

        # writer3.add_scalar("accuracy", max_cnt_obstacle_rate, episode)

        if (last_time_steps.mean() >= goal_average_steps):
            np.savetxt("q_table0.txt", q_table0, delimiter=",")
            np.savetxt("q_table1.txt", q_table1, delimiter=",")
            np.savetxt("q_table2.txt", q_table2, delimiter=",")
            env.close()
            break

    print("success count: ", success_cnt)
    print("ss: ", ss)
    print("finish")

    print(agent0_cnt)
    print(agent1_cnt)
    print(agent2_cnt)
    print(agent0_cnt.tolist())
    print(agent1_cnt.tolist())
    print(agent2_cnt.tolist())
    with open('output1.txt', 'a') as file:
        file.write("success count: " + str(success_cnt) + "\n")
        file.write(str(agent0_cnt)+"\n")
        file.write(str(agent1_cnt)+"\n")
        file.write(str(agent2_cnt)+"\n")
        file.write(str(agent0_cnt.tolist())+"\n")
        file.write(str(agent1_cnt.tolist())+"\n")
        file.write(str(agent2_cnt.tolist())+"\n")


