from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Annotated, Optional, Sequence

import cv2
import gymnasium
import numpy as np
import torch
import typer
from chasing_targets_gym.planner import Planner
from konductor.data import Split, get_dataset_config
from konductor.init import ExperimentInitConfig
from konductor.utilities.pbar import LivePbar
from matplotlib import pyplot as plt
from torch import Tensor
from torch.utils.data import DataLoader

from ..model import MotionPerceiver
from ..utils.eval_common import initialize
from ..utils.visual_occupancy import write_occupancy_batch
from ..utils.visual_target import write_target_batch


class VideoType(str, Enum):
    """Type of video to generate"""

    OCCUPANCY = "occupancy"
    TARGETS = "targets"


@dataclass
class EvalConfig:
    path: Path
    batch_size: int
    n_videos: int = 128
    video_thresh: float | None = None
    sequence_length: int = 1
    video_type: VideoType = VideoType.OCCUPANCY


ROBOT_RADIUS = 0.1


def update_step(robots: np.ndarray, target_pos: np.ndarray) -> dict[str, np.ndarray]:
    """Returns action for robot using pure pursuit algorithm"""
    X, Y = 0, 1
    lr_control = np.full([robots.shape[0], 2], 0.5, dtype=np.float32)
    alpha = np.arctan2(target_pos[Y] - robots[:, Y], target_pos[X] - robots[:, X])
    alpha -= robots[:, -1]
    delta = np.arctan2(2.0 * ROBOT_RADIUS * np.sin(alpha) * 3, 1.0)
    lr_control[delta > 0, 0] -= delta[delta > 0]
    lr_control[delta <= 0, 1] += delta[delta <= 0]
    # lr_control = np.clip(lr_control, -0.5, 0.5) # shouldn't be necessary

    return {"vL": lr_control[:, 0], "vR": lr_control[:, 1]}


def simple_planner(obs: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
    actions = []
    for robot, tgtid in zip(obs["current_robot"], obs["robot_target_idx"]):
        actions.append(update_step(robot[None], obs["future_target"][tgtid]))
    action = {}
    for key_ in actions[0]:
        action[key_] = np.concatenate([a[key_] for a in actions], axis=0)
    return action


app = typer.Typer()


@app.command()
def run_env(max_vel: float = 0.5, visualise: bool = False):
    """Runs simulation of target chasers"""
    env = gymnasium.make(
        "ChasingTargets-v0",
        n_robots=10,
        n_targets=4,
        render_mode="human" if visualise else None,
        robot_radius=ROBOT_RADIUS,
        max_velocity=max_vel,
        barrier_velocity_range=max_vel,
    )
    # planner = simple_planner

    observation, info = env.reset(seed=2)

    planner = Planner(ROBOT_RADIUS, info["dt"], max_velocity=max_vel)

    if not visualise:
        from scalene import scalene_profiler

        scalene_profiler.start()

    for _ in range(5000):
        action = planner(observation)
        observation, reward, terminated, truncated, info = env.step(action)

        env.render()
        if terminated or truncated:
            observation, info = env.reset()

    if not visualise:
        scalene_profiler.stop()

    env.close()


def write_occupancy_sequence(occupancy: Tensor, timestamps: Tensor, path: Path) -> None:
    for tidx, image in zip(timestamps, occupancy):
        cv2.imwrite(str(path / f"occupancy_{tidx}.png"), image.numpy() * 255.0)


def write_motion_spline(agents: Tensor, masks: Tensor, path: Path) -> None:
    plt.figure(figsize=(10, 10))
    for agent, mask in zip(agents.numpy(), masks.numpy()):
        agent = agent[mask != 0]  # get valid agents
        plt.scatter(agent[:, 0], agent[:, 1])
        if mask[10] == 0:
            continue
        # Draw arrow at T=10
        plt.arrow(
            agent[10, 0],
            agent[10, 1],
            np.cos(agent[10, 2]) / 10.0,
            np.sin(agent[10, 2]) / 10.0,
            width=0.01,
        )
    plt.tight_layout()
    plt.savefig(path / "agents.png")


def write_images(dataloder: DataLoader, path: Path) -> None:
    """Write images of loaded data to visually check correctness"""
    for bidx, data in enumerate(dataloder):
        folder = path / f"batch_{bidx}"
        if not folder.exists():
            folder.mkdir()
        else:
            # Clear items in folder for new batch
            for file in folder.iterdir():
                file.unlink()

        # Remove batch and class dimension
        write_occupancy_sequence(data["heatmap"][0, 0], data["time_idx"][0, 0], folder)
        write_motion_spline(data["agents"][0], data["agents_valid"][0], folder)

        if input('"e" to quit: ') == "e":
            break


@app.command()
def visualize_dataloader(config_file: Path, split: Split) -> None:
    """Write data loaded as images for visualisation + checking"""
    dst_folder = Path.cwd() / "disp"
    if not dst_folder.exists():
        dst_folder.mkdir()

    experiment_cfg = ExperimentInitConfig.from_config(dst_folder, config_file)
    experiment_cfg.set_batch_size(1, split)
    dataset_cfg = get_dataset_config(experiment_cfg)
    dataloader = dataset_cfg.get_instance(split)

    write_images(dataloader, dst_folder)


def generate_videos(
    model: MotionPerceiver,
    loader: Sequence[dict[str, Tensor] | list[dict[str, Tensor]]],
    config: EvalConfig,
    pbar: LivePbar,
):
    """Create Videos of Inference"""
    for data in loader:
        n_samples = pbar.n * config.batch_size
        if n_samples >= config.n_videos:
            break

        if isinstance(data, (tuple, list)):
            data = data[0]  # remove wrapper
        data = {k: v.cuda() for k, v in data.items()}
        outputs = model(**data)

        timestamps = list(range(0, config.sequence_length))

        match config.video_type:
            case VideoType.OCCUPANCY:
                write_occupancy_batch(
                    data,
                    outputs,
                    [t / 10 for t in timestamps],
                    config.path / "occupancy",
                    n_samples,
                    config.video_thresh,
                )
            case VideoType.TARGETS:
                write_target_batch(
                    data,
                    outputs,
                    [t / 10 for t in timestamps],
                    config.path / "targets",
                    n_samples,
                )

        pbar.update(1)


@app.command()
def visualize_prediction(
    run_path: Path,
    video_type: Annotated[VideoType, typer.Option()],
    split: Annotated[Split, typer.Option()] = Split.VAL,
    n_samples: Annotated[int, typer.Option()] = 16,
    workers: Annotated[int, typer.Option()] = 4,
    batch_size: Annotated[int, typer.Option()] = 8,
    threshold: Annotated[Optional[float], typer.Option()] = None,
) -> None:
    """"""
    exp_cfg = ExperimentInitConfig.from_run(run_path)
    exp_cfg.set_workers(workers)
    exp_cfg.set_batch_size(batch_size, split)

    model, data_cfg = initialize(exp_cfg)
    if data_cfg.occupancy is not None:
        data_cfg.occupancy.random_count = 0
        data_cfg.occupancy.times = list(range(data_cfg.sequence_length))
    dataloader = data_cfg.get_dataloader(split)

    config = EvalConfig(
        exp_cfg.exp_path / exp_cfg.data[0].dataset.type / str(split.name.lower()),
        batch_size,
        n_samples,
        threshold,
        sequence_length=data_cfg.sequence_length,
        video_type=video_type,
    )

    with (
        torch.inference_mode(),
        LivePbar(total=config.n_videos // config.batch_size) as pbar,
    ):
        generate_videos(model, dataloader, config, pbar)
