import math

import numpy as np


animation_total = 2
animation_stop = 0
background_rgb = (255, 255, 255)
attack_line_rgb = (0, 0, 0)
attack_dot_rgb = (0, 0, 0)
attack_dot_size = 0.3
text_rgb = (0, 0, 0)
text_size = 16
text_spacing = 3
banner_size = 20
banner_spacing = 3
bigscreen_size = 72
bigscreen_spacing = 0
grid_rgba = ((0, 0, 0), 30)
grid_size = 8


def draw_line(surface, color, a, b):
    import pygame

    pygame.draw.line(
        surface,
        color,
        (int(round(a[0])), int(round(a[1]))),
        (int(round(b[0])), int(round(b[1]))),
    )


def draw_rect(surface, color, a, w, h):
    import pygame

    pygame.draw.rect(
        surface,
        color,
        pygame.Rect(
            *map(
                int,
                (
                    round(a[0]),
                    round(a[1]),
                    round(w + a[0] - round(a[0])),
                    round(h + a[1] - round(a[1])),
                ),
            )
        ),
    )


def draw_rect_matrix(matrix, color, a, w, h, resolution):
    x, y, w, h = map(
        int,
        (
            round(a[0]),
            round(a[1]),
            round(w + a[0] - round(a[0])),
            round(h + a[1] - round(a[1])),
        ),
    )
    matrix[
        max(x, 0) : min(x + w, resolution[0]), max(y, 0) : min(h + y, resolution[1]), :
    ] = color


def draw_line_matrix(matrix, color, a, b, resolution):
    a = (min(max(0, a[0]), resolution[0] - 1), min(max(0, a[1]), resolution[1] - 1))
    b = (min(max(0, b[0]), resolution[0] - 1), min(max(0, b[1]), resolution[1] - 1))
    a = map(int, (round(a[0]), round(a[1])))
    b = map(int, (round(b[0]), round(b[1])))
    if a[0] == b[0]:
        if a[1] > b[1]:
            matrix[a[0], b[1] : a[1] + 1] = color
        else:
            matrix[a[0], a[1] : b[1] + 1] = color
    elif a[1] == b[1]:
        if a[0] > b[0]:
            matrix[b[0] : a[0] + 1, a[1]] = color
        else:
            matrix[a[0] : b[0] + 1, a[1]] = color
    else:
        raise NotImplementedError


class Renderer:
    def __init__(self, env, map_size, mode):
        import pygame

        self.env = env
        self.mode = mode
        self.handles = self.env.get_handles()
        base_resolution = (map_size * 8, map_size * 8 + 15)
        if mode == "human":
            pygame.init()
            pygame.display.init()
            infoObject = pygame.display.Info()
            screen_size = (infoObject.current_w - 50, infoObject.current_h - 50)
            self.resolution = resolution = np.min(
                [screen_size, base_resolution], axis=0
            )
            self.display = pygame.display.set_mode(resolution, pygame.DOUBLEBUF, 0)
            canvas_resolution = (resolution[0], resolution[1])
            self.canvas = pygame.Surface(canvas_resolution)
            pygame.display.set_caption("MAgent2 Renderer Window")
        elif mode == "rgb_array":
            pygame.font.init()
            self.resolution = base_resolution
            self.display = pygame.Surface(base_resolution)
            canvas_resolution = (base_resolution[0], base_resolution[1])
            self.canvas = pygame.Surface(canvas_resolution)

        self.text_formatter = pygame.font.SysFont(None, text_size, True)
        self.banner_formatter = pygame.font.SysFont(None, banner_size, True)
        self.bigscreen_formatter = pygame.font.SysFont(None, bigscreen_size, True)

        self.map_size = (map_size, map_size)

        self.frame_id = 0

        self.old_data = None
        self.new_data = None

        self.need_static_update = True
        self.animation_progress = 0

    def get_banners(self, frame_id, resolution):
        groups = self.env._get_groups_info()

        def form_txt(index):
            handle = self.handles[index]
            color = tuple(int(a) for a in groups[index][2:])
            return f"{np.sum(self.env.get_alive(handle).astype(np.int32))}", color

        if len(self.handles) == 1:
            result = [(form_txt(0),)]
        if len(self.handles) == 2:
            vs = " vs ", (0, 0, 0)
            result = [(form_txt(0), vs, form_txt(1))]
        elif len(self.handles) == 4:
            vs = " vs ", (0, 0, 0)
            comma = ", ", (0, 0, 0)
            result = [
                (form_txt(0), comma, form_txt(1), vs, form_txt(2), comma, form_txt(3))
            ]
        else:
            raise RuntimeError("bad number of handles")

        return result

    def close(self):
        import pygame

        pygame.display.quit()
        pygame.quit()

    def render(self, mode):
        import os

        os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "hide"
        import pygame

        env = self.env
        self.groups = env._get_groups_info()
        resolution = self.resolution

        grid_map = np.zeros((resolution[0], resolution[1], 3), dtype=np.int16)
        view_position = [
            self.map_size[0] / 2 * grid_size - resolution[0] / 2,
            self.map_size[1] / 2 * grid_size - resolution[1] / 2,
        ]

        groups = self.groups
        banner_formatter = self.banner_formatter
        status = True
        triggered = False
        # x_range: which vertical gridlines should be shown on the display
        # y_range: which horizontal gridlines should be shown on the display
        x_range = (
            max(0, int(math.floor(max(0, view_position[0]) / grid_size))),
            min(
                self.map_size[0],
                int(math.ceil(max(0, view_position[0] + resolution[0]) / grid_size)),
            ),
        )

        y_range = (
            max(0, int(math.floor(max(0, view_position[1]) / grid_size))),
            min(
                self.map_size[1],
                int(math.ceil(max(0, view_position[1] + resolution[1]) / grid_size)),
            ),
        )

        self.canvas.fill(background_rgb)
        self.display.fill(background_rgb)

        if self.need_static_update or True:
            grids = pygame.Surface(resolution)
            grids.fill(background_rgb)

        if (
            self.new_data is None
            or self.animation_progress > animation_total + animation_stop
        ):
            pos, event = env._get_render_info(x_range, y_range)
            buffered_new_data = pos, event

            if buffered_new_data is None:
                buffered_new_data = self.new_data
            self.old_data = self.new_data
            self.new_data = buffered_new_data
            self.animation_progress = 0

        if self.new_data is not None:
            if self.old_data is None and self.animation_progress == 0:
                self.animation_progress = animation_total

            if self.need_static_update or True:
                pygame.pixelcopy.surface_to_array(grid_map, self.canvas)
                for wall in env._get_walls_info():
                    x, y = wall[0], wall[1]
                    if (
                        x >= x_range[0]
                        and x <= x_range[1]
                        and y >= y_range[0]
                        and y <= y_range[1]
                    ):
                        draw_rect_matrix(
                            grid_map,
                            (127, 127, 127),
                            (
                                x * grid_size - view_position[0],
                                y * grid_size - view_position[1],
                            ),
                            grid_size,
                            grid_size,
                            resolution,
                        )
            pygame.pixelcopy.array_to_surface(self.canvas, grid_map)

            for key in self.new_data[0]:
                new_prop = self.new_data[0][key]
                new_group = groups[new_prop[2]]
                now_prop = new_prop
                now_group = new_group
                draw_rect(
                    self.canvas,
                    (int(now_group[2]), int(now_group[3]), int(now_group[4])),
                    (
                        now_prop[0] * grid_size - view_position[0],
                        now_prop[1] * grid_size - view_position[1],
                    ),
                    now_group[0] * grid_size,
                    now_group[1] * grid_size,
                )

            for key, event_x, event_y in self.new_data[1]:
                if key not in self.new_data[0]:
                    continue
                new_prop = self.new_data[0][key]
                new_group = groups[new_prop[2]]
                now_prop = new_prop
                now_group = new_group
                draw_line(
                    self.canvas,
                    attack_line_rgb,
                    (
                        now_prop[0] * grid_size
                        - view_position[0]
                        + now_group[0] / 2 * grid_size,
                        now_prop[1] * grid_size
                        - view_position[1]
                        + now_group[1] / 2 * grid_size,
                    ),
                    (
                        event_x * grid_size - view_position[0] + grid_size / 2,
                        event_y * grid_size - view_position[1] + grid_size / 2,
                    ),
                )
                draw_rect(
                    self.canvas,
                    attack_dot_rgb,
                    (
                        event_x * grid_size
                        - view_position[0]
                        + grid_size / 2
                        - attack_dot_size * grid_size / 2,
                        event_y * grid_size
                        - view_position[1]
                        + grid_size / 2
                        - attack_dot_size * grid_size / 2,
                    ),
                    attack_dot_size * grid_size,
                    attack_dot_size * grid_size,
                )

            if (
                status
                or triggered
                or self.animation_progress < animation_total + animation_stop
            ):
                self.animation_progress += 1

            self.display.blit(self.canvas, (0, 7))

            height_now = 0
            for texts in self.get_banners(self.frame_id, resolution):
                content = []
                width, height = 0, 0
                for text in texts:
                    text = banner_formatter.render(
                        text[0], True, pygame.Color(*text[1])
                    )
                    content.append((text, width))
                    width += text.get_width()
                    height = max(height, text.get_height())
                start = (resolution[0] - width) / 2.0
                for b in content:
                    self.display.blit(b[0], (start + b[1], height_now))
                height_now += height + banner_spacing

        if self.need_static_update:
            self.need_static_update = False

        observation = pygame.surfarray.pixels3d(self.display)
        new_observation = np.copy(observation)
        del observation
        if self.mode == "human":
            pygame.display.flip()
        return (
            np.transpose(new_observation, axes=(1, 0, 2))
            if mode == "rgb_array"
            else None
        )
