from dataclasses import dataclass
from pathlib import Path

import cv2
import numpy as np
from sc2_serializer import StepDataSoA, set_replay_database_logger_level, spdlog_lvl
from sc2_serializer.sampler import SQLSampler
from scipy.cluster.vq import kmeans2

from .dataset_utils import in_bounds_2d, normalize_coordinates_inplace

set_replay_database_logger_level(spdlog_lvl.err)


@dataclass(slots=True)
class BattleEvent:
    """Indices in replay where battle occurs and its average intensity"""

    start: int
    end: int
    score: float

    @property
    def duration(self):
        """Duration of battle"""
        return self.end - self.start


@dataclass(slots=True)
class SequenceData:
    """Data from typical training sequence"""

    units: list[np.ndarray]
    unit_targets: list[np.ndarray]
    enemy_units: list[np.ndarray] | None = None
    positions: list[np.ndarray] | None = None
    positions_unique: list[np.ndarray] | None = None
    position_targets: list[np.ndarray] | None = None
    minimap: np.ndarray | None = None

    def __len__(self):
        return len(self.units)

    def as_contiguous_arrays(self, include_masking: bool = True):
        """Transform data as stacks of padded arrays"""
        result = [make_padded_array(self.units), make_padded_array(self.unit_targets)]

        if include_masking:
            result.append(get_padded_array_mask(self.units))

        if self.positions is not None:
            result.append(make_padded_array(self.positions))

        if self.enemy_units is not None:
            result.append(make_padded_array(self.enemy_units))
            if include_masking:
                result.append(get_padded_array_mask(self.enemy_units))

        if self.positions_unique is None:
            return result

        assert self.position_targets is not None
        result.extend(
            [
                make_padded_array(self.positions_unique),
                make_padded_array(self.position_targets),
            ]
        )
        if include_masking:
            result.append(get_padded_array_mask(self.positions_unique))

        return result

    def separate_self_units(self, self_idx: int):
        """Separate out self units and only keep their target assignments"""
        self.enemy_units = []
        for idx in range(len(self)):
            mask = self.units[idx][:, self_idx] == 1
            self.enemy_units.append(self.units[idx][~mask])
            # Invert mask since we are only considering enemy targets to reindex
            self.unit_targets[idx] = reindex_targets(
                self.unit_targets[idx][mask], ~mask
            )
            self.units[idx] = self.units[idx][mask]
            if self.positions is not None:
                self.positions[idx] = self.positions[idx][mask]
            if self.position_targets is not None:
                self.position_targets[idx] = self.position_targets[idx][mask]


def get_default_sql_sampler(dataset: Path, is_train: bool = True):
    """Create SQL sampler with some basic filters"""
    return SQLSampler(
        str(dataset / "gamedata.db"),
        dataset,
        [
            "game_length > 6720",
            "read_success = 1",
            "parse_success = 1",
            "playerAPM > 120",
        ],
        train_ratio=0.8,
        is_train=is_train,
    )


def get_unit_positions_with_targets_or_targeted(
    units_sequence: list[np.ndarray], targets_sequence: list[np.ndarray]
):
    """Get unit positions from sequence that have targets or are targeted as one
    homogenious array.

    Args:
        units_sequence (list[np.ndarray]): Sequence of unit observations as [N,M] where (x,y)
        must be the first 2 dims using in-game coordinates.
        targets_sequence (list[np.ndarray]): The corresponding sequence of unit targets indices.

    Returns:
        np.ndarray: Position of all units (x,y)
    """
    units_all = np.concatenate(units_sequence, axis=0)
    targets_all = np.concatenate(targets_sequence, axis=0)

    # Get units that have targets
    valid_units_list = [units_all[targets_all != -1]]

    # Also get units that are targets
    for units, targets in zip(units_sequence, targets_sequence):
        valid_targets = targets[targets != -1]
        valid_units_list.append(units[valid_targets])
    valid_units = np.concatenate(valid_units_list)

    return valid_units[:, :2]  # Only grab (x,y)


def find_main_roi(
    units_sequence: list[np.ndarray],
    targets_sequence: list[np.ndarray],
    roi_size: np.ndarray,
    map_size: np.ndarray,
    num_clusters: int = 5,
):
    """From a sequence of units and their targets, return the center of the ROI.

    Args:
        units_sequence (list[np.ndarray]): Sequence of unit observations as [N,M] where (x,y)
            must be the first 2 dims using in-game coordinates.
        targets_sequence (list[np.ndarray]): The corresponding sequence of unit targets indices.
        roi_size (np.ndarray): The size of the ROI as in-game units to find, if the clusters the
            algorithm finds is larger than this roi, we will prune the smaller clusters.
        map_size (np.ndarray): The size of the full map as [width,height]
        num_clusters (int): Number of clusters for kmeans algorithm

    Returns:
        np.ndarray: The center point of the ROI we want to use.
    """
    unit_pos = get_unit_positions_with_targets_or_targeted(
        units_sequence, targets_sequence
    )

    # Well no units are targets/targeted...just base off general unit clusters
    if unit_pos.size == 0:
        unit_pos = np.concatenate(units_sequence)[..., :2]

    def roi_size_and_center(centroids: np.ndarray):
        topleft: np.ndarray = np.min(centroids, axis=0)
        bottomright: np.ndarray = np.max(centroids, axis=0)
        size = bottomright - topleft
        center = (topleft + bottomright) / 2
        return size, center

    def restrict_center(centroid: np.ndarray):
        """Clip so that roi is within map bounds"""
        min_centroid = roi_size / 2
        max_centroid = map_size - min_centroid
        return np.clip(centroid, min_centroid, max_centroid)

    num_clusters = min(num_clusters, unit_pos.shape[0])

    centroids, labels = kmeans2(unit_pos, num_clusters, minit="points")
    prop_size, prop_center = roi_size_and_center(centroids)
    if (prop_size < roi_size).all():
        return restrict_center(prop_center)

    cluster_indices_sorted = np.argsort(np.bincount(labels))
    for num_remove in range(1, num_clusters):
        prop_size, prop_center = roi_size_and_center(
            np.delete(centroids, cluster_indices_sorted[:num_remove], axis=0)
        )
        if (prop_size < roi_size).all():
            return restrict_center(prop_center)

    # No idea how the above doesn't end up with 1 thing in cluster, whatever, just selecting max
    return restrict_center(centroids[cluster_indices_sorted[-1]])


def extract_minimap_roi(
    minimap: np.ndarray, center: np.ndarray, size: np.ndarray, resolution: np.ndarray
):
    """Extracts ROI from minimap, for simplicity, the crop is rounded to the nearest pixel,
    rather than interpolating.

    Args:
        minimap (np.ndarray): Original minimap
        center (np.ndarray): Center of ROI
        size (np.ndarray): Size of ROI
        resolution (np.ndarray): Output resolution of ROI

    Returns:
        np.ndarray: Minimap ROI
    """
    assert (
        size[0] / size[1] == resolution[0] / resolution[1]
    ), "aspect ratios between roi and resolution are not equal"
    half_size = size / 2
    left, top = tuple(map(round, (center - half_size)))
    right, bottom = tuple(map(round, center + half_size))

    # from matplotlib import pyplot as plt

    # plt.imshow(cv2.rectangle(minimap, (left, top), (right, bottom), 6))
    # plt.show()

    crop = minimap[..., top:bottom, left:right]
    crop = cv2.resize(crop, tuple(resolution), interpolation=cv2.INTER_NEAREST)
    return crop


def normalize_minimap_coordinates(minimap: np.ndarray, height: int, width: int):
    """Resize the minimap so that the coordinate system matches the pixels

    Args:
        minimap (np.ndarray): minimap to resize
        height (int): height of map in ingame units
        width (int): width of map in ingame units

    Returns:
        np.ndarray: Minimap normalized to ingame coordinatates
    """

    # Need to flip y-coordinates
    minimap = cv2.flip(minimap, 0)

    if minimap.shape[-1] / height < minimap.shape[-2] / width:
        new_width = int(minimap.shape[-2] / height * width)
        unpad_minimap = minimap[..., :new_width]
    else:
        new_height = int(minimap.shape[-1] / width * height)
        unpad_minimap = minimap[..., :new_height, :]

    # from matplotlib import pyplot as plt

    # fig, axs = plt.subplots(1, 2)
    # axs[0].set_title("original")
    # axs[0].imshow(minimap)
    # axs[1].set_title("unpad")
    # axs[1].imshow(unpad_minimap)
    # plt.show()
    # plt.close(fig)

    minimap = cv2.resize(unpad_minimap, (width, height), cv2.INTER_NEAREST)
    return minimap


## REPLACED BY CPP NATIVE with 11% overall speedup in end-to-end dataloader performance
# def normalize_coordinates_inplace(
#     sequence: list[np.ndarray], center: np.ndarray, size: np.ndarray
# ):
#     """Inplace apply normalization to coordinates [-1,1] assumes [x,y] are first two elements of
#     last dimension"""
#     s = [s.copy() for s in sequence]
#     inv_half_size = 1 / (size / 2)
#     for data in sequence:
#         data[..., :2] = (data[..., :2] - center) * inv_half_size
#     from .dataset_utils import normalize_coordinates_inplace as optim_norm
#     optim_norm(s, center, size)
#     for a, b in zip(sequence, s):
#         assert np.allclose(a, b)


def get_unit_target_indices_old(data: StepDataSoA, index: int):
    """Create array that maps units to the index of their target at a point of time in the replay
    data. This slow python method is replaced with a c++ native method"""
    unit_ids = {u.id: i for i, u in enumerate(data.units[index])}
    units_tgt_idx = [unit_ids.get(u.tgtId, -1) for u in data.units[index]]
    return np.array(units_tgt_idx, dtype=np.int32)


def reindex_targets(targets: np.ndarray, valid: np.ndarray):
    """Remove invalid targets and change indexing based on this filter"""
    # old = np.concatenate(
    #     (np.array([-1], dtype=targets.dtype), np.cumsum(valid, dtype=targets.dtype) - 1)
    # )
    mapping = np.zeros(valid.shape[0] + 1, dtype=targets.dtype)
    np.cumsum(valid, dtype=targets.dtype, out=mapping[1:])
    mapping -= 1
    # assert old == mapping
    return mapping[targets + 1]


def remove_units_outside_roi(data: SequenceData):
    """Remove units (and respective target data) if it is outside [-1,1] ROI.
    Mutates the input arguments.
    """
    for idx in range(len(data)):
        units = data.units[idx]
        valid = in_bounds_2d(units, -1, 1)
        data.units[idx] = units[valid]
        data.unit_targets[idx] = reindex_targets(data.unit_targets[idx][valid], valid)
        if data.positions is not None:
            data.positions[idx] = data.positions[idx][valid]
        if data.positions_unique is not None:
            assert data.position_targets is not None
            pos = data.positions_unique[idx]
            p_valid = in_bounds_2d(pos, -1, 1)
            data.positions_unique[idx] = pos[p_valid]
            data.position_targets[idx] = reindex_targets(
                data.position_targets[idx][valid], p_valid
            )


def apply_roi_transforms(
    data: SequenceData,
    roi_size: np.ndarray,
    roi_center: np.ndarray,
    minimap_res: np.ndarray | None = None,
):
    """Inplace-apply ROI cropping/restriction on the sequence data"""
    normalize_coordinates_inplace(data.units, roi_center, roi_size)
    if data.positions is not None:
        normalize_coordinates_inplace(data.positions, roi_center, roi_size)
    if data.positions_unique is not None:
        normalize_coordinates_inplace(data.positions_unique, roi_center, roi_size)
    remove_units_outside_roi(data)
    if data.minimap is not None:
        assert minimap_res is not None, "Specify minimap_res if minimap is present"
        data.minimap = extract_minimap_roi(
            data.minimap, roi_center, roi_size, minimap_res
        )


def make_padded_array(data: list[np.ndarray]):
    """Create contiguous array from list of arrays"""
    shape = [len(data), *data[0].shape]
    shape[1] = max(len(d) for d in data)
    new_data = np.zeros(shape, dtype=data[0].dtype)
    for idx, sample in enumerate(data):
        new_data[idx, : len(sample)] = sample
    return new_data


def get_padded_array_mask(data: list[np.ndarray]):
    """Create tha 'valid' mask array associated with list of arrays"""
    mask = np.zeros((len(data), max(len(d) for d in data)), dtype=bool)
    for idx, sample in enumerate(data):
        mask[idx, : len(sample)] = True
    return mask
