""" 
Generate an sequence to playback for motion perceiver so that
it is able to learn to predict the motion of the obstacles and robots.
"""

import enum
import functools
import random
from dataclasses import dataclass, field
from typing import Any

from chasing_targets_gym.planner import Planner
import cv2
import gymnasium
import numpy as np
import torch
from konductor.data import DATASET_REGISTRY, DatasetConfig, ExperimentInitConfig, Split
from konductor.data._pytorch.dataloader import DataloaderV1Config
from torch import Tensor
from torch.utils.data import Dataset

from .torch_batch_sampler import BatchSamplerSyncRandom


@dataclass(slots=True)
class SampleRange:
    min: int
    max: int

    @classmethod
    def make(cls, other):
        match other:
            case dict():
                return cls(**other)
            case list() | tuple():
                return cls(*other)
            case SampleRange():
                return cls(other.min, other.max)
            case _:
                raise NotImplementedError

    def __post_init__(self):
        assert self.min <= self.max, "Min is > than max???"

    def __iter__(self):
        return iter([self.min, self.max])


SampleType = list[int] | SampleRange | tuple[int] | dict[str, int]


@dataclass
class OccupancyParams:
    shape: int
    times: list[int]
    random_count: int = 0
    random_range: SampleRange | None = None

    def __post_init__(self):
        assert self.random_count >= 0, f"Invalid random count {self.random_count}"

        if self.random_count > 0:
            assert self.random_range is not None, "Random count > 0 but range is None"

            # Ensure correct struct
            if not isinstance(self.random_range, SampleRange):
                self.random_range = SampleRange.make(self.random_range)

    @property
    def n_time(self):
        return len(self.times) + self.random_count


class MultiRobotGym(Dataset):
    """
    Optionally include targets and separate robots/targets in the returned
    agents state with one-hot encoding, however we only ever care about robot
    positions so they will not be included in the occupancy.

    Classes are last channel where 1=agent, 2=target (we start from 1 to be
    consistent with waymo motion if needed)
    """

    class RenderType(enum.Enum):
        CIRCLE = enum.auto()
        RECTANGLE = enum.auto()

    def __init__(
        self,
        robot_radius: float,
        max_velocity: float,
        n_robots: SampleType,
        n_targets: SampleType,
        n_iter: int,
        sandbox_extent: float,
        dataset_len: int = 10000,
        skip_iter: int = 0,
        render_style: RenderType = RenderType.CIRCLE,
        occupancy: OccupancyParams | None = None,
    ) -> None:
        super().__init__()
        self.dataset_len = dataset_len
        self.n_robots = SampleRange.make(n_robots)
        self.n_targets = SampleRange.make(n_targets)
        self.robot_radius = robot_radius
        self.max_velocity = max_velocity
        self.planner = Planner(robot_radius, 0.1, max_velocity)
        self.n_iter = n_iter
        self.skip_iter = skip_iter
        self.occ_cfg = occupancy
        self.sandbox_extent = sandbox_extent
        self.render_fn = {
            MultiRobotGym.RenderType.CIRCLE: self.render_agent_circle,
            MultiRobotGym.RenderType.RECTANGLE: self.render_agent_rectangle,
        }[render_style]

    def __len__(self):
        return self.dataset_len  # arbitrary

    @property
    def ppm(self) -> float:
        """Pixels per meter (occupancy)"""
        assert self.occ_cfg is not None
        return self.occ_cfg.shape / self.sandbox_extent

    @property
    def sandbox_dims(self):
        extents = [self.sandbox_extent / 2] * 4
        extents[0] *= -1
        extents[1] *= -1
        return extents

    def make_data_buffer(self) -> dict[str, Tensor]:
        """
        Initialise training data buffers
        Agent position tensor is of [num agent,time step,[x,y,t,vx,vy,vt,cls]]
        """
        n_feats = 6
        buffs = {
            "agents": torch.zeros((self.n_robots.max, self.n_iter, n_feats)),
            "agents_valid": torch.zeros((self.n_robots.max, self.n_iter)),
            "targets": torch.zeros((self.n_targets.max, self.n_iter, n_feats)),
            "targets_valid": torch.zeros((self.n_targets.max, self.n_iter)),
            "agent_target": torch.full(
                (self.n_robots.max, self.n_iter), fill_value=-1, dtype=torch.int64
            ),
        }
        if self.occ_cfg is not None:
            buffs["agents_occ"] = torch.zeros(
                (self.occ_cfg.n_time, self.occ_cfg.shape, self.occ_cfg.shape)
            )
            buffs["targets_occ"] = torch.zeros(
                (self.occ_cfg.n_time, self.occ_cfg.shape, self.occ_cfg.shape)
            )
            buffs["time_idx"] = torch.zeros((1, self.occ_cfg.n_time), dtype=torch.int64)

        return buffs

    def _sim2pix(self, position: Tensor) -> np.ndarray:
        """Transform coordinate from sim space to pixel space"""
        assert self.occ_cfg is not None
        position = position.clone()
        position *= self.ppm  # scale to image space
        position[1] *= -1  # flip y axis
        position += self.occ_cfg.shape / 2  # translate to image center
        return position.to(torch.int32).numpy()

    def render_agent_circle(
        self, occ_img: np.ndarray, agent_state: Tensor
    ) -> np.ndarray:
        """Renders agent on occupancy image as a circle"""
        radius_px = int(self.ppm * self.robot_radius)
        occ_img = cv2.circle(
            occ_img, self._sim2pix(agent_state[:2]), radius_px, 1, thickness=-1
        )
        return occ_img

    def render_agent_rectangle(
        self, occ_img: np.ndarray, agent_state: Tensor
    ) -> np.ndarray:
        """Renders agent on occupancy image as directed rectangle"""
        raise NotImplementedError

    def update_data_collector(
        self,
        tidx: int,
        data_buffer: dict[str, Tensor],
        observed_state: dict[str, np.ndarray],
        info: dict[str, Any],
        occ_idxs: list[int] | None,
    ) -> None:
        """_summary_

        :param tidx: Time index of the simulation
        :param data_buffer: Data to be returned as ground truth
        :param observed_state: The observed state returned by the simulator
        :param info: Extra information about the simulation
        :param barriers: The [x,y,vx,vy] of the barriers
        :param occ_idxs: Sorted list of time indexes where occupancy should be saved
        """
        obs = {k: torch.as_tensor(v) for k, v in observed_state.items()}
        # Add agents to the state
        n_robot: int = obs["current_robot"].shape[0]
        data_buffer["agents"][:n_robot, tidx] = obs["current_robot"]

        # Add targets to the state
        n_target = obs["current_target"].shape[0]
        target_theta = torch.arctan2(
            obs["current_target"][:, 3], obs["current_target"][:, 2]
        ).unsqueeze(-1)
        torch.cat(
            [
                obs["current_target"][:, :2],  # x, y
                target_theta,  # theta
                obs["current_target"][:, 2:],  # dx, dy
                torch.zeros_like(target_theta),  # dtheta
            ],
            dim=1,
            out=data_buffer["targets"][:n_target, tidx],
        )

        # Add target assignment
        data_buffer["agent_target"][:n_robot, tidx] = obs["robot_target_idx"]

        if self.occ_cfg is None:
            return  # Skip occupancy things

        # Add occupancy images
        assert occ_idxs is not None
        try:
            occ_idx = occ_idxs.index(tidx)
        except ValueError:  # Not found
            pass
        else:
            occ_img = np.zeros((self.occ_cfg.shape, self.occ_cfg.shape))
            for agent_pos in obs["current_robot"]:
                occ_img = self.render_fn(occ_img, agent_pos)
            data_buffer["agents_occ"][occ_idx] = torch.as_tensor(occ_img)

            occ_img.fill(0)  # Reset
            for target_pos in obs["current_target"]:
                occ_img = self.render_fn(occ_img, target_pos)
            data_buffer["targets_occ"][occ_idx] = torch.as_tensor(occ_img)

            data_buffer["time_idx"][:, occ_idx] = tidx

    def run_simulation(
        self, seed: int, occ_idxs: list[int] | None
    ) -> dict[str, Tensor]:
        """Run simulation for n steps and yield results"""
        data_buffer = self.make_data_buffer()

        # Need to construct gym every time to generate new random robots and targets
        rand = random.Random(seed)
        env = gymnasium.make(
            "ChasingTargets-v0",
            n_robots=rand.randint(self.n_robots.min, self.n_robots.max),
            n_targets=rand.randint(self.n_targets.min, self.n_targets.max),
            robot_radius=self.robot_radius,
            max_velocity=self.max_velocity,
            target_velocity_std=self.max_velocity,
            sandbox_dimensions=self.sandbox_dims,
            max_episode_steps=self.n_iter + self.skip_iter,
        )

        observation, info = env.reset(seed=seed)
        done = False
        tidx = 0
        while not done:
            action = self.planner(observation)
            observation, _, terminated, truncated, info = env.step(action)
            if tidx >= self.skip_iter:
                self.update_data_collector(
                    tidx - self.skip_iter,
                    data_buffer,
                    observation,
                    info,
                    occ_idxs,
                )
            tidx += 1
            done = terminated or truncated

        # Add valid mask
        data_buffer["agents_valid"][: info["n_robots"]].fill_(1)
        data_buffer["targets_valid"][: info["n_targets"]].fill_(1)

        # Normalise position [-1,1]
        data_buffer["agents"][..., :2] /= self.sandbox_extent * 0.5
        data_buffer["targets"][..., :2] /= self.sandbox_extent * 0.5

        # Normalise angle [-1,1]
        data_buffer["agents"][..., 2] /= np.pi
        data_buffer["targets"][..., 2] /= np.pi

        return data_buffer

    def __getitem__(self, idx: int | tuple[int, list[int]]):
        if isinstance(idx, tuple):
            idx, occ_idxs = idx
        elif self.occ_cfg is not None:
            occ_idxs = sorted(self.occ_cfg.times)
        else:
            occ_idxs = None

        data = self.run_simulation(idx, occ_idxs)

        return data


@dataclass
@DATASET_REGISTRY.register_module("gym-predict")
class GymPredictConfig(DatasetConfig):
    # Dataloader we expect to use
    train_loader: DataloaderV1Config
    val_loader: DataloaderV1Config

    robot_radius: float = 0.1
    max_velocity: float = 0.5
    n_robots: SampleRange = field(default_factory=lambda: SampleRange(5, 10))
    n_targets: SampleRange = field(default_factory=lambda: SampleRange(3, 5))
    n_iter: int = 50
    sandbox_extent: float = 4
    occupancy: OccupancyParams | None = None
    skip_iter: int = 0
    classes: list[str] = field(default_factory=lambda: ["agents", "targets"])

    @classmethod
    def from_config(cls, config: ExperimentInitConfig, idx: int = 0):
        return super().from_config(config, idx)

    @property
    def sequence_length(self):
        return self.n_iter

    def __post_init__(self):
        if isinstance(self.occupancy, dict):
            self.occupancy = OccupancyParams(**self.occupancy)
        if not isinstance(self.n_robots, SampleRange):
            self.n_robots = SampleRange.make(self.n_robots)
        if not isinstance(self.n_targets, SampleRange):
            self.n_targets = SampleRange.make(self.n_targets)

        if self.occupancy is not None and self.occupancy.random_count > 0:
            assert self.occupancy.random_range is not None
            assert (
                self.n_iter > self.occupancy.random_range.max
            ), "n_iter must be greater than the largest random occupancy timestep"

            def rand_gen() -> list[int]:
                assert self.occupancy is not None
                rand_time_idxs = set(range(*self.occupancy.random_range))
                rand_time_idxs.difference_update(set(self.occupancy.times))
                rand_time_idxs = list(rand_time_idxs)
                random.shuffle(rand_time_idxs)
                rand_time_idxs = rand_time_idxs[: self.occupancy.random_count]
                time_idxs = sorted([*self.occupancy.times, *rand_time_idxs])
                return time_idxs

            self.train_loader.batch_sampler = functools.partial(
                BatchSamplerSyncRandom, random_generator=rand_gen
            )
            self.val_loader.batch_sampler = functools.partial(
                BatchSamplerSyncRandom, random_generator=rand_gen
            )

    @property
    def properties(self) -> dict[str, Any]:
        return self.__dict__

    def get_dataloader(self, split: Split) -> Any:
        dataset_len = 100000
        if split != Split.TRAIN:
            dataset_len //= 100

        known_unused = {"train_loader", "val_loader", "basepath", "classes"}
        dataset = self.init_auto_filter(
            MultiRobotGym, known_unused, dataset_len=dataset_len
        )

        match split:
            case Split.TRAIN:
                return self.train_loader.get_instance(dataset)
            case Split.TEST | Split.VAL:
                return self.val_loader.get_instance(dataset)
            case _:
                raise RuntimeError("How did I get here?")
