import time
from typing import Callable

import numpy as np
import torch
from konductor.data import ExperimentInitConfig
from konductor.utilities.pbar import IntervalPbar
from smacv2.starcraft2.distributions import WeightedTeamsDistribution, get_distribution
from torch import Tensor

from ..dataset.sc2_common import extract_minimap_roi
from ..dataset.sc2_dataset import (
    SC2BattleCfg,
    TorchSC2Data,
    UnitTypeToContiguous,
    create_unit_type_to_contiguous_map,
    make_unit_type_contiguous,
)
from ..model.sc2_perceiver import SC2IntentPredictor
from ..utils.eval_common import EnvResult
from ..utils.position_transforms import PositionTransform
from .simulator import (
    Outcome,
    PositionAssignment,
    SC2GameCfg,
    SC2ObsCfg,
    StarCraft2Env,
    TargetAssignment,
    UnitAction,
)


def create_env(game_cfg: SC2GameCfg, obs_cfg: SC2ObsCfg):
    """Create custom environment"""
    # Difficulty doesn't seem to have significant effect on outcome
    dist_cfg = game_cfg.pos_dist
    for sub_cfg in ["team_gen", "start_positions"]:
        dist_cfg[sub_cfg]["env_key"] = sub_cfg
        for count_param in ["n_units", "n_enemies"]:
            dist_cfg[sub_cfg][count_param] = dist_cfg[count_param]
    team_dist = WeightedTeamsDistribution(dist_cfg["team_gen"])
    pos_dist = get_distribution(dist_cfg["start_positions"]["dist_type"])(
        dist_cfg["start_positions"]
    )
    env = StarCraft2Env(game_cfg, obs_cfg, pos_dist, team_dist)
    return env


def add_observation(
    obs: dict[str, np.ndarray], data: TorchSC2Data, normalize: Callable[[Tensor], None]
):
    """Add observation to model input data"""
    assert data.enemy_mask is not None and data.enemy_units is not None

    # Append Unit Data
    units_mask = torch.zeros(data.units_mask.shape[-1], dtype=torch.bool).cuda()
    units_mask[: obs["units"].shape[0]] = True
    data.units_mask = torch.cat([data.units_mask, units_mask[None, None]], dim=0)

    units = torch.zeros(data.units.shape[-2:]).cuda()
    units[: obs["units"].shape[0]] = torch.as_tensor(obs["units"]).cuda()
    normalize(units)
    data.units = torch.cat([data.units, units[None, None]], dim=0)

    # Append Enemy Data
    enemy_mask = torch.zeros(data.enemy_mask.shape[-1], dtype=torch.bool).cuda()
    enemy_mask[: obs["enemy_units"].shape[0]] = True
    data.enemy_mask = torch.cat([data.enemy_mask, enemy_mask[None, None]], dim=0)

    enemy = torch.zeros(data.enemy_units.shape[-2:]).cuda()
    enemy[: obs["enemy_units"].shape[0]] = torch.as_tensor(obs["enemy_units"]).cuda()
    normalize(enemy)
    data.enemy_units = torch.cat([data.enemy_units, enemy[None, None]], dim=0)


def intention_to_action(
    preds: dict[str, Tensor],
    unit_mask: Tensor,
    unit_pos: Tensor,
    pos_transform: PositionTransform,
    center_xy: Tensor,
    inv_half_size: Tensor,
) -> list[UnitAction]:
    """Decode intention prediction to unit action"""
    # Get last frame and remove batch dim of predictions
    pred_targets = torch.argmax(preds["unit-target"][0], dim=-1)
    pred_pos_logits = preds["pos-logit"][0].sigmoid()
    pred_pos_values = (
        pos_transform(preds["position"][0], unit_pos) / inv_half_size + center_xy
    )

    actions = []
    for target, pos_logit, pos_value, is_valid in zip(
        pred_targets.tolist(), pred_pos_logits, pred_pos_values, unit_mask
    ):
        if not is_valid.item():
            actions.append(None)
        elif target != 0:
            actions.append(TargetAssignment(target - 1))
        elif pos_logit > 0.25:
            actions.append(PositionAssignment(pos_value[0].item(), pos_value[1].item()))
        else:
            actions.append(None)

    return actions


def run_episode(
    env: StarCraft2Env,
    model: SC2IntentPredictor | None,
    data_cfg: SC2BattleCfg,
    pos_transform: PositionTransform,
    unit_type_map: UnitTypeToContiguous,
    visualize: bool,
):
    """Run custom implementation"""

    env.reset()
    env_info = env.get_env_info()

    center_xy = torch.as_tensor(env_info["map_size"]).cuda() / 2
    inv_half_size = 2.0 / torch.as_tensor(data_cfg.roi_size).cuda()

    def normalize(data: Tensor):
        """Inplace normalize unit data"""
        data[..., :2] = (data[..., :2] - center_xy) * inv_half_size

    if model is not None:
        model.inc_reset(1)
        minimap = extract_minimap_roi(
            env_info["terrain_height"],
            center_xy.cpu().numpy(),
            data_cfg.roi_size,
            data_cfg.minimap_size,
        )
        minimap = torch.as_tensor(minimap, dtype=torch.float32).cuda()
        model.inc_minimap((minimap[None, None] - 127) / 128)

    terminated = False
    while not terminated:
        if model is None:
            actions = []
        else:
            obs = env.get_obs()
            make_unit_type_contiguous(obs["units"], -1, unit_type_map)
            make_unit_type_contiguous(obs["enemy_units"], -1, unit_type_map)
            # add_observation(obs, model_data, normalize)
            model_data = TorchSC2Data(
                units=torch.as_tensor(obs["units"]).cuda()[None],
                unit_targets=torch.empty(0),
                units_mask=torch.ones(
                    1, *obs["units"].shape[:-1], dtype=torch.bool
                ).cuda(),
                enemy_units=torch.as_tensor(obs["enemy_units"]).cuda()[None],
                enemy_mask=torch.ones(
                    1, *obs["enemy_units"].shape[:-1], dtype=torch.bool
                ).cuda(),
            )
            normalize(model_data.units)
            normalize(model_data.enemy_units)
            pred = model.inc_forward(model_data)
            actions = intention_to_action(
                pred,
                model_data.units_mask[0],
                model_data.units[0, :, :2],
                pos_transform,
                center_xy,
                inv_half_size,
            )

        if visualize:
            env.render()
            time.sleep(0.05)

        _, terminated, info = env.step(actions)

    return info


def track_results(info: dict, results: EnvResult):
    """Accumulate results from custom environment"""
    outcome: Outcome = info["outcome"]
    if outcome is Outcome.WIN:
        results.wins += 1
    elif outcome is Outcome.LOSS:
        results.losses += 1
    else:
        results.draws += 1


@torch.no_grad()
def run_evaluation(
    n_samples: int,
    game_cfg: SC2GameCfg,
    model: SC2IntentPredictor | None,
    cfg: ExperimentInitConfig,
    visualize: bool,
):
    """Run evaluation with SMACv2 Environment"""
    data_cfg = SC2BattleCfg.from_config(cfg)
    pos_transform = PositionTransform.from_config(cfg)
    unit_type_map = create_unit_type_to_contiguous_map(data_cfg.get_unit_type_file())
    obs_cfg = SC2ObsCfg(unit_features=[f.name for f in data_cfg.unit_features])
    obs_cfg.combine_health_shield = data_cfg.combine_health_shield
    env = create_env(game_cfg, obs_cfg)

    results = EnvResult()

    with IntervalPbar(n_samples, fraction=0.2, desc="Evaluating") as pbar:
        for _ in range(n_samples):
            info = run_episode(
                env, model, data_cfg, pos_transform, unit_type_map, visualize
            )
            track_results(info, results)
            pbar.update(1)

    env.close()

    return results
