import math
import random
import numpy as np
import networkx as nx

from shapely import geometry
from scipy.spatial.distance import directed_hausdorff
from networkx.exception import NetworkXError
from shapely.geometry import Point, LineString


def create_graph(keypoints, lines, node_bias=0):

    G = nx.MultiGraph()

    for i in range(keypoints.shape[0]):
        G.add_node(i + node_bias, x=keypoints[i][0], y=keypoints[i][1])

    for lin in lines:
        line = LineString([Point(keypoints[lin[0]]), Point(keypoints[lin[1]])])
        G.add_edge(
            lin[0] + node_bias,
            lin[1] + node_bias,
            length=line.length,
            geometry=line,
        )

    return G


def keypoints_is_in_boundary(point, image_size=300, pad_boundary=3):
    return (
        point["x"] < pad_boundary
        or point["x"] > image_size - 1 - pad_boundary
        or point["y"] < pad_boundary
        or point["y"] > image_size - 1 - pad_boundary
    )


def remove_edges_in_small_segments(
    G, remove_length=100, image_size=300, pad_boundary=3
):
    edges_to_remove = []
    for node in G.nodes():
        if len(G[node]) == 1 and not keypoints_is_in_boundary(G.nodes[node]):
            for neigh in G[node]:
                remove_edges_in_small_segments_traverse(
                    G,
                    neigh,
                    [node],
                    G[node][neigh][0]["length"],
                    edges_to_remove,
                    remove_length=remove_length,
                )

    for edge in edges_to_remove:
        try:
            G.remove_edge(edge[0], edge[1])
        except NetworkXError:
            pass

        try:
            G.remove_edge(edge[1], edge[0])
        except NetworkXError:
            pass


def remove_edges_in_small_segments_traverse(
    G, node, current_path, current_path_len, edges_to_remove, remove_length=100
):
    current_path.append(node)
    if len(G[node]) != 2:
        # reached an end ..
        if current_path_len < remove_length:
            # remove all edges ...
            for i in range(len(current_path) - 1):
                start = current_path[i]
                end = current_path[i + 1]
                edges_to_remove.append([start, end])

        return

    for neigh in G[node]:
        if neigh == current_path[-2]:
            continue
        else:
            break

    remove_edges_in_small_segments_traverse(
        G,
        neigh,
        current_path,
        current_path_len + G[node][neigh][0]["length"],
        edges_to_remove,
        remove_length=remove_length,
    )


# matching issue not important
def transform_graph(old_G):
    G = nx.Graph()

    for i in old_G.nodes():
        G.add_node(i, pos=(old_G.nodes[i]["x"], old_G.nodes[i]["y"]))

    for x, y in old_G.edges():
        G.add_edge(x, y)

    return G


def sample_new_keypoint(
    target_keypoints,
    current_keypoints,
    range_one,
    select_corner_prob=0.5,
    return_only_corners=False,
    pessimistic_rewards=False,
):
    if range_one is None:
        return None, 0

    rectangle = geometry.Polygon(
        [
            [range_one[0][0], range_one[1][0]],
            [range_one[0][1], range_one[1][0]],
            [range_one[0][1], range_one[1][1]],
            [range_one[0][0], range_one[1][1]],
        ]
    )

    linestring = geometry.LineString(
        [
            geometry.Point(point[0].item(), point[1].item())
            for point in np.concatenate([target_keypoints, target_keypoints[0][None]])
        ]
    )

    res = rectangle.intersection(linestring)

    reward = 0

    # the largest the distance of the interesection to the current point cloud the larget the reward should be ..
    if res.length > 0:
        if len(current_keypoints) <= 0:
            # first keypoint detected!
            reward = res.length

            if pessimistic_rewards:
                reward = range_one[0][1] - range_one[0][0]
        else:
            if pessimistic_rewards:
                # under pessimistic rewards, only corners are awarded ...
                available_targets = list(
                    filter(
                        lambda x: keypoint_is_in_range(x, range_one)
                        and (
                            np.linalg.norm(
                                np.stack(current_keypoints) - x, axis=1
                            ).min()
                            > 0.0001
                            if len(current_keypoints) > 0
                            else True
                        ),
                        target_keypoints,
                    )
                )
                if len(available_targets) > 0:
                    # good enough?
                    reward = range_one[0][1] - range_one[0][0]
            else:
                # get min of max distances ...

                # sample ten points from each line
                samples_every_pixels = 2

                if type(res) == geometry.GeometryCollection:
                    res1 = list(
                        filter(
                            lambda x: type(x) == geometry.linestring.LineString,
                            list(res),
                        )
                    )
                else:
                    res1 = [res]

                points = []

                for line in res1:
                    for per in range(math.ceil(line.length / samples_every_pixels)):
                        x, y = line.interpolate(per * samples_every_pixels).xy
                        points.append([x[0], y[0]])

                distance, _, _ = directed_hausdorff(
                    np.array(points), np.stack(current_keypoints)
                )

                reward = distance

    # with a random probability select one of the target polygon points if these are in our window else sample
    if random.random() < select_corner_prob or return_only_corners:
        # sample from targets ..
        available_targets = list(
            filter(
                lambda x: keypoint_is_in_range(x, range_one)
                and (
                    np.linalg.norm(np.stack(current_keypoints) - x, axis=1).min()
                    > 0.0001
                    if len(current_keypoints) > 0
                    else True
                ),
                target_keypoints,
            )
        )
        if len(available_targets) > 0:
            return (
                available_targets[random.randint(0, len(available_targets) - 1)],
                reward,
            )
        else:
            # if return only targets then return only targets indeed :P
            if return_only_corners:
                return None, 0
            # else if all corners are already taken then sample a new point randomly

    # else sample randomly ..
    if type(res) == geometry.GeometryCollection:
        res = list(
            filter(lambda x: type(x) == geometry.linestring.LineString, list(res))
        )
        if len(res) == 0:
            res = None
        else:
            # sample a random line from the intersection ..
            res = res[random.randint(0, len(res) - 1)]

    if res is not None and res.length > 0.01:
        x, y = res.interpolate(random.random(), normalized=True).xy
        return np.array([x[0], y[0]]), reward
    else:
        return None, reward


def get_point_info(target_keypoints, point):
    # for buffer 0 sometimes intersection not found ..
    for buffer_size in [0, 1]:
        for i in range(target_keypoints.shape[0]):
            start = target_keypoints[i]
            end = target_keypoints[(i + 1) % target_keypoints.shape[0]]

            linestring = geometry.LineString(
                [
                    geometry.Point(point[0].item(), point[1].item())
                    for point in [start, end]
                ]
            )

            x, y = linestring.coords.xy

            if linestring.buffer(buffer_size).intersects(geometry.Point(point)):
                if np.abs(np.array([x[1], y[1]]) - np.array([x[0], y[0]])).max() > 1e-4:
                    idx = np.abs(
                        np.array([x[1], y[1]]) - np.array([x[0], y[0]])
                    ).argmax()
                    val = np.abs(np.array([x[1], y[1]]) - np.array([x[0], y[0]])).max()
                    lambda_val = np.abs(point - np.array([x[0], y[0]]))[idx] / val
                    return i, lambda_val
                else:
                    return i, 0

    assert False, "Should not reach this point..."
    return None


def sample_new_edge(
    target_keypoints,
    current_keypoints,
    current_keypoints_info,
    current_edges,
    ranges_one,
    new_edge_has_keypoints_in_area=False,
):

    if ranges_one is None or len(current_keypoints_info) < 2:
        return None, 0

    current_keypoints_info = np.array(current_keypoints_info)

    all_possible_edges = set()

    for i in range(target_keypoints.shape[0]):
        indices = []
        percentage_along_bar = []

        for j in range(current_keypoints_info.shape[0]):
            if current_keypoints_info[j, 0] == i:
                indices.append(j)
                percentage_along_bar.append(current_keypoints_info[j, 1])
            if (
                current_keypoints_info[j, 0] == ((i + 1) % target_keypoints.shape[0])
                and abs(current_keypoints_info[j, 1]) < 0.01
            ):
                indices.append(j)
                percentage_along_bar.append(1)
            if (
                current_keypoints_info[j, 0]
                == ((i - 1 + target_keypoints.shape[0]) % target_keypoints.shape[0])
                and abs(current_keypoints_info[j, 1] - 1) < 0.01
            ):
                indices.append(j)
                percentage_along_bar.append(0)

        #     indices = np.where(current_keypoints_info[:, 0] == i)[0]

        if len(indices) <= 1:
            continue

        indices = [
            x
            for _, x in sorted(
                zip(percentage_along_bar, indices), key=lambda pair: pair[0]
            )
        ]

        for idx in range(len(indices) - 1):
            all_possible_edges.add(sort_edge([indices[idx], indices[idx + 1]]))

    for edge in current_edges:
        if tuple(edge) in all_possible_edges:
            all_possible_edges.remove(tuple(edge))

    possible_edges_in_range = []

    if new_edge_has_keypoints_in_area:
        for edge in all_possible_edges:
            if keypoint_is_in_range(
                current_keypoints[edge[0]], ranges_one
            ) or keypoint_is_in_range(current_keypoints[edge[1]], ranges_one):
                possible_edges_in_range.append(edge)
    else:
        for edge in all_possible_edges:
            if edge_is_in_range(
                current_keypoints[edge[0]],
                current_keypoints[edge[1]],
                ranges_one,
            ):
                possible_edges_in_range.append(edge)

    if len(possible_edges_in_range) <= 0:
        return None, 0
    else:
        edge = possible_edges_in_range[
            random.randint(0, len(possible_edges_in_range) - 1)
        ]
        return (
            edge,
            np.linalg.norm(current_keypoints[edge[0]] - current_keypoints[edge[1]]),
        )


def sort_edge(e):
    return (e[0], e[1]) if e[0] < e[1] else (e[1], e[0])


def get_ranges_from_actions(action, range_x, range_y, quantize_image_size):
    if action == quantize_image_size * quantize_image_size * 2:
        return None

    # we only care about ranges here not the type of action
    if action >= quantize_image_size * quantize_image_size:
        action = action - quantize_image_size * quantize_image_size

    index_range_x = action // quantize_image_size
    index_range_y = action % quantize_image_size

    return (
        (range_x[index_range_x], range_x[index_range_x + 1]),
        (range_y[index_range_y], range_y[index_range_y + 1]),
    )


def keypoint_is_in_range(x, ranges):
    return (
        x[0] >= ranges[0][0]
        and x[0] < ranges[0][1]
        and x[1] >= ranges[1][0]
        and x[1] < ranges[1][1]
    )


def get_keypoints_not_seen(set_, ranges):
    if ranges is None:
        return None

    keypoints = list(filter(lambda x: keypoint_is_in_range(x[0], ranges), list(set_)))

    return keypoints


def get_edges_not_seen(set_, keypoints, keypoints_map, ranges):
    if ranges is None:
        return None

    # first keep only edges for keypoints that have been observed so far ...
    observed_keypoints = list(keypoints_map.keys())
    possible_edges = list(
        filter(
            lambda x: x[0] in observed_keypoints and x[1] in observed_keypoints, set_
        )
    )

    possible_edges_in_range = []

    for edge in possible_edges:
        if keypoint_is_in_range(
            keypoints[keypoints_map[edge[0]]], ranges
        ) or keypoint_is_in_range(keypoints[keypoints_map[edge[1]]], ranges):
            possible_edges_in_range.append(edge)

    return possible_edges_in_range


def edge_is_in_range(point1, point2, ranges):
    linestring = geometry.LineString(
        [geometry.Point(point[0], point[1]) for point in [point1, point2]]
    )
    rectangle = geometry.Polygon(
        [
            [ranges[0][0], ranges[1][0]],
            [ranges[0][1], ranges[1][0]],
            [ranges[0][1], ranges[1][1]],
            [ranges[0][0], ranges[1][1]],
        ]
    )

    return linestring.intersects(rectangle)
