import random
from dataclasses import dataclass, field

import numpy as np
from konductor.data import DATASET_REGISTRY, DatasetConfig, ModuleInitConfig, Split
from konductor.data.dali import (
    DALIExternalSource,
    DaliExternalSourceParams,
    DaliLoaderConfig,
)
from konductor.init import ExperimentInitConfig
from nvidia.dali import fn
from nvidia.dali.pipeline import pipeline_def
from nvidia.dali.types import DALIDataType
from sc2_serializer import (
    GAME_INFO_FILE,
    ReplayDataAllDatabase,
    ReplayDataAllParser,
    set_replay_database_logger_level,
    spdlog_lvl,
)
from sc2_serializer.sampler import BasicSampler, ReplaySampler
from sc2_serializer.unit_features import UnitOH
from torch import Tensor

from .dataset_utils import (
    UnitTypeToContiguous,
    create_unit_type_to_contiguous_map,
    make_unit_type_contiguous,
)
from .sc2_common import (
    BattleEvent,
    apply_roi_transforms,
    find_main_roi,
    get_default_sql_sampler,
    normalize_minimap_coordinates,
)
from .sc2_preproc import extract_battle_sequence, find_battle_events

BATTLE_WINDOW_SIZE_DEFAULT = 20
BATTLE_THRESHOLD_DEFAULT = 3.0
set_replay_database_logger_level(spdlog_lvl.err)


@dataclass(slots=True, kw_only=True)
class TorchSC2Data:
    """Holds either a single sample, or batch of samples.The sequence extent is always the first
    dimension for convenient iteration over time (except minimap if its static)."""

    units: Tensor
    unit_targets: Tensor
    units_mask: Tensor
    enemy_units: Tensor | None = None
    enemy_mask: Tensor | None = None
    positions: Tensor | None = None
    positions_unique: Tensor | None = None
    positions_mask: Tensor | None = None
    position_targets: Tensor | None = None
    minimap: Tensor | None = None

    @classmethod
    def from_dali(cls, data: list[dict[str, Tensor]]):
        """Create TorchSC2Data from data loaded with DALI by transposing to sequence first"""
        return cls(
            **{
                k: v.transpose(0, 1) if k != "minimap" else v
                for k, v in data[0].items()
            }
        )

    @property
    def is_batch(self):
        """True if this object contains a batch of sequences"""
        return self.units.ndim == 4

    @property
    def batch_size(self):
        """Get the batch size if this is a batch of sequences"""
        if not self.is_batch:
            raise RuntimeError("This instance is not a batch of SC2Data")
        return self.units.shape[1]

    @property
    def sequence_len(self):
        """Return the lenght of the sequence"""
        return self.units.shape[0]

    def get_sample(self, index: int):
        """Returns single sample from the batch"""
        if not self.is_batch:
            raise RuntimeError("This instance is not a batch of SC2Data")
        return TorchSC2Data(
            **{
                k: (
                    getattr(self, k)[:, index]
                    if k != "minimap"
                    else getattr(self, k)[index]
                )
                for k in self.__slots__
                if getattr(self, k) is not None
            }
        )


@dataclass
@DATASET_REGISTRY.register_module("sc2-battle")
class SC2BattleCfg(DatasetConfig):
    """Configuration for DALI StarCraft II Battle Dataset"""

    train_loader: DaliLoaderConfig
    val_loader: DaliLoaderConfig

    clip_length: int
    unit_features: list[UnitOH] = field(
        default_factory=lambda: [UnitOH.x, UnitOH.y, UnitOH.t, UnitOH.alliance_self]
    )
    damage_window: int = BATTLE_WINDOW_SIZE_DEFAULT
    damage_thresh: float = BATTLE_THRESHOLD_DEFAULT

    enable_pos_values: bool = True
    enable_pos_targets: bool = False

    # Focus into ROI of size (in-game units)
    roi_size: np.ndarray | None = None

    # If not none, enables heightMap context at specified resolution
    minimap_size: np.ndarray | None = None

    # Separates self from enemy units
    separate_alliance: bool = False

    # Add additional properties encoded based on unit type
    unit_properties: bool = False

    # External Source Variables
    prefetch_queue_depth: int = 2

    cast_fp16: bool = False  # Cast float type data to fp16

    contiguous_unit_type: bool = False  # Map unit type id to contiguous mapping

    n_classes: int = field(init=False)

    combine_health_shield: bool = False

    @classmethod
    def from_config(cls, config: ExperimentInitConfig, idx: int = 0):
        if "amp" in config.trainer:
            config.data[idx].dataset.args["cast_fp16"] = True
        return super().from_config(config, idx)

    @property
    def properties(self):
        props = {
            "num_pos_features": self.num_pos_features,
            "num_other_features": self.num_other_features,
        }
        props.update(self.__dict__)
        return props

    @property
    def num_pos_features(self):
        return 3

    @property
    def num_other_features(self):
        # Unit type and position aren't counted
        not_others = {UnitOH.x, UnitOH.y, UnitOH.t, UnitOH.unitType}
        return sum(1 for f in self.unit_features if f not in not_others)

    def __post_init__(self):
        if not self.basepath.exists():
            raise FileNotFoundError(self.basepath)
        if isinstance(self.roi_size, (tuple, list)):
            self.roi_size = np.array(self.roi_size)
        if self.roi_size is not None:
            assert len(self.roi_size) == 2
        if isinstance(self.minimap_size, (tuple, list)):
            self.minimap_size = np.array(self.minimap_size)
        self.unit_features = [
            UnitOH[f] if isinstance(f, str) else f for f in self.unit_features
        ]
        assert all(
            a is b for a, b in zip(self.unit_features, [UnitOH.x, UnitOH.y, UnitOH.t])
        ), "unit_features must begin with x,y,t"

        if self.combine_health_shield:
            assert (
                UnitOH.shield not in self.unit_features
            ), "Shield should not be specified as a feature if it is summed with health"

        try:
            type_idx = self.unit_features.index(UnitOH.unitType)
        except ValueError:
            pass
        else:
            assert (
                type_idx == len(self.unit_features) - 1
            ), "unitType must be last value"

        self.n_classes = self.get_num_unit_type_ids()

        if self.unit_properties:
            raise NotImplementedError(
                "Properties based on unit type is currently not implemented"
            )

    def get_unit_type_file(self):
        """Get the path to the file with list of unit type ids"""
        if self.basepath.is_dir():
            return self.basepath / "unique_unit_list.txt"
        return self.basepath.parent / "unique_unit_list.txt"

    def make_source(self, split: Split):
        """Make data external source class"""
        loader = self.train_loader if split is Split.TRAIN else self.val_loader
        pipe_kwargs = loader.pipe_kwargs()
        del pipe_kwargs["prefetch_queue_depth"]  # Use config specific key
        source = self.init_auto_filter(
            SC2BattleDataset, cfg=self, split=split, **pipe_kwargs
        )
        return source

    def _get_size(self, split: Split):
        inst = self.make_source(split)
        inst._post_init()
        return inst.num_iterations * inst.batch_size

    def get_ext_params(self):
        """Get external parameters for source"""
        ext = DaliExternalSourceParams(
            dtype=[DALIDataType.FLOAT, DALIDataType.INT64, DALIDataType.BOOL],
            ndim=[3, 2, 2],
            layout=["", "", ""],
            pipe_names=["units", "unit_targets", "units_mask"],
            out_names=["units", "unit_targets", "units_mask"],
        )
        if self.enable_pos_values:
            ext.append(
                dtype=DALIDataType.FLOAT,
                ndim=3,
                layout="",
                pipe_names="positions",
                out_names="positions",
            )

        if self.separate_alliance:
            ext.extend(
                dtype=[DALIDataType.FLOAT, DALIDataType.BOOL],
                ndim=[3, 2],
                layout=["", ""],
                pipe_names=["enemy_units", "enemy_mask"],
                out_names=["enemy_units", "enemy_mask"],
            )

        if self.enable_pos_targets:
            ext.extend(
                dtype=[DALIDataType.FLOAT, DALIDataType.INT64, DALIDataType.BOOL],
                ndim=[3, 2, 2],
                layout=["", "", ""],
                pipe_names=["positions_unique", "position_targets", "positions_mask"],
                out_names=["positions_unique", "position_targets", "positions_mask"],
            )

        if self.minimap_size is not None:
            ext.append(
                dtype=DALIDataType.FLOAT,
                ndim=3,
                layout="chw",
                pipe_names="minimap",
                out_names="minimap",
            )

        return ext

    def get_num_unit_type_ids(self):
        """Read unit type ids from unique_unit_list.txt as"""
        with open(self.get_unit_type_file(), "r", encoding="utf-8") as f:
            id_str = f.read().strip()
        return len(id_str.split(","))

    def get_dataloader(self, split: Split):
        loader = self.train_loader if split is Split.TRAIN else self.val_loader
        pipeline = sc2_data_pipeline(cfg=self, split=split, **loader.pipe_kwargs())
        size = self._get_size(split)
        return loader.get_instance(
            pipeline, output_map=self.get_ext_params().out_names, size=size
        )


class SC2BattleDataset(DALIExternalSource):
    """DALI External Source for Battle Dataset"""

    def __init__(
        self,
        cfg: SC2BattleCfg,
        split: Split,
        batch_size: int,
        shard_id: int,
        num_shards: int,
        random_shuffle: bool,
    ):
        super().__init__(
            batch_size, shard_id, num_shards, random_shuffle, yields_batch=False
        )
        self.cfg = cfg
        self.split = split
        self.sampler: ReplaySampler | None = None
        self.parser: ReplayDataAllParser | None = None
        self.replay_db: ReplayDataAllDatabase | None = None
        self._last_db: str = ""
        self._roi_center: dict[str, np.ndarray] | None = None
        self._type_id_to_contiguous: UnitTypeToContiguous | None = None

    @property
    def is_preprocessed(self):
        return self.cfg.basepath.is_file()

    def _maybe_load_roi_center_file(self):
        """Loads the precalculated roi center file if it exists"""
        assert self.cfg.roi_size is not None
        if self.cfg.basepath.is_file():
            basepath = self.cfg.basepath.parent
        else:
            basepath = self.cfg.basepath
        roi_center_file = (
            basepath
            / f"sc2_roi_center_{self.cfg.roi_size[0]}_{self.cfg.roi_size[1]}.csv"
        )
        if not roi_center_file.exists():
            print(f"File doesn't exist: {roi_center_file}")
            return

        self._roi_center = {}
        with open(roi_center_file, "r", encoding="utf-8") as f:
            f.readline()  # skip header
            while line := f.readline():
                replay, x, y = line.split(",")
                self._roi_center[replay] = np.array([float(x), float(y)])

    def _post_init(self):
        # If "dataset" is single replay file then it must be a preprocessed battle dataset
        if self.is_preprocessed:
            self.sampler = BasicSampler(
                self.cfg.basepath, 0.8, self.split is Split.TRAIN
            )
        else:
            self.sampler = get_default_sql_sampler(
                self.cfg.basepath, self.split is Split.TRAIN
            )
        self.parser = ReplayDataAllParser(GAME_INFO_FILE)
        self.replay_db = ReplayDataAllDatabase()
        if self.cfg.roi_size is not None:
            self._maybe_load_roi_center_file()

        self._type_id_to_contiguous = create_unit_type_to_contiguous_map(
            self.cfg.get_unit_type_file()
        )
        super()._post_init()

    def __len__(self):
        assert self.sampler is not None, "_post_init() must be called first"
        return len(self.sampler)

    def load_replay_into_parser(self, index: int):
        """Load replay into parser"""
        assert self.sampler is not None
        assert self.parser is not None
        assert self.replay_db is not None
        path, index = self.sampler.sample(index)
        if self._last_db != path.name:
            assert self.replay_db.load(path)
            self._last_db = path.name
        replay_data = self.replay_db.getEntry(index)
        self.parser.parse_replay(replay_data)

    def _get_padded_event(self, events: list[BattleEvent]):
        """Get first event that can be feasibly used"""
        assert self.parser is not None
        return next(
            (e for e in events if e.start + self.cfg.clip_length < len(self.parser))
        )

    def _load_battle_event(self):
        assert self.parser is not None
        if self.is_preprocessed:
            event = BattleEvent(0, self.cfg.clip_length, self.cfg.damage_thresh)
        else:
            events = find_battle_events(
                self.parser.data, self.cfg.damage_window, self.cfg.damage_thresh
            )
            if self.split is Split.TRAIN:
                events_ = [e for e in events if e.duration > self.cfg.clip_length]
                event = random.choice(events_) if len(events_) > 0 else events[0]
            else:
                event = max(events, key=lambda e: e.duration)
            # If selected event is less than requested length, get the next best one
            if event.duration < self.cfg.clip_length:
                event = self._get_padded_event(events)

        # Trim event to requested length and ensure it is less than lend
        event.end = event.start + self.cfg.clip_length
        if event.end > len(self.parser):
            raise IndexError(
                f"Battle event ends {event.end} but replay length is {len(self.parser)}"
            )

        return extract_battle_sequence(
            self.parser,
            event,
            self.cfg.unit_features,
            self.cfg.enable_pos_targets,
            self.cfg.enable_pos_values,
            self.cfg.combine_health_shield,
            normalize_props=True,
        )

    def get_data(self, index: int):
        assert self.parser is not None
        self.load_replay_into_parser(index)
        data = self._load_battle_event()

        if self.cfg.minimap_size is not None:
            data.minimap = normalize_minimap_coordinates(
                self.parser.info.heightMap.data,
                self.parser.info.mapHeight,
                self.parser.info.mapWidth,
            )

        if self.cfg.roi_size is not None:
            map_size = np.array((self.parser.info.mapWidth, self.parser.info.mapHeight))
            if self._roi_center is None:
                roi_center = find_main_roi(
                    data.units, data.unit_targets, self.cfg.roi_size, map_size
                )
            else:
                roi_center = self._roi_center[self.parser.info.replayHash]
            apply_roi_transforms(
                data, self.cfg.roi_size, roi_center, self.cfg.minimap_size
            )

        if self.cfg.contiguous_unit_type:
            assert self._type_id_to_contiguous is not None
            type_ch = self.cfg.unit_features.index(UnitOH.unitType)
            for units in data.units:
                make_unit_type_contiguous(units, type_ch, self._type_id_to_contiguous)

        if self.cfg.separate_alliance:
            data.separate_self_units(self.cfg.unit_features.index(UnitOH.alliance_self))

        result = data.as_contiguous_arrays()

        # Normalize minimap to [0,1] and add channel dim
        if data.minimap is not None:
            data.minimap = (data.minimap.astype(np.float32) - 127) / 128
            result.append(data.minimap[None])

        return tuple(result)


@pipeline_def(py_start_method="spawn")
def sc2_data_pipeline(
    cfg: SC2BattleCfg,
    split: Split,
    shard_id: int,
    num_shards: int,
    random_shuffle: bool,
    augmentations: list[ModuleInitConfig],
):
    """Create pipeline"""
    source = cfg.make_source(split)
    ext_params = cfg.get_ext_params()
    outputs = fn.external_source(
        source=source,
        num_outputs=ext_params.num_outputs,
        parallel=True,
        batch=source.yields_batch,
        batch_info=source.yields_batch,
        dtype=ext_params.dtype,
        ndim=ext_params.ndim,
        layout=ext_params.layout,
        prefetch_queue_depth=cfg.prefetch_queue_depth,
    )

    pipe_data = dict(zip(ext_params.pipe_names, outputs))

    # Pad all data except for height_map
    for key, data in pipe_data.items():
        if key != "height_map":
            pipe_data[key] = fn.pad(data, fill_value=0)

    def get_output(key: str):
        data = pipe_data[key].gpu()
        param_idx = ext_params.out_names.index(key)
        if ext_params.dtype[param_idx] == DALIDataType.FLOAT and cfg.cast_fp16:
            data = fn.cast(data, dtype=DALIDataType.FLOAT16)
        return data

    out_data = tuple(get_output(k) for k in ext_params.out_names)

    return out_data
