import time

import numpy as np
from konductor.utilities.pbar import IntervalPbar
from smacv2.starcraft2.wrapper import StarCraftCapabilityEnvWrapper

from ..model.sc2_perceiver import SC2IntentPredictor, TorchSC2Data
from ..utils.eval_common import EnvResult
from .simulator import SC2GameCfg


def create_env(game_cfg: SC2GameCfg):
    """Create SMACv2 Environment"""
    env = StarCraftCapabilityEnvWrapper(
        capability_config=game_cfg.pos_dist,
        map_name=game_cfg.map_name,
        conic_fov=False,
        obs_own_pos=True,
        use_unit_ranges=True,
        min_attack_range=2,
    )
    return env


def convert_smac_to_train_data(data) -> TorchSC2Data:
    """Convert smacv2 format to training data for SC2IntentPredictor"""
    raise NotImplementedError


def run_episode(
    env: StarCraftCapabilityEnvWrapper, model: SC2IntentPredictor, visualize: bool
):
    """Run an episode of controlling units with a model"""
    env.reset()
    terminated = False
    episode_reward = 0
    n_agents = env.get_env_info()["n_agents"]

    while not terminated:
        obs = env.get_obs()
        state = env.get_state()
        if visualize:
            env.render()
            time.sleep(0.1)

        actions = []
        for agent_id in range(n_agents):
            avail_actions = env.get_avail_agent_actions(agent_id)
            avail_actions_ind = np.nonzero(avail_actions)[0]
            action = np.random.choice(avail_actions_ind)
            actions.append(action)

        reward, terminated, info = env.step(actions)
        episode_reward += reward
    # print(f"Total reward in episode = {episode_reward}")
    return info


def run_evaluation(
    n_samples: int, game_cfg: SC2GameCfg, model: SC2IntentPredictor, visualize: bool
):
    """Run evaluation with SMACv2 Environment"""
    env = create_env(game_cfg)

    results = EnvResult()

    with IntervalPbar(n_samples, fraction=0.2, desc="Evaluating") as pbar:
        for _ in range(n_samples):
            info = run_episode(env, model, visualize)
            track_results(info, results)
            pbar.update(1)

    env.close()

    return results


def track_results(info: dict, results: EnvResult):
    """Accumulate SMACv2 Results"""
    if info["battle_won"]:
        results.wins += 1
    else:
        results.losses += 1
