import os

import torch
import gymnasium
import numpy as np
import pygame
from gymnasium import spaces
from gymnasium.utils import seeding
import time
import cv2

from pettingzoo import AECEnv
from pettingzoo.discrete_mpe._mpe_utils.core import Agent
from pettingzoo.utils import wrappers
from pettingzoo.utils.agent_selector import AgentSelector

alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"

class DiscretizedActionSpace:#5维的[-1,1],按0.01分的话共10^10个点
    #状态若用8维的[-10,10]，按0.05分的话共400^8(约10^20)个点
    def __init__(self, low, high, num_actions, shape):
        self.high = np.array([high]* shape[0]) 
        self.low = np.array([low] * shape[0]) 

        self.num_actions = num_actions
        self.shape = shape
        # 创建一个等间隔的动作数组
        self.actions = np.linspace(low, high, num_actions)
    
    def sample(self):
        # 随机选择一个动作
        return np.random.choice(self.actions, size=self.shape)

    def map_action(self, discrete_action):
        # 映射离散动作到连续值
        continuous_action = self.actions[discrete_action]
        return continuous_action

def make_env(raw_env):
    def env(**kwargs):
        env = raw_env(**kwargs)
        if env.continuous_actions:
            env = wrappers.ClipOutOfBoundsWrapper(env)
        else:
            env = wrappers.AssertOutOfBoundsWrapper(env)
        env = wrappers.OrderEnforcingWrapper(env)
        return env

    return env


class SimpleEnv(AECEnv):
    metadata = {
        "render_modes": ["human", "rgb_array"],
        "is_parallelizable": True,
        "render_fps": 10,
    }

    def __init__(
        self,
        scenario,
        world,
        max_cycles,
        render_mode=None,
        continuous_actions=False,
        local_ratio=None,
        sigma=1
    ):
        super().__init__()

        self.render_mode = render_mode
        pygame.init()
        self.viewer = None
        self.width = 700
        self.height = 700
        self.screen = pygame.Surface([self.width, self.height])
        self.frame_rate = 24
        self.avi = cv2.VideoWriter('output.avi', cv2.VideoWriter_fourcc(*'XVID'),self.frame_rate, (self.width, self.height))
        self.max_size = 1
        self.game_font = pygame.freetype.Font(
            os.path.join(os.path.dirname(__file__), "secrcode.ttf"), 24
        )

        # Set up the drawing window

        self.renderOn = False
        self._seed()

        self.max_cycles = max_cycles
        self.scenario = scenario
        self.world = world
        self.continuous_actions = continuous_actions
        self.local_ratio = local_ratio
        self.sigma = sigma

        self.scenario.reset_world(self.world, self.np_random)

        self.agents = [agent.name for agent in self.world.agents]
        self.possible_agents = self.agents[:]
        self._index_map = {
            agent.name: idx for idx, agent in enumerate(self.world.agents)
        }

        self._agent_selector = AgentSelector(self.agents)

        # set spaces
        self.action_spaces = dict()
        self.observation_spaces = dict()
        state_dim = len(self.scenario.all_state(self.world))
        self.state_dim = state_dim
        self.agent_state_dim = len(self.scenario.agent_state(self.world))
        for agent in self.world.agents:
            if agent.movable:
                space_dim = self.world.dim_p * 2 + 1
            elif self.continuous_actions:
                space_dim = 0
            else:
                space_dim = 1
            if not agent.silent:
                if self.continuous_actions:
                    space_dim += self.world.dim_c
                else:
                    space_dim *= self.world.dim_c

            obs_dim = len(self.scenario.observation(agent, self.world))
            if self.continuous_actions:
                self.action_spaces[agent.name] = spaces.Box(
                    low=0, high=1, shape=(space_dim,)
                )
            else:
                self.action_spaces[agent.name] = spaces.Discrete(space_dim)
            self.observation_spaces[agent.name] = spaces.Box(
                low=-np.float32(np.inf),
                high=+np.float32(np.inf),
                shape=(obs_dim,),
                dtype=np.float32,
            )
            # self.action_spaces[agent.name] = DiscretizedActionSpace(low=0, high=1, num_actions=101, shape=(space_dim,))

        self.state_space = spaces.Box(
            low=-np.float32(np.inf),
            high=+np.float32(np.inf),
            shape=(state_dim,),
            dtype=np.float32,
        )
        self.dis_state_space = np.linspace(-10, 10, 401)

        self.steps = 0

        self.current_actions = [None] * self.num_agents

    def observation_space(self, agent):
        return self.observation_spaces[agent]

    def action_space(self, agent):
        return self.action_spaces[agent]

    def _seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)

    def observe(self, agent):
        return self.scenario.observation(
            self.world.agents[self._index_map[agent]], self.world
        ).astype(np.float32)

    # def dis_state(self):
    #     con_state = self.scenario.all_state(self.world).astype(np.float32)
    #     dis_state = np.vectorize(lambda x: self.dis_state_space[np.argmin(np.abs(self.dis_state_space - x))])(con_state)
    #     return dis_state

    def state(self):
        con_state = self.scenario.all_state(self.world).astype(np.float32)
        dis_state = np.vectorize(lambda x: self.dis_state_space[np.argmin(np.abs(self.dis_state_space - x))])(con_state)
        return dis_state

    def reset(self, seed=None, options=None):
        if seed is not None:
            self._seed(seed=seed)
        self.scenario.reset_world(self.world, self.np_random)

        self.agents = self.possible_agents[:]
        self.rewards = {name: 0.0 for name in self.agents}
        self._cumulative_rewards = {name: 0.0 for name in self.agents}
        self.terminations = {name: False for name in self.agents}
        self.truncations = {name: False for name in self.agents}
        self.infos = {name: {} for name in self.agents}

        self.agent_selection = self._agent_selector.reset()
        self.steps = 0

        self.current_actions = [None] * self.num_agents

    def _execute_world_step(self):
        # set action for each agent
        for i, agent in enumerate(self.world.agents):
            action = self.current_actions[i]
            scenario_action = []
            if agent.movable:
                mdim = self.world.dim_p * 2 + 1
                if self.continuous_actions:
                    scenario_action.append(action[0:mdim])
                    action = action[mdim:]
                else:
                    scenario_action.append(action % mdim)
                    action //= mdim
            if not agent.silent:
                scenario_action.append(action)
            self._set_action(scenario_action, agent, self.action_spaces[agent.name])

        self.world.step()
        

        global_reward = 0.0
        if self.local_ratio is not None:
            global_reward = float(self.scenario.global_reward(self.world))

        for agent in self.world.agents:
            agent_reward = float(self.scenario.reward(agent, self.world))
            if self.local_ratio is not None:
                reward = (
                    global_reward * (1 - self.local_ratio)
                    + agent_reward * self.local_ratio
                )
            else:
                reward = agent_reward

            self.rewards[agent.name] = reward

    # set env action for a particular agent
    def _set_action(self, action, agent, action_space, time=None):
        agent.action.u = np.zeros(self.world.dim_p)
        agent.action.c = np.zeros(self.world.dim_c)

        if agent.movable:
            # physical action
            agent.action.u = np.zeros(self.world.dim_p)
            if self.continuous_actions:
                # Process continuous action as in OpenAI MPE
                # Note: this ordering preserves the same movement direction as in the discrete case
                agent.action.u[0] += action[0][2] - action[0][1]
                agent.action.u[1] += action[0][4] - action[0][3]
            else:
                # process discrete action
                if action[0] == 1:
                    agent.action.u[0] = -1.0
                if action[0] == 2:
                    agent.action.u[0] = +1.0
                if action[0] == 3:
                    agent.action.u[1] = -1.0
                if action[0] == 4:
                    agent.action.u[1] = +1.0
            sensitivity = 5.0
            if agent.accel is not None:
                sensitivity = agent.accel
            agent.action.u *= sensitivity
            action = action[1:]
        if not agent.silent:
            # communication action
            if self.continuous_actions:
                agent.action.c = action[0]
            else:
                agent.action.c = np.zeros(self.world.dim_c)
                agent.action.c[action[0]] = 1.0
            action = action[1:]
        # make sure we used all elements of action
        assert len(action) == 0

    def step(self, action):
        if (
            self.terminations[self.agent_selection]
            or self.truncations[self.agent_selection]
        ):
            self._was_dead_step(action)
            return
        cur_agent = self.agent_selection
        current_idx = self._index_map[self.agent_selection]
        next_idx = (current_idx + 1) % self.num_agents
        self.agent_selection = self._agent_selector.next()

        self.current_actions[current_idx] = action

        if next_idx == 0:
            start_time = time.time()
            self._execute_world_step()
            self.add_noise()
            self.steps += 1
            if self.steps >= self.max_cycles:
                for a in self.agents:
                    self.truncations[a] = True
        else:
            self._clear_rewards()

        self._cumulative_rewards[cur_agent] = 0
        self._accumulate_rewards()

        if self.render_mode == "human":
            self.render()

    def enable_render(self, mode="human"):
        if not self.renderOn and mode == "human":
            self.screen = pygame.display.set_mode(self.screen.get_size())
            self.clock = pygame.time.Clock()
            self.renderOn = True

    def render(self):
        if self.render_mode is None:
            gymnasium.logger.warn(
                "You are calling render method without specifying any render mode."
            )
            return

        self.enable_render(self.render_mode)

        self.draw()
        if self.render_mode == "rgb_array":
            observation = np.array(pygame.surfarray.pixels3d(self.screen))
            return np.transpose(observation, axes=(1, 0, 2))
        elif self.render_mode == "human":
            pygame.display.flip()
            frame = pygame.surfarray.array3d(pygame.display.get_surface())
            frame = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)  # 旋转以匹配 OpenCV 格式
            frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 
            self.avi.write(frame)
            self.clock.tick(self.metadata["render_fps"])
            return

    def draw(self):
        # clear screen
        self.screen.fill((255, 255, 255))

        # update bounds to center around agent
        all_poses = [entity.state.p_pos for entity in self.world.entities]
        cam_range = np.max(np.abs(np.array(all_poses)))

        # update geometry and text positions
        text_line = 0
        for e, entity in enumerate(self.world.entities):
            # geometry
            x, y = entity.state.p_pos
            y *= (
                -1
            )  # this makes the display mimic the old pyglet setup (ie. flips image)
            x = (
                (x / cam_range) * self.width // 2 * 0.9
            )  # the .9 is just to keep entities from appearing "too" out-of-bounds
            y = (y / cam_range) * self.height // 2 * 0.9
            x += self.width // 2
            y += self.height // 2
            pygame.draw.circle(
                self.screen, entity.color * 200, (x, y), entity.size * 350
            )  # 350 is an arbitrary scale factor to get pygame to render similar sizes as pyglet
            pygame.draw.circle(
                self.screen, (0, 0, 0), (x, y), entity.size * 350, 1
            )  # borders
            assert (
                0 < x < self.width and 0 < y < self.height
            ), f"Coordinates {(x, y)} are out of bounds."

            # text
            if isinstance(entity, Agent):
                if entity.silent:
                    continue
                if np.all(entity.state.c == 0):
                    word = "_"
                elif self.continuous_actions:
                    word = (
                        "[" + ",".join([f"{comm:.2f}" for comm in entity.state.c]) + "]"
                    )
                else:
                    word = alphabet[np.argmax(entity.state.c)]

                message = entity.name + " sends " + word + "   "
                message_x_pos = self.width * 0.05
                message_y_pos = self.height * 0.95 - (self.height * 0.05 * text_line)
                self.game_font.render_to(
                    self.screen, (message_x_pos, message_y_pos), message, (0, 0, 0)
                )
                text_line += 1

    def close(self):
        self.avi.release()
        if self.screen is not None:
            
            pygame.quit()
            self.screen = None

    def next_idx(self):
        current_idx = self._index_map[self.agent_selection]
        next_idx = (current_idx + 1) % self.num_agents
        return next_idx
    
    def mask(self, episode_steps, done):
        mask = 1 if episode_steps < self.max_cycles else float(not done)
        return mask
    
    def add_noise(self):
            
            # self.scenario.all_state_add_noise(self.sigma,self.world)
            
            self.scenario.agent_state_add_noise(self.sigma,self.world)
            
    def next_state(self, s, a, b):
        next_states = np.empty(s.shape)
        for j in range(s.shape[0]):
            actions = [a[j],b[j]]

            for i, agent in enumerate(self.world.agents):
                action = actions[i]
                scenario_action = []
                if agent.movable:
                    mdim = self.world.dim_p * 2 + 1
                    if self.continuous_actions:
                        scenario_action.append(action[0:mdim])
                        action = action[mdim:]
                    else:
                        scenario_action.append(action % mdim)
                        action //= mdim
                if not agent.silent:
                    scenario_action.append(action)
                self._set_action(scenario_action, agent, self.action_spaces[agent.name])
            
            
            
            # set actions for scripted agents
            for agent in self.world.scripted_agents:
                agent.action = agent.action_callback(agent, world)
            # gather forces applied to entities
            p_force = [None] * len(self.world.entities)
            # apply agent physical controls
            p_force = self.world.apply_action_force(p_force)
            # apply environment forces
            p_force = self.world.apply_environment_force(p_force)
            # integrate physical state
            entity_pos = []
            entity_color = []
            for entity in self.world.landmarks:  # world.entities:
                entity_pos.append(entity.state.p_pos)
                entity_color.append(entity.color)
            # communication of all other agents
            agent_vel = []
            agent_pos = []
            agent_goal = []
            agent_color = []
            for agent in self.world.agents:
                agent_vel.append(agent.state.p_vel)
                agent_pos.append(agent.state.p_pos)
                agent_goal.append(agent.goal_a.state.p_pos)
                agent_color.append(agent.color)
            
            for i, agent in enumerate(self.world.agents):
                agent_pos[i] += agent.state.p_vel * self.world.dt
                agent_vel[i] = agent.state.p_vel * (1 - self.world.damping)
            
                if p_force[i] is not None:
                    agent_vel[i] += (p_force[i] / entity.mass) * self.world.dt
                if agent.max_speed is not None:
                    speed = np.sqrt(
                        np.square(agent_vel[i][0]) + np.square(agent_vel[i][1])
                    )
                    if speed > agent.max_speed:
                        agent_vel[i] = (
                            agent_vel[i]
                            / np.sqrt(
                                np.square(agent_vel[i][0])
                                + np.square(agent_vel[i][1])
                            )
                            * agent.max_speed
                        )
            
            next_states[j] =  np.concatenate(
                agent_vel
                + agent_pos
                + agent_goal
                + agent_color
                + entity_pos
                + entity_color
            )
        return next_states
    
    def get_collision_force(self, entity_a, entity_b, delta_pos):
        if (not entity_a.collide) or (not entity_b.collide):
            return [None, None]  # not a collider
        if entity_a is entity_b:
            return [None, None]  # don't collide against itself
        # compute actual distance between entities
        dist = np.sqrt(np.sum(np.square(delta_pos)))
        # minimum allowable distance
        dist_min = entity_a.size + entity_b.size
        # softmax penetration
        k = self.world.contact_margin
        penetration = np.logaddexp(0, -(dist - dist_min) / k) * k
        force = self.world.contact_force * delta_pos / dist * penetration
        force_a = +force if entity_a.movable else None
        force_b = -force if entity_b.movable else None
        return [force_a, force_b]

    
    def batch_next_state(self,s,a,b,device):
        next_states = s.clone()
        next_states = next_states.detach().cpu()
        actions = [a,b]
        u = []
        for i, agent in enumerate(self.world.agents):
            action = actions[i]
            u0 = action[:,2] - action[:,1]
            u1 = action[:,4] - action[:,3]
            
            sensitivity = 5.0
            if agent.accel is not None:
                sensitivity = agent.accel
            u.append(sensitivity * torch.stack((u0, u1), dim=1))
        p_force = [None] * len(self.world.entities)
        for i, agent in enumerate(self.world.agents):
            if agent.movable:
                p_force[i] = u[i]
            
            for a, agent_a in enumerate(self.world.agents):
                for b, agent_b in enumerate(self.world.agents):
                    if b <= a:
                        continue
                    p_a = next_states[:,4+2*a:6+2*a]
                    p_b = next_states[:,4+2*b:6+2*b]
                    delta_pos = p_a - p_b
                    delta_pos = delta_pos.detach().cpu().numpy()
                    [f_a, f_b] = self.get_collision_force(agent_a, agent_b, delta_pos)
                    if f_a is not None:
                        if p_force[a] is None:
                            p_force[a] = 0.0
                        f_a = torch.tensor(f_a).to(device)
                        p_force[a] = f_a + p_force[a]
                    if f_b is not None:
                        if p_force[b] is None:
                            p_force[b] = 0.0
                        f_b = torch.tensor(f_b).to(device)
                        p_force[b] = f_b + p_force[b]
                        
            for a, agent_a in enumerate(self.world.agents):
                for b, agent_b in enumerate(self.world.landmarks):
                    p_a = next_states[:,4+2*a:6+2*a]
                    p_b = next_states[:,18+2*b:20+2*b]
                    delta_pos = p_a - p_b
                    delta_pos = delta_pos.detach().cpu().numpy()
                    [f_a, f_b] = self.get_collision_force(agent_a, agent_b, delta_pos)
                    if f_a is not None:
                        if p_force[a] is None:
                            p_force[a] = 0.0
                        f_a = torch.tensor(f_a).to(device)
                        p_force[a] = f_a + p_force[a]
                    if f_b is not None:
                        if p_force[b] is None:
                            p_force[b] = 0.0
                        f_b = torch.tensor(f_b).to(device)
                        p_force[b] = f_b + p_force[b]
            
            
        
        
        for i, agent in enumerate(self.world.agents):
            next_states[:,2*i+4] = next_states[:,2*i+4] + next_states[:,2*i] * self.world.dt
            next_states[:,2*i+5] = next_states[:,2*i+5] + next_states[:,2*i+1] * self.world.dt
            next_states[:,2*i] = next_states[:,2*i] * (1 - self.world.damping)
            next_states[:,2*i+1] = next_states[:,2*i+1] * (1 - self.world.damping)
            
            if p_force[i] is not None:
                a = p_force[i][:,0].detach().cpu()
                b = p_force[i][:,1].detach().cpu()
                next_states[:,2*i] = next_states[:,2*i] + (a / agent.mass) * self.world.dt
                next_states[:,2*i+1] = next_states[:,2*i+1] + (b / agent.mass) * self.world.dt
                
            if agent.max_speed is not None:
                states = s.clone()
                states = states.detach().cpu().numpy()
                speed = np.sqrt(
                    np.square(states[:,2*i]) + np.square(states[:,2*i+1])
                )
                next_states[speed>agent.max_speed,2*i] = (
                        next_states[speed>agent.max_speed,2*i]
                        / np.sqrt(
                            np.square(next_states[speed>agent.max_speed,2*i])
                            + np.square(next_states[speed>agent.max_speed,2*i+1])
                        )
                        * agent.max_speed
                    )
                next_states[speed>agent.max_speed,2*i+1] = (
                        next_states[speed>agent.max_speed,2*i]
                        / np.sqrt(
                            np.square(next_states[speed>agent.max_speed,2*i])
                            + np.square(next_states[speed>agent.max_speed,2*i+1])
                        )
                        * agent.max_speed
                    )
                
                
                
                
                # if speed > agent.max_speed:
                #     next_states[:,2*i] = (
                #         next_states[:,2*i]
                #         / np.sqrt(
                #             np.square(next_states[:,2*i])
                #             + np.square(next_states[:,2*i+1])
                #         )
                #         * agent.max_speed
                #     )
                #     next_states[:,2*i+1] = (
                #         next_states[:,2*i+1]
                #         / np.sqrt(
                #             np.square(next_states[:,2*i])
                #             + np.square(next_states[:,2*i+1])
                #         )
                #         * agent.max_speed
                #     )
                    
                    
        return next_states
            
            


    def state_dis_to_con(self,s):
        s = -10+s*0.05
        return s
    
    def state_con_to_dis(self,s):
        s = (s+10)//0.05
        return s

    def get_dis_states(self,s_size):
        s = np.zeros((s_size,s_size,s_size,s_size,s_size,s_size,s_size,s_size))
        s[...,:] =np.linspace(-10, 10, 401)
        return s