import os

import torch
import gymnasium
import numpy as np
import pygame
from gymnasium import spaces
from gymnasium.utils import seeding
import time
import cv2

from pettingzoo import AECEnv
from pettingzoo.mpe._mpe_utils.core import Agent
from pettingzoo.utils import wrappers
from pettingzoo.utils.agent_selector import AgentSelector

alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"


def make_env(raw_env):
    def env(**kwargs):
        env = raw_env(**kwargs)
        if env.continuous_actions:
            env = wrappers.ClipOutOfBoundsWrapper(env)
        else:
            env = wrappers.AssertOutOfBoundsWrapper(env)
        env = wrappers.OrderEnforcingWrapper(env)
        return env

    return env


class SimpleEnv(AECEnv):
    metadata = {
        "render_modes": ["human", "rgb_array"],
        "is_parallelizable": True,
        "render_fps": 10,
    }

    def __init__(
        self,
        scenario,
        world,
        max_cycles,
        render_mode=None,
        continuous_actions=False,
        local_ratio=None,
        sigma=0
    ):
        super().__init__()

        self.render_mode = render_mode
        pygame.init()
        self.viewer = None
        self.width = 700
        self.height = 700
        self.screen = pygame.Surface([self.width, self.height])
        self.max_size = 1
        self.game_font = pygame.freetype.Font(
            os.path.join(os.path.dirname(__file__), "secrcode.ttf"), 24
        )

        # Set up the drawing window

        self.renderOn = False
        self._seed()

        self.max_cycles = max_cycles
        self.scenario = scenario
        self.world = world
        self.continuous_actions = continuous_actions
        self.local_ratio = local_ratio
        self.sigma = sigma

        self.scenario.reset_world(self.world, self.np_random)

        self.agents = [agent.name for agent in self.world.agents]
        self.possible_agents = self.agents[:]
        self._index_map = {
            agent.name: idx for idx, agent in enumerate(self.world.agents)
        }

        self._agent_selector = AgentSelector(self.agents)

        # set spaces
        self.action_spaces = dict()
        self.observation_spaces = dict()
        state_dim = len(self.scenario.all_state(self.world))
        self.state_dim = state_dim
        self.agent_state_dim = len(self.scenario.agent_state(self.world))
        for agent in self.world.agents:
            if agent.movable:
                space_dim = self.world.dim_p * 2 + 1
            elif self.continuous_actions:
                space_dim = 0
            else:
                space_dim = 1
            if not agent.silent:
                if self.continuous_actions:
                    space_dim += self.world.dim_c
                else:
                    space_dim *= self.world.dim_c

            obs_dim = len(self.scenario.observation(agent, self.world))
            if self.continuous_actions:
                self.action_spaces[agent.name] = spaces.Box(
                    low=0, high=1, shape=(space_dim,)
                )
            else:
                self.action_spaces[agent.name] = spaces.Discrete(space_dim)
            self.observation_spaces[agent.name] = spaces.Box(
                low=-np.float32(np.inf),
                high=+np.float32(np.inf),
                shape=(obs_dim,),
                dtype=np.float32,
            )

        self.state_space = spaces.Box(
            low=-np.float32(np.inf),
            high=+np.float32(np.inf),
            shape=(state_dim,),
            dtype=np.float32,
        )

        self.steps = 0

        self.current_actions = [None] * self.num_agents

    def observation_space(self, agent):
        return self.observation_spaces[agent]

    def action_space(self, agent):
        return self.action_spaces[agent]

    def _seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)

    def observe(self, agent):
        return self.scenario.observation(
            self.world.agents[self._index_map[agent]], self.world
        ).astype(np.float32)

    def state(self):
        return self.scenario.all_state(self.world).astype(np.float32)

    def reset(self, seed=None, options=None):
        if seed is not None:
            self._seed(seed=seed)
        self.scenario.reset_world(self.world, self.np_random)

        self.agents = self.possible_agents[:]
        self.rewards = {name: 0.0 for name in self.agents}
        self._cumulative_rewards = {name: 0.0 for name in self.agents}
        self.terminations = {name: False for name in self.agents}
        self.truncations = {name: False for name in self.agents}
        self.infos = {name: {} for name in self.agents}

        self.agent_selection = self._agent_selector.reset()
        self.steps = 0

        self.current_actions = [None] * self.num_agents

    def _execute_world_step(self):
        # set action for each agent
        for i, agent in enumerate(self.world.agents):
            action = self.current_actions[i]
            scenario_action = []
            if agent.movable:
                mdim = self.world.dim_p * 2 + 1
                if self.continuous_actions:
                    scenario_action.append(action[0:mdim])
                    action = action[mdim:]
                else:
                    scenario_action.append(action % mdim)
                    action //= mdim
            if not agent.silent:
                scenario_action.append(action)
            self._set_action(scenario_action, agent, self.action_spaces[agent.name])

        self.world.step()
        

        global_reward = 0.0
        if self.local_ratio is not None:
            global_reward = float(self.scenario.global_reward(self.world))

        for agent in self.world.agents:
            agent_reward = float(self.scenario.reward(agent, self.world))
            if self.local_ratio is not None:
                reward = (
                    global_reward * (1 - self.local_ratio)
                    + agent_reward * self.local_ratio
                )
            else:
                reward = agent_reward

            self.rewards[agent.name] = reward

    # set env action for a particular agent
    def _set_action(self, action, agent, action_space, time=None):
        agent.action.u = np.zeros(self.world.dim_p)
        agent.action.c = np.zeros(self.world.dim_c)

        if agent.movable:
            # physical action
            agent.action.u = np.zeros(self.world.dim_p)
            if self.continuous_actions:
                # Process continuous action as in OpenAI MPE
                # Note: this ordering preserves the same movement direction as in the discrete case
                agent.action.u[0] += action[0][2] - action[0][1]
                agent.action.u[1] += action[0][4] - action[0][3]
            else:
                # process discrete action
                if action[0] == 1:
                    agent.action.u[0] = -1.0
                if action[0] == 2:
                    agent.action.u[0] = +1.0
                if action[0] == 3:
                    agent.action.u[1] = -1.0
                if action[0] == 4:
                    agent.action.u[1] = +1.0
            sensitivity = 5.0
            if agent.accel is not None:
                sensitivity = agent.accel
            agent.action.u *= sensitivity
            action = action[1:]
        if not agent.silent:
            # communication action
            if self.continuous_actions:
                agent.action.c = action[0]
            else:
                agent.action.c = np.zeros(self.world.dim_c)
                agent.action.c[action[0]] = 1.0
            action = action[1:]
        # make sure we used all elements of action
        assert len(action) == 0

    def step(self, action):
        if (
            self.terminations[self.agent_selection]
            or self.truncations[self.agent_selection]
        ):
            self._was_dead_step(action)
            return
        cur_agent = self.agent_selection
        current_idx = self._index_map[self.agent_selection]
        next_idx = (current_idx + 1) % self.num_agents
        self.agent_selection = self._agent_selector.next()

        self.current_actions[current_idx] = action

        if next_idx == 0:
            start_time = time.time()
            self._execute_world_step()
            self.add_noise()
            self.steps += 1
            if self.steps >= self.max_cycles:
                for a in self.agents:
                    self.truncations[a] = True
        else:
            self._clear_rewards()

        self._cumulative_rewards[cur_agent] = 0
        self._accumulate_rewards()

        if self.render_mode == "human":
            self.render()

    def enable_render(self, mode="human"):
        if not self.renderOn and mode == "human":
            self.screen = pygame.display.set_mode(self.screen.get_size())
            self.clock = pygame.time.Clock()
            self.renderOn = True

    def render(self):
        if self.render_mode is None:
            gymnasium.logger.warn(
                "You are calling render method without specifying any render mode."
            )
            return

        self.enable_render(self.render_mode)

        self.draw()
        if self.render_mode == "rgb_array":
            observation = np.array(pygame.surfarray.pixels3d(self.screen))
            return np.transpose(observation, axes=(1, 0, 2))
        elif self.render_mode == "human":
            pygame.display.flip()
            self.clock.tick(self.metadata["render_fps"])
            return

    def draw(self):
        # clear screen
        self.screen.fill((255, 255, 255))

        # update bounds to center around agent
        all_poses = [entity.state.p_pos for entity in self.world.entities]
        cam_range = np.max(np.abs(np.array(all_poses)))

        # update geometry and text positions
        text_line = 0
        for e, entity in enumerate(self.world.entities):
            # geometry
            x, y = entity.state.p_pos
            y *= (
                -1
            )  # this makes the display mimic the old pyglet setup (ie. flips image)
            x = (
                (x / cam_range) * self.width // 2 * 0.9
            )  # the .9 is just to keep entities from appearing "too" out-of-bounds
            y = (y / cam_range) * self.height // 2 * 0.9
            x += self.width // 2
            y += self.height // 2
            pygame.draw.circle(
                self.screen, entity.color * 200, (x, y), entity.size * 350
            )  # 350 is an arbitrary scale factor to get pygame to render similar sizes as pyglet
            pygame.draw.circle(
                self.screen, (0, 0, 0), (x, y), entity.size * 350, 1
            )  # borders
            assert (
                0 < x < self.width and 0 < y < self.height
            ), f"Coordinates {(x, y)} are out of bounds."

            # text
            if isinstance(entity, Agent):
                if entity.silent:
                    continue
                if np.all(entity.state.c == 0):
                    word = "_"
                elif self.continuous_actions:
                    word = (
                        "[" + ",".join([f"{comm:.2f}" for comm in entity.state.c]) + "]"
                    )
                else:
                    word = alphabet[np.argmax(entity.state.c)]

                message = entity.name + " sends " + word + "   "
                message_x_pos = self.width * 0.05
                message_y_pos = self.height * 0.95 - (self.height * 0.05 * text_line)
                self.game_font.render_to(
                    self.screen, (message_x_pos, message_y_pos), message, (0, 0, 0)
                )
                text_line += 1

    def close(self):
        if self.screen is not None:
            pygame.quit()
            self.screen = None

    def next_idx(self):
        current_idx = self._index_map[self.agent_selection]
        next_idx = (current_idx + 1) % self.num_agents
        return next_idx
    
    def mask(self, episode_steps, done):
        mask = 1 if episode_steps < self.max_cycles else float(not done)
        return mask
    
    def add_noise(self):
            
        # self.scenario.all_state_add_noise(self.sigma,self.world)

        self.scenario.agent_state_add_noise(self.sigma,self.world)
            
    def next_state(self, s, a, b):
        next_states = np.empty(s.shape)
        for j in range(s.shape[0]):
            actions = [a[j],b[j]]

            for i, agent in enumerate(self.world.agents):
                action = actions[i]
                scenario_action = []
                if agent.movable:
                    mdim = self.world.dim_p * 2 + 1
                    if self.continuous_actions:
                        scenario_action.append(action[0:mdim])
                        action = action[mdim:]
                    else:
                        scenario_action.append(action % mdim)
                        action //= mdim
                if not agent.silent:
                    scenario_action.append(action)
                self._set_action(scenario_action, agent, self.action_spaces[agent.name])
            
            
            
            # set actions for scripted agents
            for agent in self.world.scripted_agents:
                agent.action = agent.action_callback(agent, world)
            # gather forces applied to entities
            p_force = [None] * len(self.world.entities)
            # apply agent physical controls
            p_force = self.world.apply_action_force(p_force)
            # apply environment forces
            p_force = self.world.apply_environment_force(p_force)
            # integrate physical state
            entity_pos = []
            entity_color = []
            for entity in self.world.landmarks:  # world.entities:
                entity_pos.append(entity.state.p_pos)
                entity_color.append(entity.color)
            # communication of all other agents
            agent_vel = []
            agent_pos = []
            agent_goal = []
            agent_color = []
            for agent in self.world.agents:
                agent_vel.append(agent.state.p_vel)
                agent_pos.append(agent.state.p_pos)
                agent_goal.append(agent.goal_a.state.p_pos)
                agent_color.append(agent.color)
            
            for i, agent in enumerate(self.world.agents):
                agent_pos[i] += agent.state.p_vel * self.world.dt
                agent_vel[i] = agent.state.p_vel * (1 - self.world.damping)
            
                if p_force[i] is not None:
                    agent_vel[i] += (p_force[i] / entity.mass) * self.world.dt
                if agent.max_speed is not None:
                    speed = np.sqrt(
                        np.square(agent_vel[i][0]) + np.square(agent_vel[i][1])
                    )
                    if speed > agent.max_speed:
                        agent_vel[i] = (
                            agent_vel[i]
                            / np.sqrt(
                                np.square(agent_vel[i][0])
                                + np.square(agent_vel[i][1])
                            )
                            * agent.max_speed
                        )
            
            next_states[j] =  np.concatenate(
                agent_vel
                + agent_pos
                + agent_goal
                + agent_color
                + entity_pos
                + entity_color
            )
        return next_states
    
    def get_collision_force(self, entity_a, entity_b, delta_pos):
        if (not entity_a.collide) or (not entity_b.collide):
            return [None, None]  # not a collider
        if entity_a is entity_b:
            return [None, None]  # don't collide against itself
        # compute actual distance between entities
        dist = np.sqrt(np.sum(np.square(delta_pos)))
        # minimum allowable distance
        dist_min = entity_a.size + entity_b.size
        # softmax penetration
        k = self.world.contact_margin
        penetration = np.logaddexp(0, -(dist - dist_min) / k) * k
        force = self.world.contact_force * delta_pos / dist * penetration
        force_a = +force if entity_a.movable else None
        force_b = -force if entity_b.movable else None
        return [force_a, force_b]

    
    def batch_next_state(self,s,a,b,device):
        next_states = s.clone()
        next_states = next_states.detach().cpu()
        actions = [a,b]
        u = []
        for i, agent in enumerate(self.world.agents):
            action = actions[i]
            u0 = action[:,2] - action[:,1]
            u1 = action[:,4] - action[:,3]
            
            sensitivity = 5.0
            if agent.accel is not None:
                sensitivity = agent.accel
            u.append(sensitivity * torch.stack((u0, u1), dim=1))
        p_force = [None] * len(self.world.entities)
        for i, agent in enumerate(self.world.agents):
            if agent.movable:
                p_force[i] = u[i]
            
            for a, agent_a in enumerate(self.world.agents):
                for b, agent_b in enumerate(self.world.agents):
                    if b <= a:
                        continue
                    p_a = next_states[:,4+2*a:6+2*a]
                    p_b = next_states[:,4+2*b:6+2*b]
                    delta_pos = p_a - p_b
                    delta_pos = delta_pos.detach().cpu().numpy()
                    [f_a, f_b] = self.get_collision_force(agent_a, agent_b, delta_pos)
                    if f_a is not None:
                        if p_force[a] is None:
                            p_force[a] = 0.0
                        f_a = torch.tensor(f_a).to(device)
                        p_force[a] = f_a + p_force[a]
                    if f_b is not None:
                        if p_force[b] is None:
                            p_force[b] = 0.0
                        f_b = torch.tensor(f_b).to(device)
                        p_force[b] = f_b + p_force[b]
                        
            for a, agent_a in enumerate(self.world.agents):
                for b, agent_b in enumerate(self.world.landmarks):
                    p_a = next_states[:,4+2*a:6+2*a]
                    p_b = next_states[:,18+2*b:20+2*b]
                    delta_pos = p_a - p_b
                    delta_pos = delta_pos.detach().cpu().numpy()
                    [f_a, f_b] = self.get_collision_force(agent_a, agent_b, delta_pos)
                    if f_a is not None:
                        if p_force[a] is None:
                            p_force[a] = 0.0
                        f_a = torch.tensor(f_a).to(device)
                        p_force[a] = f_a + p_force[a]
                    if f_b is not None:
                        if p_force[b] is None:
                            p_force[b] = 0.0
                        f_b = torch.tensor(f_b).to(device)
                        p_force[b] = f_b + p_force[b]
            
            
        
        
        for i, agent in enumerate(self.world.agents):
            next_states[:,2*i+4] = next_states[:,2*i+4] + next_states[:,2*i] * self.world.dt
            next_states[:,2*i+5] = next_states[:,2*i+5] + next_states[:,2*i+1] * self.world.dt
            next_states[:,2*i] = next_states[:,2*i] * (1 - self.world.damping)
            next_states[:,2*i+1] = next_states[:,2*i+1] * (1 - self.world.damping)
            
            if p_force[i] is not None:
                a = p_force[i][:,0].detach().cpu()
                b = p_force[i][:,1].detach().cpu()
                next_states[:,2*i] = next_states[:,2*i] + (a / agent.mass) * self.world.dt
                next_states[:,2*i+1] = next_states[:,2*i+1] + (b / agent.mass) * self.world.dt
                
            if agent.max_speed is not None:
                states = s.clone()
                states = states.detach().cpu().numpy()
                speed = np.sqrt(
                    np.square(states[:,2*i]) + np.square(states[:,2*i+1])
                )
                next_states[speed>agent.max_speed,2*i] = (
                        next_states[speed>agent.max_speed,2*i]
                        / np.sqrt(
                            np.square(next_states[speed>agent.max_speed,2*i])
                            + np.square(next_states[speed>agent.max_speed,2*i+1])
                        )
                        * agent.max_speed
                    )
                next_states[speed>agent.max_speed,2*i+1] = (
                        next_states[speed>agent.max_speed,2*i]
                        / np.sqrt(
                            np.square(next_states[speed>agent.max_speed,2*i])
                            + np.square(next_states[speed>agent.max_speed,2*i+1])
                        )
                        * agent.max_speed
                    )
                
                
        return next_states
            
            
    