"""Tools for HARL."""
import os
import random
import numpy as np
import torch
import math
from harl.envs.env_wrappers import ShareSubprocVecEnv, ShareDummyVecEnv


def check(value):
    """Check if value is a numpy array, if so, convert it to a torch tensor."""
    output = torch.from_numpy(value) if isinstance(value, np.ndarray) else value
    return output


def get_shape_from_obs_space(obs_space):
    """Get shape from observation space.
    Args:
        obs_space: (gym.spaces or list) observation space
    Returns:
        obs_shape: (tuple) observation shape
    """
    if obs_space.__class__.__name__ == "Box":
        obs_shape = obs_space.shape
    elif obs_space.__class__.__name__ == "list":
        obs_shape = obs_space
    else:
        raise NotImplementedError
    return obs_shape


def get_shape_from_act_space(act_space):
    """Get shape from action space.
    Args:
        act_space: (gym.spaces) action space
    Returns:
        act_shape: (tuple) action shape
    """
    if act_space.__class__.__name__ == "Discrete":
        act_shape = 1
    elif act_space.__class__.__name__ == "MultiDiscrete":
        act_shape = act_space.shape[0]
    elif act_space.__class__.__name__ == "Box":
        act_shape = act_space.shape[0]
    elif act_space.__class__.__name__ == "MultiBinary":
        act_shape = act_space.shape[0]
    return act_shape


def make_train_env(env_name, seed, n_threads, env_args):
    """Make env for training."""
    if env_name == "dexhands":
        from harl.envs.dexhands.dexhands_env import DexHandsEnv

        return DexHandsEnv({"n_threads": n_threads, **env_args})

    def get_env_fn(rank):
        def init_env():
            if env_name == "smac":
                from harl.envs.smac.StarCraft2_Env import StarCraft2Env

                env = StarCraft2Env(env_args)
            elif env_name == "smacv2":
                from harl.envs.smacv2.smacv2_env import SMACv2Env

                env = SMACv2Env(env_args)
            elif env_name == "mamujoco":
                from harl.envs.mamujoco.multiagent_mujoco.mujoco_multi import (
                    MujocoMulti,
                )

                env = MujocoMulti(env_args=env_args)
            elif env_name == "pettingzoo_mpe":
                from harl.envs.pettingzoo_mpe.pettingzoo_mpe_env import (
                    PettingZooMPEEnv,
                )

                assert env_args["scenario"] in [
                    "simple_v2",
                    "simple_spread_v2",
                    "simple_reference_v2",
                    "simple_speaker_listener_v3",
                ], "only cooperative scenarios in MPE are supported"
                env = PettingZooMPEEnv(env_args)
            elif env_name == "gym":
                from harl.envs.gym.gym_env import GYMEnv

                env = GYMEnv(env_args)
            elif env_name == "football":
                from harl.envs.football.football_env import FootballEnv

                env = FootballEnv(env_args)
            elif env_name == "lag":
                from harl.envs.lag.lag_env import LAGEnv

                env = LAGEnv(env_args)
            else:
                print("Can not support the " + env_name + "environment.")
                raise NotImplementedError
            env.seed(seed + rank * 1000)
            return env

        return init_env

    if n_threads == 1:
        return ShareDummyVecEnv([get_env_fn(0)])
    else:
        return ShareSubprocVecEnv([get_env_fn(i) for i in range(n_threads)])


def make_eval_env(env_name, seed, n_threads, env_args):
    """Make env for evaluation."""
    if env_name == "dexhands":  # dexhands does not support running multiple instances
        raise NotImplementedError

    def get_env_fn(rank):
        def init_env():
            if env_name == "smac":
                from harl.envs.smac.StarCraft2_Env import StarCraft2Env

                env = StarCraft2Env(env_args)
            elif env_name == "smacv2":
                from harl.envs.smacv2.smacv2_env import SMACv2Env

                env = SMACv2Env(env_args)
            elif env_name == "mamujoco":
                from harl.envs.mamujoco.multiagent_mujoco.mujoco_multi import (
                    MujocoMulti,
                )

                env = MujocoMulti(env_args=env_args)
            elif env_name == "pettingzoo_mpe":
                from harl.envs.pettingzoo_mpe.pettingzoo_mpe_env import (
                    PettingZooMPEEnv,
                )

                env = PettingZooMPEEnv(env_args)
            elif env_name == "gym":
                from harl.envs.gym.gym_env import GYMEnv

                env = GYMEnv(env_args)
            elif env_name == "football":
                from harl.envs.football.football_env import FootballEnv

                env = FootballEnv(env_args)
            elif env_name == "lag":
                from harl.envs.lag.lag_env import LAGEnv

                env = LAGEnv(env_args)
            else:
                print("Can not support the " + env_name + "environment.")
                raise NotImplementedError
            env.seed(seed * 50000 + rank * 10000)
            return env

        return init_env

    if n_threads == 1:
        return ShareDummyVecEnv([get_env_fn(0)])
    else:
        return ShareSubprocVecEnv([get_env_fn(i) for i in range(n_threads)])


def make_render_env(env_name, seed, env_args):
    """Make env for rendering."""
    manual_render = True  # manually call the render() function
    manual_expand_dims = True  # manually expand the num_of_parallel_envs dimension
    manual_delay = True  # manually delay the rendering by time.sleep()
    env_num = 1  # number of parallel envs
    if env_name == "smac":
        from harl.envs.smac.StarCraft2_Env import StarCraft2Env

        env = StarCraft2Env(args=env_args)
        manual_render = (
            False  # smac does not support manually calling the render() function
        )
        # instead, it use save_replay()
        manual_delay = False
        env.seed(seed * 60000)
    elif env_name == "smacv2":
        from harl.envs.smacv2.smacv2_env import SMACv2Env

        env = SMACv2Env(args=env_args)
        manual_render = False
        manual_delay = False
        env.seed(seed * 60000)
    elif env_name == "mamujoco":
        from harl.envs.mamujoco.multiagent_mujoco.mujoco_multi import MujocoMulti

        env = MujocoMulti(env_args=env_args)
        env.seed(seed * 60000)
    elif env_name == "pettingzoo_mpe":
        from harl.envs.pettingzoo_mpe.pettingzoo_mpe_env import PettingZooMPEEnv

        env = PettingZooMPEEnv({**env_args, "render_mode": "human"})
        env.seed(seed * 60000)
    elif env_name == "gym":
        from harl.envs.gym.gym_env import GYMEnv

        env = GYMEnv(env_args)
        env.seed(seed * 60000)
    elif env_name == "football":
        from harl.envs.football.football_env import FootballEnv

        env = FootballEnv(env_args)
        manual_render = False  # football renders automatically
        env.seed(seed * 60000)
    elif env_name == "dexhands":
        from harl.envs.dexhands.dexhands_env import DexHandsEnv

        env = DexHandsEnv({"n_threads": 64, **env_args})
        manual_render = False  # dexhands renders automatically
        manual_expand_dims = (
            False  # dexhands uses parallel envs, thus dimension is already expanded
        )
        manual_delay = False
        env_num = 64
    elif env_name == "lag":
        from harl.envs.lag.lag_env import LAGEnv

        env = LAGEnv(env_args)
        env.seed(seed * 60000)
    else:
        print("Can not support the " + env_name + "environment.")
        raise NotImplementedError
    return env, manual_render, manual_expand_dims, manual_delay, env_num


def set_seed(args):
    """Seed the program."""
    if not args["seed_specify"]:
        args["seed"] = np.random.randint(1000, 10000)
    random.seed(args["seed"])
    np.random.seed(args["seed"])
    os.environ["PYTHONHASHSEED"] = str(args["seed"])
    torch.manual_seed(args["seed"])
    torch.cuda.manual_seed(args["seed"])
    torch.cuda.manual_seed_all(args["seed"])


def get_num_agents(env, env_args, envs):
    """Get the number of agents in the environment."""
    if env == "smac":
        from harl.envs.smac.smac_maps import get_map_params

        return get_map_params(env_args["map_name"])["n_agents"]
    elif env == "smacv2":
        return envs.n_agents
    elif env == "mamujoco":
        return envs.n_agents
    elif env == "pettingzoo_mpe":
        return envs.n_agents
    elif env == "gym":
        return envs.n_agents
    elif env == "football":
        return envs.n_agents
    elif env == "dexhands":
        return envs.n_agents
    elif env == "lag":
        return envs.n_agents

def get_relative_direction(pos_x: float, pos_y: float) -> str:
        """
        Convert relative position coordinates to cardinal/intercardinal direction.
        Args:
            pos_x: x coordinate
            pos_y: y coordinate
        Returns:
            String describing the relative direction
        """
        angle = math.degrees(math.atan2(pos_y, pos_x))
        
        # Normalize angle to 0-360
        angle = (angle + 360) % 360
        
        # Map angles to directions
        if 22.5 <= angle < 67.5:
            return "Northeast"
        elif 67.5 <= angle < 112.5:
            return "North"
        elif 112.5 <= angle < 157.5:
            return "Northwest"
        elif 157.5 <= angle < 202.5:
            return "West"
        elif 202.5 <= angle < 247.5:
            return "Southwest"
        elif 247.5 <= angle < 292.5:
            return "South"
        elif 292.5 <= angle < 337.5:
            return "Southeast"
        else:
            return "East"

def update_positions_from_obs(global_memory, env_idx: int, obs_text: str, agent_id: int, timestep: int) -> None:
    """Extract positions from observation text, update global memory
    
    Args:
        env_idx: Index of the environment
        obs_text: Formatted observation text from obs2text()
        agent_id: ID of the agent whose observation this is
        timestep: Current timestep
    
    Returns:
        None
    """
    sections = obs_text.split('\n\n')
    # Parse own info
    own_section = next(s for s in sections if "Own Unit Information:" in s)
    own_health = float(own_section.split('Health:')[1].split('\n')[0].strip()[:-1]) / 100
    
    own_shield = 0
    if 'Shield:' in own_section:
        own_shield = float(own_section.split('Shield:')[1].split('\n')[0].strip()[:-1]) / 100
    
    own_sight_range = None
    if 'Sight range:' in own_section:
        own_sight_range = float(own_section.split('Sight range:')[1].split('units')[0].strip())

    own_shoot_range = None
    if 'Shoot range:' in own_section:
        own_shoot_range = float(own_section.split('Shoot range:')[1].split('units')[0].strip())

    own_pos = None
    if 'Position:' in own_section:
        pos_str = own_section.split('Position:')[1].split(')')[0].strip('( ')
        x, y = map(float, pos_str.split(','))
        own_pos = (x, y)

    own_type = None
    if 'Unit type:' in own_section:
        own_type = own_section.split('Unit type:')[1].split('\n')[0].strip()
    
    own_info = {
        "health": own_health,
        "shield": own_shield,
        "sight_range": own_sight_range,
        "shoot_range": own_shoot_range,
        "pos": own_pos,
        "type": own_type
    }

    global_memory.update_self_position(
        env_id=env_idx,
        agent_id=agent_id,
        own_info=own_info,
        timestep=timestep
    )

    lines = obs_text.split('\n')
    current_section = None
    
    for line in lines:
        # Track which section we're in
        if "Enemy Units Information" in line:
            current_section = "enemy"
            continue
        elif "Ally Units Information" in line:
            current_section = "ally"
            continue
        elif "Own Unit Information" in line:
            break
            
        # Process positions if we're in a relevant section
        if current_section == "enemy" and "Enemy #" in line:
            # Extract enemy info
            enemy_id = int(line.split('Enemy #')[1].split(':')[0])
            
            # Look ahead for relative position
            for next_line in lines[lines.index(line):lines.index(line)+8]:
                if "Relative position:" in next_line:
                    pos_str = next_line.split('(')[1].split(')')[0]
                    x, y = map(float, pos_str.split(', '))
            
                if 'Health:' in next_line:
                    health = float(next_line.split('Health:')[1].split('%')[0].strip()) / 100
                
                shield = 0
                if 'Shield:' in next_line:
                    shield = float(next_line.split('Shield:')[1].split('%')[0].strip()) / 100

                # Look ahead for unit type
                if "Unit type:" in next_line:
                    unit_type = next_line.split(': ')[1].strip()
            
            # Update enemy position
            global_memory.update_unit_position(
                env_id=env_idx,
                agent_id=agent_id,
                own_info=own_info,
                unit_id=enemy_id,
                unit_type="enemy",
                race=unit_type,  # Using unit type as race
                relative_pos=(x, y),
                health=health,
                shield=shield,
                timestep=timestep
            )
            
        elif current_section == "ally" and "Ally #" in line:
            # Extract ally info
            ally_id = int(line.split('Ally #')[1].split(':')[0])
            
            # Look ahead for relative position
            for next_line in lines[lines.index(line):lines.index(line)+8]:
                if "Relative position:" in next_line:
                    pos_str = next_line.split('(')[1].split(')')[0]
                    x, y = map(float, pos_str.split(', '))

                if 'Health:' in next_line:
                    health = float(next_line.split('Health:')[1].split('%')[0].strip()) / 100
                
                shield = 0
                if 'Shield:' in next_line:
                    shield = float(next_line.split('Shield:')[1].split('%')[0].strip()) / 100

                # Look ahead for unit type
                if "Unit type:" in next_line:
                    unit_type = next_line.split(': ')[1].strip()
            
            # Update ally position
            global_memory.update_unit_position(
                env_id=env_idx,
                agent_id=agent_id,
                own_info=own_info,
                unit_id=ally_id,
                unit_type="ally",
                race=unit_type,  # Using unit type as race
                relative_pos=(x, y),
                health=health,
                shield=shield,
                timestep=timestep
            )

def get_ego_minimap_text(global_memory, local_memory, env_idx: int, agent_id: int, timestep: int) -> str:
        """Generate text description of units from agent's perspective"""
        own_unit_type = local_memory.unit_type
        agent_data = global_memory.recent_history[env_idx]["agent_knowledge"].get(agent_id, {})
        
        text = f"Ego-centric Minimap for Agent {agent_id} ({own_unit_type}):\n"
        
        # Process allies
        if agent_data.get("ally"):
            text += "1. Visible Allies:\n"
            for unit_id, info in sorted(agent_data["ally"].items()):
                if unit_id == agent_id or info['health'] == 0:  # Skip self
                    continue
                
                obs_type = "Directly observed" if info.get("directly", False) else "Shared information"
                text += f"  - Ally #{unit_id} ({info['race']}) [{obs_type}]:\n"
                text += f"    * Health: {info['health']:%}\n"
                text += f"    * Shield: {info['shield']:%}\n"
                text += f"    * Relative position: ({info['pos'][0]}, {info['pos'][1]})\n"
                direction = get_relative_direction(info['pos'][0], info['pos'][1])
                text += f"    * Relative direction: {direction}\n"
                text += f"    * Last seen: {timestep - info['last_seen']} timesteps ago\n"
        
        # Process enemies
        if agent_data.get("enemy"):
            text += "2. Visible Enemies:\n"
            for unit_id, info in sorted(agent_data["enemy"].items()):
                if info['health'] == 0:
                    continue
                obs_type = "Directly observed" if info.get("directly", False) else "Shared information"
                text += f"  - Enemy #{unit_id} ({info['race']}) [{obs_type}]:\n"
                text += f"    * Health: {info['health']:%}\n"
                text += f"    * Shield: {info['shield']:%}\n"
                text += f"    * Relative position: ({info['pos'][0]}, {info['pos'][1]})\n"
                direction = get_relative_direction(info['pos'][0], info['pos'][1])
                text += f"    * Relative direction: {direction}\n"
                text += f"    * Last seen: {timestep - info['last_seen']} timesteps ago\n"
        
        if not agent_data.get("ally") and not agent_data.get("enemy"):
            text += "No units currently visible.\n"
        
        return text