import pymunk
from dataclasses import dataclass
from pathlib import Path
import ffmpeg
import tempfile
from examples.marble_drop.plotting import plot_space
import torch


FPS = 60.0
PLATFORM_RADIUS = 0.01
BUTTON_WIDTH = 1.0
MARBLE_RADIUS = BUTTON_WIDTH/5.0
ROW_HEIGHT = 2.0


Action = tuple[
    float,  # platform x1
    float,  # platform y1
    float,  # platform x2
    float,  # platform y2
]


@dataclass
class DynamicButtonRow:
    pass


@dataclass
class ButtonRow:
    button_n: int


Puzzle = list[DynamicButtonRow | ButtonRow]


@dataclass
class PymunkPuzzle:
    space: pymunk.Space
    marble: pymunk.Shape
    rows: list[list[pymunk.Shape] | pymunk.Shape]  # if dynamic row, only body


def get_puzzle_button_n(puzzle: Puzzle) -> int:
    """Return the number of buttons in the given puzzle."""
    return sum(
        row.button_n if isinstance(row, ButtonRow) else 1
        for row in puzzle
    )


def get_puzzle_width(puzzle: Puzzle) -> float:
    """Return the width of the puzzle."""
    # Compute puzzle width
    row_widths = [
        row.button_n if isinstance(row, ButtonRow) else 1
        for row in puzzle
    ]
    puzzle_width = max(row_widths)*BUTTON_WIDTH*3+1
    return puzzle_width


def get_puzzle_height(puzzle: Puzzle) -> float:
    """Return the height of the puzzle."""
    puzzle_height = (len(puzzle)+2)*ROW_HEIGHT
    return puzzle_height


def get_pymunk_space(puzzle: Puzzle) -> PymunkPuzzle:
    """Instantiate a new Pymunk space and add the bodies described by the
    puzzle."""
    # Create space
    space = pymunk.Space()
    space.gravity = (0.0, -9.81)
    static_body = space.static_body

    # Compute puzzle width
    puzzle_width = get_puzzle_width(puzzle)

    # Create each row from bottom to top
    pymunk_rows = list[list[pymunk.Shape] | pymunk.Shape]()
    button_radius = 0.3
    for row_i, row in enumerate(reversed(puzzle)):
        y = (row_i+1)*ROW_HEIGHT

        # Instantiate bodies depending on the type of row
        if isinstance(row, ButtonRow):
            pymunk_row = list()
            step_size = puzzle_width/(row.button_n+1)

            # Add each button
            for button_i in range(row.button_n):
                center_x = (button_i+1)*step_size

                # Create static segments
                b = pymunk.Body(body_type=pymunk.Body.STATIC)
                b.position = (center_x, y)
                shape = pymunk.Poly.create_box(
                    b,
                    (BUTTON_WIDTH, button_radius),
                )
                shape.elasticity = 0.95
                shape.friction = 0.5
                space.add(b, shape)
                pymunk_row.append(shape)
            pymunk_rows.append(pymunk_row)
        else:
            # TODO
            raise NotImplementedError()

    # Create marble
    mass = 1
    inertia = pymunk.moment_for_circle(mass, 0, MARBLE_RADIUS, (0, 0))
    marble_body = pymunk.Body(mass, inertia)
    x = MARBLE_RADIUS*2
    y = (len(puzzle)+1)*ROW_HEIGHT
    marble_body.position = (x, y)
    marble_shape = pymunk.Circle(marble_body, MARBLE_RADIUS, (0, 0))
    marble_shape.elasticity = 0.5
    marble_shape.friction = 0.5
    space.add(marble_body, marble_shape)

    # Create puzzle boundary
    puzzle_height = get_puzzle_height(puzzle)
    lines = [
        pymunk.Segment(static_body, (0, 0), (puzzle_width, 0), PLATFORM_RADIUS),
        pymunk.Segment(static_body, (puzzle_width, 0), (puzzle_width, puzzle_height), PLATFORM_RADIUS),
        pymunk.Segment(static_body, (puzzle_width, puzzle_height), (0, puzzle_height), PLATFORM_RADIUS),
        pymunk.Segment(static_body, (0, puzzle_height), (0, 0), PLATFORM_RADIUS),
    ]
    for line in lines:
        line.elasticity = 0.95
        line.friction = 0.5
    space.add(*lines)

    return PymunkPuzzle(space, marble_shape, list(reversed(pymunk_rows)))


def add_platform(
        action: Action | torch.Tensor,
        space: pymunk.Space,
        max_x: float,
        max_y: float
        ):
    """Add the given platform to the space."""
    static_body = space.static_body
    if isinstance(action, torch.Tensor):
        x1, y1, x2, y2 = action.tolist()
    else:
        x1, y1, x2, y2 = action

    # Clip action to puzzle bounds
    def clip_x(x: float) -> float: return max(min(x, max_x), 0.0)
    def clip_y(y: float) -> float: return max(min(y, max_y), 0.0)
    x1 = clip_x(x1)
    x2 = clip_x(x2)
    y1 = clip_y(y1)
    y2 = clip_y(y2)

    platform = pymunk.Segment(static_body, (x1, y1), (x2, y2), PLATFORM_RADIUS)
    platform.elasticity = 0.95
    platform.friction = 0.95
    space.add(platform)


def simulation(
        raw_actions: list[Action] | torch.Tensor,
        puzzle: Puzzle,
        output_animation_path: Path | None,
        timestep_n: int,
        ) -> list[tuple[int, int]]:
    """Returns buttons pressed as `(row_i, row_button_i)`."""
    # Setup plotting
    if output_animation_path is not None:
        tdir = tempfile.TemporaryDirectory()
        frame_dir = Path(tdir.name)
    else:
        tdir = None
        frame_dir = None

    # Create pymunk space
    space = get_pymunk_space(puzzle)
    max_x = get_puzzle_width(puzzle)
    max_y = get_puzzle_height(puzzle)
    actions = list()
    xp, yp = space.marble.body.position
    yp -= ROW_HEIGHT
    for (dx1, dy1, dx2, dy2) in raw_actions:
        dx1 = min(max(dx1, -5.0), 5.0)
        dy1 = min(max(dy1, -5.0), 5.0)
        dx2 = min(max(dx2, -5.0), 5.0)
        dy2 = min(max(dy2, -5.0), 5.0)
        x1 = xp + dx1*BUTTON_WIDTH*5
        y1 = yp + dy1*BUTTON_WIDTH*5
        x2 = x1 + dx2*BUTTON_WIDTH
        y2 = y1 + dy2*BUTTON_WIDTH
        action = (
            x1,
            y1,
            x2,
            y2
        )
        actions.append(action)
        px = x2
        py = y2

    # Step animation
    dt = 1.0/FPS
    action_i = 0
    budget = 1
    activated_buttons = list()
    for i in range(timestep_n):
        # If we have budget, create platform
        if budget > 0 and action_i < len(actions):
            action = actions[action_i]
            action_i += 1
            budget -= 1
            add_platform(action, space.space, max_x=max_x, max_y=max_y)

        # Step physics
        space.space.step(dt)

        # Handle contacts
        for row_i, pymunk_row in enumerate(space.rows):
            if isinstance(pymunk_row, list):
                for button_i, shape in enumerate(pymunk_row):
                    # HERE IS THE SOFT CHANGE:
                    distance = shape.body.position.get_distance(space.marble.body.position)
                    collision_info = shape.shapes_collide(space.marble)
                    if len(collision_info.points) > 0:
                        distance = -1

                    if distance < 0.3:  # ID: collision_check
                        # We collided with the button
                        button_key = (row_i, button_i)
                        if button_key not in activated_buttons:
                            activated_buttons.append(button_key)
                            budget += 1
            else:
                # TODO
                raise NotImplementedError()

        if frame_dir is not None:
            path = frame_dir/(f"{i}.png").rjust(10)
            plot_space(space.space, path, margin=MARBLE_RADIUS)

    # Save plot
    if frame_dir is not None and output_animation_path is not None:
        (
            ffmpeg
            .input(frame_dir/"*.png", pattern_type="glob", framerate=FPS)
            .output(str(output_animation_path))
            .overwrite_output()
            .run(quiet=True)
        )
    if tdir is not None:
        tdir.cleanup()

    total_buttons = sum(
        row.button_n if isinstance(row, ButtonRow) else 1
        for row in puzzle
    )
    if total_buttons == 0:
        return activated_buttons
    return activated_buttons


if __name__ == "__main__":
    puzzle: Puzzle = [
        ButtonRow(2),
        ButtonRow(3),
        ButtonRow(1),
    ]
    actions = [
        (0.1, 8, 1.8, 6.3),
        (4.3, 6.3, 2.7, 6.0),
        (7.0, 6, 5.5, 4.3),
        (4.0, 4.1, 5.0, 3.95),
        (2.1, 4.15, 3.0, 3.99),
        (0.0, 3.9, 2.7, 2.4),
    ]
    timestep_n = int(FPS*8)

    # Measure simulation time (without plotting)
    import time
    import math
    start_t = time.time()
    pressed_buttons = simulation(actions, puzzle, None, timestep_n)
    end_t = time.time()
    score = len(pressed_buttons)/max(1, get_puzzle_button_n(puzzle))
    print(f"Raw simulation time: {end_t-start_t}s")
    print(f"Buttons pressed: {pressed_buttons}")
    print(f"Fraction of buttons pressed: {score}")

    # Plot simulation
    animation_path = Path()/"debugging_marble_drop.mp4"
    simulation(actions, puzzle, animation_path, timestep_n)
    print(f"Wrote {animation_path}")
