from __future__ import absolute_import, division, print_function

import time
from os import replace
import random
from typing import Tuple, List

import numpy as np
from absl import logging
from harl.envs.smacv2.starcraft2 import StarCraft2Env
from harl.envs.smacv2.wrapper import StarCraftCapabilityEnvWrapper
from pysc2.lib.units import Neutral, Protoss, Terran, Zerg
from harl.envs.smacv2.render import StarCraft2Renderer
from harl.utils.envs_tools import get_relative_direction
from PIL import Image
from typing import List
logging.set_verbosity(logging.DEBUG)
import os.path as osp
from pathlib import Path
import yaml

from gym.spaces import Box, Discrete


class SMACv2Env:
    def __init__(self, args):
        self.renderer = None
        self.map_config = self.load_map_config(args["map_name"])

    def step(self, actions):
        if isinstance(actions[0], Tuple):
            processed_actions = self.extract_action(actions)
        else:
            processed_actions = actions
        reward, terminated, info = self.env.step(processed_actions)
        obs = self.env.get_obs()
        state = self.repeat(self.env.get_state())
        avail_actions = self.env.get_avail_actions()

        # obs state to text
        obs_texts = []
        state_texts = []
        for agent_id in range(self.n_agents):
            obs_texts.append(self.obs2text(obs[agent_id], avail_actions[agent_id], agent_id))
        state_texts = [self.state2text(state[0], avail_actions)] * self.n_agents

        rewards = [[reward]] * self.n_agents
        dones = [terminated] * self.n_agents
        if terminated:
            if self.env.env.timeouts > self.timeouts:
                assert (
                    self.env.env.timeouts - self.timeouts == 1
                ), "Change of timeouts unexpected."
                info["bad_transition"] = True
                self.timeouts = self.env.env.timeouts
        infos = [info] * self.n_agents
        # Render the current game state
        image, cropped_images = self.renderer.render(mode="rgb_array")
        # Convert the numpy array to a PIL Image
        # full_img = Image.fromarray(image)
        # Convert the list of cropped numpy arrays to PIL Images
        # cropped_imgs = [Image.fromarray(cropped_image) for cropped_image in cropped_images]
        return obs, state, rewards, dones, infos, avail_actions, cropped_images, self.repeat(image), obs_texts, state_texts
    
    def extract_action(self, actions: List[Tuple[str, dict]]) -> List[int]:
        """Extract and process text-based action commands into numeric action indices.

        Args:
            actions (List[Tuple[str, dict]]): List of tuples containing action commands and parameters.
                Each tuple has format (action_name, params_dict) where:
                - action_name (str): Name of the action (e.g. "move_north", "attack", "heal")
                - params_dict (dict): Parameters for the action with keys:
                    - "unit_id" (int): ID of the unit performing the action
                    - "target_id" (int): ID of target unit for attack/heal actions

        Returns:
            List[int]: List of processed numeric action indices that can be passed to the environment.
                The indices map to:
                - 0: no-op
                - 1: stop
                - 2: move north
                - 3: move south  
                - 4: move east
                - 5: move west
                - 6+: attack/heal actions (target unit ID + n_actions_no_attack)
        """

        processed_actions = []
        
        for agent_action in actions:
            action_name, params = agent_action
            unit_id = params["unit_id"]
            avail_actions = self.env.env.get_avail_agent_actions(unit_id)
            action = None
            
            if action_name == "noop":
                action = 0  # noop action index
                
            elif action_name == "stop":
                action = 1  # stop action index
                
            elif action_name == "move_north":
                action = 2  # move north action index
                
            elif action_name == "move_south":
                action = 3  # move south action index
                
            elif action_name == "move_east":
                action = 4  # move east action index
                
            elif action_name == "move_west":
                action = 5  # move west action index
                
            elif action_name == "attack":
                action = params["target_id"] + self.env.n_actions_no_attack
                
            elif action_name == "heal":
                action = params["target_id"] + self.env.n_actions_no_attack
            
            if action is not None and avail_actions[action] == 1:
                processed_actions.append(action)
            else:
                available_indices = [i for i, x in enumerate(avail_actions) if x == 1]
                random_action = random.choice(available_indices)
                processed_actions.append(random_action)

        # Execute actions using the environment's step function
        return processed_actions

    def reset(self):
        self.env.reset()
        if hasattr(self.env.env, 'ally_unit_map'):
            self.ally_unit_type_to_name = {v: k for k, v in self.env.env.ally_unit_map.items()}
            self.enemy_unit_type_to_name = {v: k for k, v in self.env.env.enemy_unit_map.items()}
        else:
            self.ally_unit_type_to_name = self.env.env.id_to_unit_name_map
            self.enemy_unit_type_to_name = self.env.env.id_to_unit_name_map
        self.renderer = StarCraft2Renderer(self.env, mode="rgb_array")
        obs = self.env.get_obs()
        state = self.repeat(self.env.get_state())
        avail_actions = self.env.get_avail_actions()
        image, cropped_images = self.renderer.render(mode="rgb_array")
        # full_img = Image.fromarray(image)
        # cropped_imgs = [Image.fromarray(cropped_image) for cropped_image in cropped_images]
        obs_texts = []
        state_texts = []
        for agent_id in range(self.n_agents):
            obs_texts.append(self.obs2text(obs[agent_id], avail_actions[agent_id], agent_id))
        state_texts = [self.state2text(state[0], avail_actions)] * self.n_agents
        return obs, state, avail_actions, cropped_images, self.repeat(image), obs_texts, state_texts

    def seed(self, seed):
        self.env = StarCraftCapabilityEnvWrapper(seed=seed, **self.map_config)
        env_info = self.env.get_env_info()
        n_actions = env_info["n_actions"]
        state_shape = env_info["state_shape"]
        obs_shape = env_info["obs_shape"]
        self.n_agents = env_info["n_agents"]
        self.timeouts = self.env.env.timeouts

        self.share_observation_space = self.repeat(
            Box(low=-np.inf, high=np.inf, shape=(state_shape,))
        )
        self.observation_space = self.repeat(
            Box(low=-np.inf, high=np.inf, shape=(obs_shape,))
        )
        self.action_space = self.repeat(Discrete(n_actions))

    def close(self):
        self.env.close()
        if self.renderer is not None:
            self.renderer.close()

    def load_map_config(self, map_name):
        base_path = osp.split(osp.split(osp.dirname(osp.abspath(__file__)))[0])[0]
        map_config_path = (
            Path(base_path)
            / "configs"
            / "envs_cfgs"
            / "smacv2_map_config"
            / f"{map_name}.yaml"
        )
        with open(str(map_config_path), "r", encoding="utf-8") as file:
            map_config = yaml.load(file, Loader=yaml.FullLoader)
        return map_config

    def repeat(self, a):
        return [a for _ in range(self.n_agents)]
    
    def _get_action_name(self, action_id: int, unit_id: int) -> str:
        unit = self.env.env.get_unit_by_id(unit_id)
        action_text = ""
        if action_id == 0:
            action_text = "noop"
        elif action_id == 1:
            action_text = f"stop, {{'unit_id': {unit_id}}}"
        elif action_id == 2:
            action_text = f"move_north, {{'unit_id': {unit_id}}}"
        elif action_id == 3:
            action_text = f"move_south, {{'unit_id': {unit_id}}}"
        elif action_id == 4:
            action_text = f"move_east, {{'unit_id': {unit_id}}}"
        elif action_id == 5:
            action_text = f"move_west, {{'unit_id': {unit_id}}}"
        elif self.env.env.conic_fov and action_id in range(6, 6 + self.env.env.n_fov_actions):
            action_text = f"change field of view direction {action_id - 6}"
        else:
            target_id = action_id - self.env.env.n_actions_no_attack
            if self.env.env.map_type in ["MMM", "terran_gen"] and unit.unit_type == self.env.env.medivac_id:
                action_text = f"heal, {{'unit_id': {unit_id}, 'target_id': {target_id}}}"
            else:
                action_text = f"attack, {{'unit_id': {unit_id}, 'target_id': {target_id}}}"
        return action_text
    
    def get_action_names(self, actions: np.ndarray) -> np.ndarray:
        action_texts = []
        for agent_id in range(actions.shape[0]):
            action_text = self._get_action_name(int(actions[agent_id]), agent_id)
            action_texts.append(action_text)
        return action_texts
            
    
    def obs2text(self, agent_obs: np.ndarray, available_actions: np.ndarray, agent_id: int) -> str:
        """Convert observation array to text format for an agent.
        
        Args:
            agent_obs (np.ndarray): Observation array for the agent containing information about
                movement possibilities, enemies and allies
            available_actions (np.ndarray): Array indicating which actions are available
            agent_id (int): ID of the agent whose observation is being converted
        
        Returns:
            str: Text description of the agent's observation including movement options,
                visible enemies with their stats, and ally information
        """
        text = "Current Observation:\n"
        
        # Track the current index in the observation array
        idx = 0
        
        # 1. Movement Information
        text += "1. Movement Information:\n"
        move_actions = ["North", "South", "East", "West"]
        for i in range(4):  # First 4 elements are movement actions
            text += f"- Can move {move_actions[i]}: {'yes' if agent_obs[i] == 1 else 'no'}\n"
        idx += 4
        
        # Skip pathing and terrain information if present
        if self.env.env.obs_pathing_grid:
            idx += self.env.env.n_obs_pathing
        if self.env.env.obs_terrain_height:
            idx += self.env.env.n_obs_height
        
        # 2. Enemy Units Information
        text += "\n2. Enemy Units Information:\n"
        enemy_feats_length = 4 + self.env.env.unit_type_bits  # base features: available, distance, rel_x, rel_y
        if self.env.env.obs_all_health:
            enemy_feats_length += 1 + self.env.env.shield_bits_enemy
            
        for e_id in range(self.env.env.n_enemies):
            start_idx = idx + e_id * enemy_feats_length
            if any(agent_obs[start_idx:start_idx + enemy_feats_length] != 0):  # If enemy is visible
                text += f"- Enemy #{e_id}:\n"
                text += f"  * Can be attacked: {'yes' if agent_obs[start_idx] == 1 else 'no'}\n"
                text += f"  * Distance: {agent_obs[start_idx + 1]} units\n"
                text += f"  * Relative position: ({agent_obs[start_idx + 2]}, {agent_obs[start_idx + 3]})\n"
                direction = get_relative_direction(agent_obs[start_idx + 2], agent_obs[start_idx + 3])
                text += f"  * Relative direction: {direction}\n"
                feat_idx = 4
                if self.env.env.obs_all_health:
                    text += f"  * Health: {agent_obs[start_idx + feat_idx]:%}\n"
                    feat_idx += 1
                    if self.env.env.shield_bits_enemy > 0:
                        text += f"  * Shield: {agent_obs[start_idx + feat_idx]:%}\n"
                        feat_idx += 1
                if self.env.env.unit_type_bits > 0:
                    unit_type = np.argmax(agent_obs[start_idx + feat_idx:start_idx + feat_idx + self.env.env.unit_type_bits])
                    unit_name = self.get_unit_name(unit_type, ally=False)
                    text += f"  * Unit type: {unit_name}\n"
        
        idx += self.env.env.n_enemies * enemy_feats_length
        
        # 3. Ally Units Information
        text += "\n3. Ally Units Information:\n"
        ally_feats_length = 4  # base features: visible, distance, rel_x, rel_y
        if self.env.env.obs_all_health:
            ally_feats_length += 1 + self.env.env.shield_bits_ally

        if self.env.env.stochastic_attack and (
            self.env.env.zero_pad_stochastic_attack or self.env.env.observe_attack_probs
        ):
            ally_feats_length += 1
        if self.env.env.stochastic_health and (
            self.env.env.observe_teammate_health or self.env.env.zero_pad_health
        ):
            ally_feats_length += 1

        if self.env.env.unit_type_bits > 0:
            ally_feats_length += self.env.env.unit_type_bits
        if self.env.env.obs_last_action:
            ally_feats_length += self.env.env.n_actions
        
        al_ids = [
                al_id for al_id in range(self.env.env.n_agents) if al_id != agent_id
            ]
        
        for i, al_id in enumerate(al_ids):  # Excluding self
            start_idx = idx + i * ally_feats_length
            if agent_obs[start_idx] == 1:  # If ally is visible
                text += f"- Ally #{al_id}:\n"
                text += f"  * Distance: {agent_obs[start_idx + 1]} units\n"
                text += f"  * Relative position: ({agent_obs[start_idx + 2]}, {agent_obs[start_idx + 3]})\n"
                direction = get_relative_direction(agent_obs[start_idx + 2], agent_obs[start_idx + 3])
                text += f"  * Relative direction: {direction}\n"
                feat_idx = 4
                if self.env.env.obs_all_health:
                    text += f"  * Health: {agent_obs[start_idx + feat_idx]:%}\n"
                    feat_idx += 1
                    if self.env.env.shield_bits_ally > 0:
                        text += f"  * Shield: {agent_obs[start_idx + feat_idx]:%}\n"
                        feat_idx += 1
                if self.env.env.stochastic_attack and self.env.env.observe_attack_probs:
                    text += f"  * Attack probability: {agent_obs[start_idx + feat_idx]}\n"
                    feat_idx += 1
                if self.env.env.stochastic_health and self.env.env.observe_teammate_health:
                    text += f"  * Health level: {agent_obs[start_idx + feat_idx]}\n"
                    feat_idx += 1
                if self.env.env.unit_type_bits > 0:
                    unit_type = np.argmax(agent_obs[start_idx + feat_idx:start_idx + feat_idx + self.env.env.unit_type_bits])
                    unit_name = self.get_unit_name(unit_type, ally=True)
                    text += f"  * Unit type: {unit_name}\n"
                    feat_idx += self.env.env.unit_type_bits
                if self.env.env.obs_last_action:
                    last_action = np.argmax(agent_obs[start_idx + feat_idx:start_idx + feat_idx + self.env.env.n_actions])
                    action_text = self._get_action_name(last_action, al_id)
                    text += f"  * Last action: {action_text}\n"
        
        idx += (self.env.env.n_agents - 1) * ally_feats_length
        
        # 4. Own Unit Information
        sight_range = self.env.env.unit_sight_range(agent_id)
        shoot_range = self.env.env.unit_shoot_range(agent_id)
        text += "\n4. Own Unit Information:\n"
        if self.env.env.obs_own_health:
            text += f"- Health: {agent_obs[idx]:%}\n"
            idx += 1
            if self.env.env.shield_bits_ally > 0:
                text += f"- Shield: {agent_obs[idx]:%}\n"
                idx += 1
        if sight_range:
            text += f"- Sight range: {sight_range} units\n"
        if shoot_range:
            text += f"- Shoot range: {shoot_range} units\n"
        if self.env.env.stochastic_attack:
            text += f"- Attack probability: {agent_obs[idx]}\n"
            idx += 1
        if self.env.env.stochastic_health:
            text += f"- Health level: {agent_obs[idx]}\n"
            idx += 1
        if self.env.env.obs_own_pos:
            text += f"- Position: ({agent_obs[idx]}, {agent_obs[idx + 1]})\n"
            idx += 2
        if self.env.env.conic_fov:
            text += f"- Field of view direction: ({agent_obs[idx]}, {agent_obs[idx + 1]})\n"
            idx += 2
        if self.env.env.unit_type_bits > 0:
            unit_type = np.argmax(agent_obs[idx:idx + self.env.env.unit_type_bits])
            unit_name = self.get_unit_name(unit_type, ally=True)
            text += f"- Unit type: {unit_name}\n"
            idx += self.env.env.unit_type_bits
        if self.env.env.obs_last_action:
            last_action = np.argmax(self.env.env.last_action[agent_id])
            action_text = self._get_action_name(last_action, agent_id)
            text += f"  * Last action: {action_text}\n"
        
        # 5. Available Actions
        text += "\n5. Available Actions:\n"
        for i in range(len(available_actions)):
            text += f"- {self._get_action_name(i, agent_id)}: {'yes' if available_actions[i] == 1 else 'no'}\n"
        
        # 6. Time Information
        if self.env.env.obs_timestep_number:
            text += f"\n6. Time Information:\n"
            text += f"- Current timestep: {agent_obs[-1]:%} of episode limit\n"
        
        return text

    def state2text(self, state: np.ndarray, available_actions: np.ndarray) -> str:
        """Convert global state array to text format
        
        Args:
            state: global state array from env
        
        Returns:
            str: Text formatted global state
        """
        # If using observations instead of state
        if self.env.env.obs_instead_of_state:
            text = "Global State (Combined Observations):\n"
            obs_size = self.env.env.get_obs_size()
            for agent_id in range(self.env.env.n_agents):
                text += f"\nAgent {agent_id} Observation:\n"
                agent_obs = state[agent_id * obs_size:(agent_id + 1) * obs_size]
                text += self.obs2text(agent_obs, self.env.env)
            return text
        
        text = "Global State:\n"
        idx = 0
        
        # 1. Ally Information
        text += "1. Ally Units:\n"
        ally_feats = self.env.env.get_ally_num_attributes()
        for al_id in range(self.env.env.n_agents):
            start_idx = idx + al_id * ally_feats
            text += f"- Ally #{al_id}:\n"
            feat_idx = 0
            
            # Health
            text += f"  * Health: {state[start_idx + feat_idx]:%}\n"
            feat_idx += 1
            
            # Cooldown/Energy
            text += f"  * {'Energy' if self.env.env.map_type in ['MMM', 'terran_gen'] else 'Cooldown'}: {state[start_idx + feat_idx]}\n"
            feat_idx += 1
            
            # Relative position
            text += f"  * Relative position to map center: ({state[start_idx + feat_idx]}, {state[start_idx + feat_idx + 1]})\n"
            feat_idx += 2
            
            # Shield if present
            if self.env.env.shield_bits_ally > 0:
                text += f"  * Shield: {state[start_idx + feat_idx]:%}\n"
                feat_idx += 1
                
            # Stochastic attack probability
            if self.env.env.stochastic_attack:
                text += f"  * Attack probability: {state[start_idx + feat_idx]}\n"
                feat_idx += 1
                
            # Stochastic health level
            if self.env.env.stochastic_health:
                text += f"  * Health level: {state[start_idx + feat_idx]}\n"
                feat_idx += 1
                
            # Field of view direction
            if self.env.env.conic_fov:
                text += f"  * Field of view: ({state[start_idx + feat_idx]}, {state[start_idx + feat_idx + 1]})\n"
                feat_idx += 2
                
            # Unit type
            if self.env.env.unit_type_bits > 0:
                unit_type = np.argmax(state[start_idx + feat_idx:start_idx + feat_idx + self.env.env.unit_type_bits])
                unit_name = self.get_unit_name(unit_type, ally=True)
                text += f"  * Unit type: {unit_name}\n"
        
        idx += self.env.env.n_agents * ally_feats
        
        # 2. Enemy Information
        text += "\n2. Enemy Units:\n"
        enemy_feats = self.env.env.get_enemy_num_attributes()
        for e_id in range(self.env.env.n_enemies):
            start_idx = idx + e_id * enemy_feats
            text += f"- Enemy #{e_id}:\n"
            feat_idx = 0
            
            # Health
            text += f"  * Health: {state[start_idx + feat_idx]:%}\n"
            feat_idx += 1
            
            # Relative position
            text += f"  * Relative position: ({state[start_idx + feat_idx]}, {state[start_idx + feat_idx + 1]})\n"
            feat_idx += 2
            
            # Shield if present
            if self.env.env.shield_bits_enemy > 0:
                text += f"  * Shield: {state[start_idx + feat_idx]:%}\n"
                feat_idx += 1
                
            # Unit type
            if self.env.env.unit_type_bits > 0:
                unit_type = np.argmax(state[start_idx + feat_idx:start_idx + feat_idx + self.env.env.unit_type_bits])
                unit_name = self.get_unit_name(unit_type, ally=False)
                text += f"  * Unit type: {unit_name}\n"
        
        idx += self.env.env.n_enemies * enemy_feats
        
        # 3. Last Actions
        if self.env.env.state_last_action:
            text += "\n3. Last Actions:\n"
            for al_id in range(self.env.env.n_agents):
                action = state[idx + al_id]
                action_text = self._get_action_name(action, al_id)
                text += f"- Agent {al_id}: {action_text}\n"
            idx += self.env.env.n_agents
        
        # 5. Available Actions
        text += "\n4. Available Actions:\n"
        for agent_id in range(self.env.env.n_agents):
            text += f"- Agent {agent_id}:\n"
            for i in range(len(available_actions[agent_id])):
                text += f"  * {self._get_action_name(i, agent_id)}: {'yes' if available_actions[agent_id][i] == 1 else 'no'}\n"
            
        # 4. Time Information
        if self.env.env.state_timestep_number:
            text += "\n5. Time Information:\n"
            text += f"- Current timestep: {state[idx]:%} of episode limit\n"
        
        return text
    
    def get_unit_name(self, unit_id: int, ally) -> str:
        unit_type = self.get_id_unit_type(unit_id, ally)
        if ally:
            return self.ally_unit_type_to_name[unit_type]
        else:
            return self.enemy_unit_type_to_name[unit_type]
    
    def get_id_unit_type(self, type_id, ally):
        """Returns the unit type for given type ID in the scenario.
        
        Args:
            type_id (int): Unit type ID
            ally (bool): Whether unit is ally
            
        Returns:
            unit_type: Unit type enum value
        """
        if self.env.env.map_type == "protoss_gen":
            if type_id == 0:
                return self.env.env.stalker_id if ally else Protoss.Stalker
            if type_id == 1:
                return self.env.env.zealot_id if ally else Protoss.Zealot
            if type_id == 2:
                return self.env.env.colossus_id if ally else Protoss.Colossus
            raise AttributeError(f"Invalid type_id {type_id} for protoss")

        if self.env.env.map_type == "terran_gen":
            if type_id == 0:
                return self.env.env.marine_id if ally else Terran.Marine
            if type_id == 1:
                return self.env.env.marauder_id if ally else Terran.Marauder
            if type_id == 2:
                return self.env.env.medivac_id if ally else Terran.Medivac
            raise AttributeError(f"Invalid type_id {type_id} for terran")

        if self.env.env.map_type == "zerg_gen":
            if type_id == 0:
                return self.env.env.zergling_id if ally else Zerg.Zergling
            if type_id == 1:
                return self.env.env.hydralisk_id if ally else Zerg.Hydralisk
            if type_id == 2:
                return self.env.env.baneling_id if ally else Zerg.Baneling
            raise AttributeError(f"Invalid type_id {type_id} for zerg")
        
        if ally:
            unit_type = type_id + self.env.env._min_unit_type

        if self.env.env.map_type == "stalkers_and_zealots":
            # id(Stalker) = 74, id(Zealot) = 73
            unit_type = type_id + 73
        elif self.env.env.map_type == "colossi_stalkers_zealots":
            # id(Stalker) = 74, id(Zealot) = 73, id(Colossus) = 4
            if type_id == 0:
                unit_type = 4
            elif type_id == 1:
                unit_type = 74
        elif self.env.env.map_type == "bane":
            if type_id == 0:
                unit_type = 9
        elif self.env.env.map_type == "MMM":
            if type_id == 0:
                unit_type ==51
            elif type_id == 1:
                unit_type = 48

        raise AttributeError(f"Invalid map_type {self.env.env.map_type}")
    
    def get_unit_types(self) -> List[str]:
        """Get the unit types for all agents in the environment
        
        Returns:
            List[str]: List of unit types for all agents
        """
        unit_types = []
        for al_id in range(self.env.env.n_agents):
            unit = self.env.env.get_unit_by_id(al_id)
            unit_types.append(self.ally_unit_type_to_name[unit.unit_type])
        return unit_types