"""SMAC is slow, has a lot of unnessary cruft, and doesn't really fit
what I need, so I will start from scratch..."""

import enum
from dataclasses import dataclass, field
from operator import attrgetter

import numpy as np
from pysc2 import maps, run_configs
from pysc2.lib.remote_controller import RemoteController
from pysc2.lib.sc_process import StarcraftProcess
from pysc2.lib.units import Protoss, Terran, Zerg
from s2clientprotocol import common_pb2 as sc_common
from s2clientprotocol import debug_pb2 as dbg_pb
from s2clientprotocol import raw_pb2 as raw_pb
from s2clientprotocol import sc2api_pb2 as sc_pb
from smacv2.starcraft2.distributions import Distribution, WeightedTeamsDistribution
from smacv2.starcraft2.render import StarCraft2Renderer


class Outcome(enum.IntEnum):
    """Game outcome"""

    WIN = enum.auto()
    DRAW = enum.auto()
    LOSS = enum.auto()


class HotStart(str, enum.Enum):
    """Initalize player units with action"""

    none = "none"
    closest = "closest"
    mean_pos = "mean_pos"


@dataclass(slots=True)
class TargetAssignment:
    """Unit action is targeting another"""

    other: int


UnitAction = Outcome | TargetAssignment | None


@dataclass(slots=True)
class PositionAssignment:
    """Unit action is moving to a position"""

    x: float
    y: float


@dataclass(slots=True)
class SC2GameCfg:
    """Configuration used for launching the game"""

    pos_dist: dict
    map_name: str = "10gen_terran"
    player_race: int = sc_common.Random
    enemy_race: int = sc_common.Random
    difficulty: int = sc_pb.Medium
    step_size: int = 5
    max_duration: int = 150
    hotstart: HotStart = HotStart.none


@dataclass
class SC2ObsCfg:
    """Configuration for observation"""

    unit_features: list[str] = field(
        default_factory=lambda: ["x", "y", "t", "unitType"]
    )
    separate_enemy: bool = True
    height_map: bool = True
    combine_health_shield: bool = False
    normalize_health_shield: bool = True


TRAIN_FEAT_TO_OBS = {
    "x": "pos.x",
    "y": "pos.y",
    "t": "facing",
    "unitType": "unit_type",
    "health": "health",
    "shield": "shield",
    "alliance_self": "alliance",
}

UNIT_NAME_TO_ID = {
    u.name.lower(): u
    for u in [
        Protoss.Colossus,
        Protoss.Stalker,
        Protoss.Zealot,
        Terran.Marauder,
        Terran.Medivac,
        Terran.Marine,
        Zerg.Baneling,
        Zerg.Hydralisk,
        Zerg.Zergling,
    ]
}


def launch_sc2(cfg: SC2GameCfg, seed=None):
    """Launch SC2 Game and return the process"""
    run_config = run_configs.get()
    map_ = maps.get(cfg.map_name)
    proc: StarcraftProcess = run_config.start(want_rgb=False)
    controller = proc.controller

    create_req = sc_pb.RequestCreateGame(
        local_map=sc_pb.LocalMap(
            map_path=map_.path, map_data=run_config.map_data(map_.path)
        ),
        realtime=False,
        random_seed=seed,
    )
    create_req.player_setup.add(type=sc_pb.Participant)
    create_req.player_setup.add(
        type=sc_pb.Computer, race=cfg.enemy_race, difficulty=cfg.difficulty
    )
    controller.create_game(create_req)

    join_req = sc_pb.RequestJoinGame(
        race=cfg.player_race, options=sc_pb.InterfaceOptions(raw=True, score=False)
    )
    controller.join_game(join_req)
    return proc


def make_unit_action_request(action_raw_unit_command):
    return sc_pb.RequestAction(
        actions=[
            sc_pb.Action(
                action_raw=raw_pb.ActionRaw(unit_command=action_raw_unit_command)
            )
        ]
    )


def get_player_and_enemy_positions(units):
    """From unit observation, create arrays of player and enemy unit positions"""
    player_pos = []
    enemy_pos = []
    for unit in units:
        pos = (unit.pos.x, unit.pos.y)
        if unit.owner == 1:
            player_pos.append(pos)
        elif unit.owner == 2:
            enemy_pos.append(pos)
        else:
            raise KeyError(f"Unknown {unit.owner=}")
    player_pos = np.array(player_pos, dtype=np.float32)
    enemy_pos = np.array(enemy_pos, dtype=np.float32)
    return player_pos, enemy_pos


def get_move_average_command(units):
    """Instruct all units to move to the average position of the enemy"""
    _, enemy_pos = get_player_and_enemy_positions(units)
    enemy_mean = np.mean(enemy_pos, axis=0)
    return make_unit_action_request(
        raw_pb.ActionRawUnitCommand(
            ability_id=23,
            target_world_space_pos=sc_common.Point2D(x=enemy_mean[0], y=enemy_mean[1]),
            unit_tags=[u.tag for u in units if u.owner == 1],
            queue_command=False,
        )
    )


def get_attack_closest_command(units):
    """Find the closes enemy unit to the mean player unit position and order all units to attack"""
    player_pos, enemy_pos = get_player_and_enemy_positions(units)
    player_mean = player_pos.mean(axis=0)[np.newaxis]  # [1,2]
    enemy_dist = np.linalg.norm(player_mean - enemy_pos, ord=2, axis=1)
    enemy_idx = np.argmin(enemy_dist)
    enemy_tag = [u.tag for u in units if u.owner == 2][enemy_idx]
    return make_unit_action_request(
        raw_pb.ActionRawUnitCommand(
            ability_id=23,
            target_unit_tag=enemy_tag,
            unit_tags=[u.tag for u in units if u.owner == 1],
            queue_command=False,
        )
    )


class StarCraft2Env:
    """Basic environment for playing against AI"""

    def __init__(
        self,
        game_cfg: SC2GameCfg,
        obs_cfg: SC2ObsCfg,
        pos_dist: Distribution,
        team_dist: WeightedTeamsDistribution,
        seed=None,
    ) -> None:
        self.game_cfg = game_cfg
        self.obs_cfg = obs_cfg
        self.seed = seed
        self._episode_steps = 0
        self._proc = launch_sc2(game_cfg, seed)
        self._controller: RemoteController = self._proc.controller
        self._obs = self._controller.observe()

        self._team_dist = team_dist
        self._pos_dist = pos_dist

        map_info = self._controller.game_info().start_raw
        self.map_size = (map_info.map_size.x, map_info.map_size.y)
        self.terrain_height = (
            np.frombuffer(map_info.terrain_height.data, dtype=np.uint8)
            .reshape(*self.map_size)
            .transpose()
        )
        self.terrain_height = np.flip(self.terrain_height, 1) / 255.0
        self.renderer: StarCraft2Renderer | None = None
        self._initialize_game()

        # Needed for renderer
        self.window_size = (1024, 1024)
        self.reward = 1

    @property
    def _observation(self):
        return self._obs.observation

    @property
    def map_name(self):
        return self.game_cfg.map_name

    def _initialize_team(self, team, positions, player_id):
        """Initialize team"""
        positions = positions * self.map_size
        debug_command = [
            dbg_pb.DebugCommand(
                create_unit=dbg_pb.DebugCreateUnit(
                    unit_type=UNIT_NAME_TO_ID[u],
                    owner=player_id,
                    pos=sc_common.Point2D(x=p[0], y=p[1]),
                    quantity=1,
                )
            )
            for u, p in zip(team, positions)
        ]
        self._controller.debug(debug_command)

    def _initialize_game(self):
        """Initialize the game state with units"""
        team_dist = self._team_dist.generate()["team_gen"]
        pos_dist = self._pos_dist.generate()
        for k in ["ally", "enemy"]:
            self._initialize_team(
                team_dist[f"{k}_team"],
                pos_dist[f"{k}_start_positions"]["item"],
                1 if k == "ally" else 2,
            )

        # Ensure expected units are spawned
        all_spawned = False
        retry_count = 0
        while not all_spawned:
            self._controller.step(1)
            self._obs = self._controller.observe()

            self_unit = 0
            enemy_units = 0
            for unit in self._observation.raw_data.units:
                if unit.owner == 1:
                    self_unit += 1
                elif unit.owner == 2:
                    enemy_units += 1
                else:
                    raise KeyError(f"Unknown unit owner {unit.owner}")

            all_spawned = self_unit == len(team_dist["ally_team"])
            all_spawned &= enemy_units == len(team_dist["enemy_team"])

            retry_count += 1
            if retry_count > 1000:
                raise RuntimeError("Exceeded max retry creating game")

        if self.game_cfg.hotstart is HotStart.closest:
            command = get_attack_closest_command(self._observation.raw_data.units)
            self._controller.actions(command)
        if self.game_cfg.hotstart is HotStart.mean_pos:
            command = get_move_average_command(self._observation.raw_data.units)
            self._controller.actions(command)

    def close(self):
        """Stop the sc2 process (and renderer)"""
        if self.renderer is not None:
            self.renderer.close()
            self.renderer = None
        self._proc.close()

    def kill_units(self):
        """Killing all units usually resets the game"""
        kill_command = [
            dbg_pb.DebugCommand(
                kill_unit=dbg_pb.DebugKillUnit(
                    tag=[u.tag for u in self._observation.raw_data.units]
                )
            )
        ]
        self._controller.debug(kill_command)
        # Spin until all dead
        while len(self._controller.observe().observation.raw_data.units) > 0:
            self._controller.step(2)

    def reset(self):
        """Reset the environment"""
        self.kill_units()
        try:
            self._initialize_game()
        except RuntimeError:  # Retry once
            print("Failed to restart env, retrying one more time")
            self.kill_units()
            self._initialize_game()
        self._episode_steps = 0

    def _submit_actions(self, actions: list[UnitAction]):
        """Submit actions to game"""
        commands = []

        units = []
        enemy_ids = []
        for u in self._observation.raw_data.units:
            if u.owner == 1:
                units.append(u)
            if u.owner == 2:
                enemy_ids.append(u.tag)

        for unit, action in zip(units, actions):
            if action is None:
                continue

            cmd_kwargs = {"unit_tags": [unit.tag], "queue_command": False}
            if isinstance(action, TargetAssignment):
                enemy_tag = enemy_ids[action.other]
                if len(unit.orders) > 0:  # Don't repeat attack order
                    if unit.orders[0].target_unit_tag == enemy_tag:
                        continue
                command = raw_pb.ActionRawUnitCommand(
                    ability_id=23, target_unit_tag=enemy_tag, **cmd_kwargs
                )
            elif isinstance(action, PositionAssignment):
                # Not sure if 23 Attack or 24 AttackTowards is appropriate
                command = raw_pb.ActionRawUnitCommand(
                    ability_id=23,
                    target_world_space_pos=sc_common.Point2D(x=action.x, y=action.y),
                    **cmd_kwargs,
                )
            commands.append(
                sc_pb.Action(action_raw=raw_pb.ActionRaw(unit_command=command))
            )

        action_req = sc_pb.RequestAction(actions=commands)
        self._controller.actions(action_req)

    def step(self, actions: list[UnitAction]):
        """Apply actions and step simulator"""
        self._episode_steps += 1

        self._submit_actions(actions)
        self._controller.step(self.game_cfg.step_size)
        self._obs = self._controller.observe()

        ally_count = 0
        enemy_count = 0
        for unit in self._observation.raw_data.units:
            if unit.unit_type == Terran.Medivac:
                continue
            if unit.owner == 1:
                ally_count += 1
            elif unit.owner == 2:
                enemy_count += 1
            else:
                raise KeyError(f"Unknown owner {unit.owner}")

        reward = None
        terminated = (
            ally_count == 0
            or enemy_count == 0
            or self._episode_steps > self.game_cfg.max_duration
        )
        info = {}
        if terminated:
            if self._episode_steps > self.game_cfg.max_duration:
                info["outcome"] = Outcome.DRAW
            elif ally_count > 0:
                info["outcome"] = Outcome.WIN
            elif enemy_count > 0:
                info["outcome"] = Outcome.LOSS
            else:
                info["outcome"] = Outcome.DRAW

        return reward, terminated, info

    def render(self, mode="human"):
        """Render simulation"""
        if self.renderer is None:
            self.renderer = StarCraft2Renderer(self, mode)
        assert mode == self.renderer.mode, f"Inconsistent render {mode=}"
        return self.renderer.render(mode)

    def get_obs(self):
        """Return observation of the units"""
        self_units = []
        enemy_units = []

        get_feats = attrgetter(
            *[TRAIN_FEAT_TO_OBS[f] for f in self.obs_cfg.unit_features]
        )

        try:
            health_idx = self.obs_cfg.unit_features.index("health")
        except ValueError:
            health_idx = None
        try:
            shield_idx = self.obs_cfg.unit_features.index("shield")
        except ValueError:
            shield_idx = None

        for unit in self._observation.raw_data.units:
            unit_feat = np.array(get_feats(unit), dtype=np.float32)
            # Add shield value to health then normalize
            if self.obs_cfg.combine_health_shield:
                assert health_idx is not None
                unit_feat[health_idx] += unit.shield
                if self.obs_cfg.normalize_health_shield:
                    unit_feat[health_idx] /= unit.health_max + unit.shield_max
            # If normalize, check health/shield_max and normalize
            elif self.obs_cfg.normalize_health_shield:
                if health_idx is not None and unit.health_max > 0:
                    unit_feat[health_idx] /= unit.health_max
                if shield_idx is not None and unit.shield_max > 0:
                    unit_feat[shield_idx] /= unit.shield_max

            try:  # Change "alliance" to flag true
                idx = self.obs_cfg.unit_features.index("alliance_self")
                unit_feat[idx] = unit.owner == 1
            except ValueError:
                pass

            if unit.owner == 1:
                self_units.append(unit_feat)
            elif unit.owner == 2:
                enemy_units.append(unit_feat)
            else:
                raise KeyError(f"Unknown owner: {unit.owner}")

        return {
            "units": np.stack(self_units, axis=0),
            "enemy_units": np.stack(enemy_units, axis=0),
        }

    def get_env_info(self):
        """Get some basic environment information"""
        env_info = {}
        env_info["n_units"] = self._team_dist.n_units
        env_info["n_enemies"] = self._team_dist.n_enemies
        env_info["unit_dim"] = len(self.obs_cfg.unit_features)
        env_info["terrain_height"] = self.terrain_height
        env_info["map_size"] = self.map_size
        return env_info
