from pathlib import Path
from typing import Sequence

import bottleneck as bn
import numpy as np
from sc2_serializer import (
    ReplayDataAllDatabase,
    ReplayDataAllParser,
    Score,
    StepDataSoA,
    set_replay_database_logger_level,
    spdlog_lvl,
)
from sc2_serializer.sampler import BasicSampler
from sc2_serializer.unit_features import UnitOH

from .dataset_utils import (
    get_target_coordinates,
    get_unique_target_coordinates,
    get_unit_coordinate_indices,
    get_unit_target_indices,
)
from .sc2_common import BattleEvent, SequenceData, get_default_sql_sampler

set_replay_database_logger_level(spdlog_lvl.err)


class ReplayDataset:
    """Simple SC2 Replay Dataset that yields ReplayDataAll entries"""

    __slots__ = ("sampler", "database", "_last_path")

    @staticmethod
    def sample(path: Path, index: int = 0):
        """Sample replay from dataset"""
        return ReplayDataset(path)[index]

    def __init__(self, path: Path):
        if path.is_dir():
            self.sampler = get_default_sql_sampler(path)
        else:
            self.sampler = BasicSampler(path, train_ratio=1, is_train=True)
        self.database = ReplayDataAllDatabase()
        self._last_path = ""

    def __len__(self):
        return len(self.sampler)

    def __getitem__(self, index: int):
        db_path, db_idx = self.sampler.sample(index)
        if db_path != self._last_path:
            self.database.load(db_path)
            self._last_path = db_path
        return self.database.getEntry(db_idx)

    def __iter__(self):
        for idx in range(len(self.sampler)):
            yield self[idx]


def calculate_smooth_damage_dealt(replay: StepDataSoA, window_size: int = 10):
    """Calculate the damage dealt by each player in a sliding window"""
    _sum_keys = [k for k in Score.__dict__ if k.startswith("total_damage_")]

    def sum_damage(score: Score):
        """Add all the different damage types together at a timepoint"""
        return sum(getattr(score, k) for k in _sum_keys)

    all_damage = np.array([sum_damage(s) for s in replay.score])
    ddamage_dt = np.diff(all_damage, prepend=all_damage[0])
    ddamage_dt_smooth: np.ndarray = bn.move_mean(ddamage_dt, window_size, min_count=1)
    return ddamage_dt_smooth


def find_directed_crossings(data: np.ndarray, threshold: float):
    """Find crossings of threshold and separate positive and negative direction"""
    dt = np.diff((data > threshold).astype(np.int8))
    start = np.argwhere(dt == 1).squeeze(-1)
    end = np.argwhere(dt == -1).squeeze(-1)
    # Player retires mid-battle
    if len(start) > len(end):
        end = np.concatenate(end, dt.shape[-1] - 1)
    return start, end


def find_battle_events(
    replay: StepDataSoA, window_size: int = 10, threshold: float = 10
):
    """Find time points of battle events within the game by using a moving average
    window on the amount of damage done thoughout a game. At each crossing of the
    `threshold`, a battle is determined and assigned a score depending on how large the battle is.
    """
    ddamage_dt = calculate_smooth_damage_dealt(replay, window_size)
    battle_start, battle_end = find_directed_crossings(ddamage_dt, threshold)
    battles = [
        BattleEvent(int(start), int(end), np.mean(ddamage_dt[start:end]).item())
        for start, end in zip(battle_start, battle_end)
    ]
    return battles


def extract_battle_sequence(
    parser: ReplayDataAllParser,
    event: BattleEvent,
    unit_features: Sequence[UnitOH],
    position_targets: bool = False,
    position_values: bool = True,
    combine_health_shield: bool = False,
    normalize_props: bool = False,
):
    """Extract unit features and their targets from a battle event

    Args:
        parser (ReplayDataAllParser): Parser filled with replay data to extract
        event (BattleEvent): Battle event to extract from the replay
        unit_features (Sequence[UnitOH]): Sequence of enum values to extract e.g.
            [sc2.UnitOH.x, sc2.UnitOH.y, sc2.UnitOH.t]
        unit_targets (bool): Find the unit-to-unit target assignment and include in result.
        position_targets (bool): Find all the positional targets and the units assigned to them.
        normalize_props (bool): Normalize health and shield properties based on
            health|shields_max if these properties are loaded.

    Returns:
        dict[str,list[np.ndarray]]: mapping of units, unit_targets, positions and position_targets.
    """
    battle_range = range(event.start, event.end)

    data = SequenceData(
        units=[parser.sample_units(idx) for idx in battle_range],
        unit_targets=[
            get_unit_target_indices(parser.data, idx) for idx in battle_range
        ],
    )

    if combine_health_shield:
        assert (
            UnitOH.shield not in unit_features
        ), "Shield should not be in unit_features if combined with health"
        # Add shield(_max) to health(_max)
        for units in data.units:
            units[..., UnitOH.health] += units[..., UnitOH.shield]
            units[..., UnitOH.health_max] += units[..., UnitOH.shield_max]
            # Normalize if requested
            if normalize_props:
                units[..., UnitOH.health] /= units[..., UnitOH.health_max]
    else:
        # Normalize health and shields when requested
        if normalize_props and UnitOH.health in unit_features:
            for units in data.units:
                units[..., UnitOH.health] /= units[..., UnitOH.health_max]
        if normalize_props and UnitOH.shield in unit_features:
            for units in data.units:
                units[..., UnitOH.shield] /= units[..., UnitOH.shield_max]

    # Keep requested features
    data.units = [u[..., unit_features] for u in data.units]

    if normalize_props:  # Ensure no nan/inf after normalization
        for units in data.units:
            units[~np.isfinite(units)] = 0

    if position_values:
        data.positions = [
            get_target_coordinates(parser.data, idx).astype(np.float32)
            for idx in battle_range
        ]

    if position_targets:
        data.positions_unique = [
            get_unique_target_coordinates(parser.data, idx).astype(np.float32)
            for idx in battle_range
        ]
        data.position_targets = [
            get_unit_coordinate_indices(
                parser.data, replay_idx, data.positions_unique[idx]
            )
            for idx, replay_idx in enumerate(battle_range)
        ]

    return data
