from dataclasses import astuple
from typing import Optional

import gym
import numpy as np
from torch.utils.tensorboard.writer import SummaryWriter

from repo_anonymized.microrts.vec_env.microrts_socket_env import MicroRTSSocketEnv
from repo_anonymized.microrts.vec_env.microrts_space_transform import (
    MicroRTSSpaceTransform,
)
from repo_anonymized.microrts.wrappers.microrts_stats_recorder import (
    MicrortsStatsRecorder,
)
from repo_anonymized.runner.config import Config, EnvHyperparams
from repo_anonymized.wrappers.action_mask_stats_recorder import ActionMaskStatsRecorder
from repo_anonymized.wrappers.action_mask_wrapper import MicrortsMaskWrapper
from repo_anonymized.wrappers.additional_win_loss_reward import (
    AdditionalWinLossRewardWrapper,
)
from repo_anonymized.wrappers.episode_stats_writer import EpisodeStatsWriter
from repo_anonymized.wrappers.hwc_to_chw_observation import HwcToChwObservation
from repo_anonymized.wrappers.is_vector_env import IsVectorEnv
from repo_anonymized.wrappers.score_reward_wrapper import ScoreRewardWrapper
from repo_anonymized.wrappers.self_play_wrapper import SelfPlayWrapper
from repo_anonymized.wrappers.vectorable_wrapper import VecEnv


def make_microrts_bots_env(
    config: Config,
    hparams: EnvHyperparams,
    training: bool = True,
    render: bool = False,
    normalize_load_path: Optional[str] = None,
    tb_writer: Optional[SummaryWriter] = None,
) -> VecEnv:
    (
        _,  # env_type
        n_envs,
        _,  # frame_stack
        make_kwargs,
        _,  # no_reward_timeout_steps
        _,  # no_reward_fire_steps
        _,  # vec_env_class
        _,  # normalize
        _,  # normalize_kwargs,
        rolling_length,
        _,  # train_record_video
        _,  # video_step_interval
        _,  # initial_steps_to_truncate
        _,  # clip_atari_rewards
        _,  # normalize_type
        _,  # mask_actions
        bots,
        _,  # self_play_kwargs,
        _,  # selfplay_bots,
        additional_win_loss_reward,
        map_paths,
        score_reward_kwargs,
        _,  # is_agent,
        valid_sizes,
        paper_planes_sizes,
        fixed_size,
        terrain_overrides,
        _,  # time_budget_ms,
        video_frames_per_second,
        reference_bot,
        _,  # self_play_reference_kwargs,
        _,  # additional_win_loss_smoothing_factor,
    ) = astuple(hparams)

    seed = config.seed(training=training)

    from repo_anonymized.microrts import microrts_ai
    from repo_anonymized.microrts.vec_env.microrts_bot_vec_env import (
        MicroRTSBotGridVecEnv,
    )

    make_kwargs = make_kwargs or {}
    if "reward_weight" in make_kwargs:
        # Reward Weights:
        # RAIWinLossRewardFunction
        # ResourceGatherRewardFunction
        # ProduceWorkerRewardFunction
        # ProduceBuildingRewardFunction
        # AttackRewardFunction
        # ProduceLightUnitRewardFunction
        # ProduceHeavyUnitRewardFunction
        # ProduceRangedUnitRewardFunction
        # ScoreRewardFunction
        make_kwargs["reward_weight"] = np.array(make_kwargs["reward_weight"])

    assert bots, f"Must specify opponent bots"

    assert reference_bot, f"Must specify reference_bot"
    ref_ai = getattr(microrts_ai, reference_bot)
    assert ref_ai, f"{reference_bot} not in microrts_ai"
    if map_paths:
        _map_paths = []
        _ais = []
        for ai_name, n in bots.items():
            modulus = len(map_paths) * (1 if ai_name == reference_bot else 2)
            assert (
                n % modulus == 0
            ), f"Expect number of {ai_name} bots ({n}) to be a multiple of {modulus}"
            env_per_map = 2 * n // len(map_paths)
            opp_ai = getattr(microrts_ai, ai_name)
            for mp in map_paths:
                _map_paths.extend([mp] * env_per_map)
                for i in range(env_per_map // 2):
                    _ais.extend([opp_ai, ref_ai] if i % 2 else [ref_ai, opp_ai])
        make_kwargs["map_paths"] = _map_paths
        make_kwargs["ais"] = _ais
    else:
        _ais = []
        for ai_name, n in bots.items():
            for i in range(n):
                opp_ai = getattr(microrts_ai, ai_name)
                assert opp_ai, f"{ai_name} not in microrts_ai"
                _ais.extend([opp_ai, ref_ai] if i % 2 else [ref_ai, opp_ai])
        make_kwargs["ais"] = _ais
    make_kwargs["reference_indexes"] = [
        idx for idx, ai in enumerate(make_kwargs["ais"]) if ai == ref_ai
    ]

    envs = MicroRTSBotGridVecEnv(
        **make_kwargs, video_frames_per_second=video_frames_per_second
    )

    envs = MicroRTSSpaceTransform(
        envs,
        valid_sizes=valid_sizes,
        paper_planes_sizes=paper_planes_sizes,
        fixed_size=fixed_size,
        terrain_overrides=terrain_overrides,
    )
    envs = HwcToChwObservation(envs)
    envs = IsVectorEnv(envs)
    envs = MicrortsMaskWrapper(envs)

    if seed is not None:
        envs.action_space.seed(seed)
        envs.observation_space.seed(seed)

    envs = gym.wrappers.RecordEpisodeStatistics(envs)
    envs = MicrortsStatsRecorder(
        envs,
        bots,
        make_kwargs.get("map_paths"),
    )
    envs = ActionMaskStatsRecorder(envs)
    if training:
        assert tb_writer
        envs = EpisodeStatsWriter(
            envs,
            tb_writer,
            training=training,
            rolling_length=rolling_length,
            additional_keys_to_log=config.additional_keys_to_log,
        )

    if additional_win_loss_reward:
        envs = AdditionalWinLossRewardWrapper(envs)
    if score_reward_kwargs:
        envs = ScoreRewardWrapper(envs, **score_reward_kwargs)

    return envs
