import math
import random
import numpy as np
from shapely.geometry import LineString
import networkx as nx
import cv2
import copy
from scipy.stats import truncnorm
from functools import reduce

from simplification.cutil import simplify_coords
from dataprocessing import sknw
from dataprocessing import graph_utils


def colorjitter(img, cj_type=None):
    """
    ### Different Color Jitter ###
    img: image
    cj_type: {b: brightness, s: saturation, c: constast}
    """
    if cj_type is None:
        cj_type = random.choice(["b", "s", "c"])

    if cj_type == "b":
        # value = random.randint(-50, 50)
        value = np.random.choice(np.array([-50, -40, -30, 30, 40, 50]))
        hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
        h, s, v = cv2.split(hsv)
        if value >= 0:
            lim = 255 - value
            v[v > lim] = 255
            v[v <= lim] += value
        else:
            lim = np.absolute(value)
            v[v < lim] = 0
            v[v >= lim] -= np.absolute(value)

        final_hsv = cv2.merge((h, s, v))
        img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)
        return img
    elif cj_type == "s":
        # value = random.randint(-50, 50)
        value = np.random.choice(np.array([-50, -40, -30, 30, 40, 50]))
        hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
        h, s, v = cv2.split(hsv)
        if value >= 0:
            lim = 255 - value
            s[s > lim] = 255
            s[s <= lim] += value
        else:
            lim = np.absolute(value)
            s[s < lim] = 0
            s[s >= lim] -= np.absolute(value)

        final_hsv = cv2.merge((h, s, v))
        img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)
        return img

    elif cj_type == "c":
        brightness = 10
        contrast = random.randint(40, 100)
        dummy = np.int16(img)
        dummy = dummy * (contrast / 127 + 1) - contrast + brightness
        dummy = np.clip(dummy, 0, 255)
        img = np.uint8(dummy)
        return img


def rotate(*args):
    rotation = random.choice(
        [cv2.ROTATE_180, cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, -1]
    )
    if rotation == -1:
        return args

    return [cv2.rotate(x, rotation) for x in args]


def split_graph_multilines(G, seen, parent_node, current_node, current_line, all_lines):
    if current_node in seen:
        if len(current_line) > 0:
            # end of a cycle
            current_line.append(current_node)
            all_lines.append([current_line, False])
        return

    current_line.append(current_node)

    if len(G[current_node]) != 2 and len(current_line) > 1:
        # either a leaf or a junction node
        # the current line is finished
        all_lines.append([current_line, len(G[current_node]) == 1])
        current_line = [current_node]

    seen.add(current_node)

    for neigh in G[current_node]:
        if neigh != parent_node:
            new_current_line = copy.deepcopy(current_line)
            split_graph_multilines(
                G, seen, current_node, neigh, new_current_line, all_lines
            )


def find_interesting_vertex(G):
    for i in G.nodes():
        if len(G[i]) != 2:
            return i


def create_graph(vertices, lines):
    graph = nx.Graph()

    for i, vertex in enumerate(vertices):
        graph.add_node(i, x=vertex[0], y=vertex[1])

    for line in lines:
        graph.add_edge(line[0], line[1])

    return graph


def simplify_graph(
    pred_graph,
    max_distance=70,
    ignore_lines_length=10,
    epsilon=10.0,
    min_distance_between_vertices=5,
):
    # simplification should be done in the whole graph ...
    #     sub_graphs = list(nx.connected_component_subgraphs(pred_graph))
    sub_graphs = (
        pred_graph.subgraph(c).copy() for c in nx.connected_components(pred_graph)
    )

    current_vertices = {}

    vertices = []
    lines = []

    for graph in sub_graphs:
        interesting_vertex = find_interesting_vertex(graph)

        if interesting_vertex is None:
            print(
                "interesting_vertex vertex is None!, means everything is in a circle "
            )
            interesting_vertex = next(iter(graph.nodes()))

        all_lines = []
        seen = set()

        # split graphs into lines
        split_graph_multilines(graph, seen, -1, interesting_vertex, [], all_lines)

        # based on all lines simplify the graphs
        for one_line, is_leaf in all_lines:
            lines_coords = [
                [graph.nodes[i]["x"], graph.nodes[i]["y"]] for i in one_line
            ]

            if LineString(lines_coords).length < ignore_lines_length and is_leaf:
                # is leaf is true when the lines ends on a leaf node (approximately ... )
                continue

            simplified = simplify_coords(lines_coords, epsilon)

            for i in range(len(simplified)):
                x, y = int(simplified[i][0]), int(simplified[i][1])
                if len(vertices) > 0:
                    differences = np.linalg.norm(
                        np.array(vertices) - np.array([x, y])[None], axis=1
                    )
                    if differences.min() < min_distance_between_vertices:
                        index = differences.argmin()
                        x, y = vertices[index][0], vertices[index][1]

                if (x, y) not in current_vertices:
                    current_vertices[(x, y)] = len(current_vertices)
                    vertices.append([x, y])

                if i > 0:
                    vector = np.array([x - prev_x, y - prev_y])
                    if np.linalg.norm(vector) > max_distance:
                        num_inside = math.ceil(np.linalg.norm(vector) / max_distance)

                        original_x, original_y = prev_x, prev_y

                        for num in range(num_inside):
                            pos = vector * (num + 1) / num_inside
                            new_x, new_y = original_x + int(pos[0]), original_y + int(
                                pos[1]
                            )
                            if len(vertices) > 0:
                                differences = np.linalg.norm(
                                    np.array(vertices) - np.array([new_x, new_y])[None],
                                    axis=1,
                                )

                                if differences.min() < min_distance_between_vertices:
                                    index = differences.argmin()
                                    new_x, new_y = (
                                        vertices[index][0],
                                        vertices[index][1],
                                    )

                            if (new_x, new_y) not in current_vertices:
                                current_vertices[(new_x, new_y)] = len(current_vertices)
                                vertices.append([new_x, new_y])
                            lines.append(
                                [
                                    current_vertices[(prev_x, prev_y)],
                                    current_vertices[(new_x, new_y)],
                                ]
                            )

                            prev_x, prev_y = new_x, new_y

                    else:
                        lines.append(
                            [
                                current_vertices[(prev_x, prev_y)],
                                current_vertices[(x, y)],
                            ]
                        )

                prev_x, prev_y = x, y

    vertices, lines = np.array(vertices), np.array(lines)

    if len(vertices) == 0:
        vertices = np.empty((0, 2))
    if len(lines) == 0:
        lines = np.empty((0, 2))

    return vertices, lines


def create_line_image(start, end, h, w, line_width=3):
    img = np.zeros((h, w))
    cv2.line(
        img,
        (
            int(start[0]),
            int(start[1]),
        ),
        (
            int(end[0]),
            int(end[1]),
        ),
        (1),
        line_width,
    )

    return img


def create_single_segmentation(
    graphs,
    x_crops,
    y_crops,
    crop_size,
    cropped_image_size,
    new_image_size,
    line_width=10,
    ignore_close_to_border=None,
    max_x_crop=900,
    max_y_crop=900,
):
    if type(new_image_size) == int:
        new_image_size_x = new_image_size
        new_image_size_y = new_image_size
    else:
        new_image_size_x = new_image_size[0]
        new_image_size_y = new_image_size[1]

    if type(cropped_image_size) == int:
        cropped_image_size_x = cropped_image_size
        cropped_image_size_y = cropped_image_size
    else:
        cropped_image_size_x = cropped_image_size[0]
        cropped_image_size_y = cropped_image_size[1]

    if type(crop_size) == int:
        crop_size_x = crop_size
        crop_size_y = crop_size
    else:
        crop_size_x = crop_size[0]
        crop_size_y = crop_size[1]

    scaling_factor_x = crop_size_x / cropped_image_size_x
    scaling_factor_y = crop_size_y / cropped_image_size_y

    global_segmentation = np.zeros((new_image_size_x, new_image_size_y))

    for graph, x_crop, y_crop in zip(graphs, x_crops, y_crops):
        #         print(x_crop, y_crop)
        if ignore_close_to_border is not None:
            vertices = []
            lines = []

            vertices_map = {}
            for node in graph.nodes():
                vertices_map[node] = len(vertices)
                vertices.append([graph.nodes[node]["x"], graph.nodes[node]["y"]])

            for x, y in graph.edges():
                lines.append([vertices_map[x], vertices_map[y]])

            pad_length_x_start = ignore_close_to_border if x_crop > 0 else 0
            pad_length_x_end = ignore_close_to_border if x_crop < max_x_crop else 0
            pad_length_y_start = ignore_close_to_border if y_crop > 0 else 0
            pad_length_y_end = ignore_close_to_border if y_crop < max_y_crop else 0
            vertices, lines = load_vertices_and_lines_in_crop(
                None,
                None,
                0,
                0,
                300,
                300,
                0,
                10000,
                (0, 0),
                vertices=vertices,
                lines=lines,
                pad_length_x_start=pad_length_x_start,
                pad_length_x_end=pad_length_x_end,
                pad_length_y_start=pad_length_y_start,
                pad_length_y_end=pad_length_y_end,
            )
            graph = create_graph(vertices, lines)

        for x, y in graph.edges():
            start = [
                graph.nodes[x]["x"] * scaling_factor_x + x_crop,
                graph.nodes[x]["y"] * scaling_factor_x + y_crop,
            ]
            end = [
                graph.nodes[y]["x"] * scaling_factor_y + x_crop,
                graph.nodes[y]["y"] * scaling_factor_y + y_crop,
            ]
            global_segmentation += create_line_image(
                start, end, new_image_size_x, new_image_size_y, line_width=line_width
            )

    return np.clip(global_segmentation, 0, 1)


def better_mapextract(prediction, smooth_dist=4):
    from skimage.morphology import skeletonize

    ske = skeletonize(np.swapaxes(prediction, axis1=0, axis2=1)).astype(np.uint16)
    graph = sknw.build_sknw(ske, multi=True)

    segments = graph_utils.simplify_graph(graph, smooth_dist)

    vertices_seen = {}
    lines_seen = set()

    vertices = []
    lines = []
    for segment in segments:
        for i in range(len(segment) - 1):
            point1 = segment[i]
            point2 = segment[i + 1]

            for point in [point1, point2]:
                if tuple(point) not in vertices_seen:
                    vertices_seen[tuple(point)] = len(vertices_seen)
                    vertices.append(point)

            vertex_index1 = vertices_seen[tuple(point1)]
            vertex_index2 = vertices_seen[tuple(point2)]

            if (vertex_index1, vertex_index2) in lines_seen or (
                vertex_index2,
                vertex_index1,
            ) in lines_seen:
                continue

            lines_seen.add((vertex_index1, vertex_index2))
            lines.append([vertex_index1, vertex_index2])

    return vertices, lines


def random_shift(vertices, shift_factor=0.1, quantization_bits=8, multiply_factor=0.85):
    """Apply random shift to vertices."""
    if vertices.shape[0] <= 0:
        return vertices

    vertices = vertices * multiply_factor

    max_value = 2 ** quantization_bits
    max_shift_pos = (max_value - np.max(vertices, axis=0)).astype(float)
    max_shift_pos = np.maximum(max_shift_pos, 1e-9)

    max_shift_neg = (np.min(vertices, axis=0)).astype(float)
    max_shift_neg = np.maximum(max_shift_neg, 1e-9)
    max_shift_neg = -max_shift_neg

    my_std = max_value * shift_factor

    a, b = max_shift_neg / my_std, max_shift_pos / my_std

    shift = truncnorm.rvs(a, b, loc=(0, 0), scale=my_std)
    shift = shift
    vertices += shift

    # make sure this holds..
    vertices[vertices < 1] = 1
    vertices[vertices > max_value] = max_value

    return np.around(vertices).astype(int)


def ccw(A, B, C):
    return (C[1] - A[1]) * (B[0] - A[0]) > (B[1] - A[1]) * (C[0] - A[0])


# Return true if line segments AB and CD intersect
def intersect(A, B, C, D):
    return ccw(A, C, D) != ccw(B, C, D) and ccw(A, B, C) != ccw(A, B, D)


def point_is_inside(x, y, x_crop, y_crop, img_size):
    return (
        x >= x_crop and x < x_crop + img_size and y >= y_crop and y < y_crop + img_size
    )


def add_lines_between(
    point_1,
    point_2,
    new_lines,
    new_vertices,
    insert_intermediate_vertices_probs,
    min_distance_between_intermediate_vertices,
    num_intermediate_vertices,
):
    if random.random() > insert_intermediate_vertices_probs:
        new_lines.append([new_vertices[tuple(point_1)], new_vertices[tuple(point_2)]])
    else:
        rand_lens = np.random.rand(
            random.randint(
                num_intermediate_vertices[0] + 1,
                num_intermediate_vertices[1] + 1,
            )
        )
        # for each new point check that the distance rule applies ..
        start_point = point_1
        prev_point = start_point
        last_point = point_2

        for proportion in (rand_lens.cumsum() / rand_lens.sum())[:-1]:
            new_point = proportion * (last_point - start_point) + start_point
            if (
                min(
                    np.linalg.norm(new_point - prev_point),
                    np.linalg.norm(new_point - last_point),
                )
                > min_distance_between_intermediate_vertices
            ):
                new_vertices[tuple(new_point)] = len(new_vertices)
                new_lines.append(
                    [new_vertices[tuple(prev_point)], new_vertices[tuple(new_point)]]
                )
                prev_point = new_point
                # print("new", prev_point, new_point)

        new_lines.append(
            [new_vertices[tuple(prev_point)], new_vertices[tuple(last_point)]]
        )


def line_intersection(line1, line2, x_crop, y_crop, img_size_x, img_size_y):
    if not intersect(line1[0], line1[1], line2[0], line2[1]):
        return None

    xdiff = (line1[0][0] - line1[1][0], line2[0][0] - line2[1][0])
    ydiff = (line1[0][1] - line1[1][1], line2[0][1] - line2[1][1])

    def det(a, b):
        return a[0] * b[1] - a[1] * b[0]

    div = det(xdiff, ydiff)
    if div == 0:
        return None

    d = (det(*line1), det(*line2))
    x = det(d, xdiff) / div
    y = det(d, ydiff) / div
    x, y = round(x), round(y)
    if (
        x >= x_crop
        and x < x_crop + img_size_x
        and y >= y_crop
        and y < y_crop + img_size_y
    ):
        return x, y
    else:
        return None


def load_vertices_and_lines_in_crop(
    vertices_path,
    lines_path,
    x_crop,
    y_crop,
    crop_size,
    new_image_size,
    insert_intermediate_vertices_probs,
    min_distance_between_intermediate_vertices,
    num_intermediate_vertices,
    rescale=True,
    vertices=None,
    lines=None,
    pad_length_x_start=0,  # do not go too close to the borders
    pad_length_x_end=0,  # do not go too close to the borders
    pad_length_y_start=0,  # do not go too close to the borders
    pad_length_y_end=0,  # do not go too close to the borders
):
    if vertices is None:
        # vertices here have values 0 -> 1299
        vertices = load_vertices(vertices_path)
    if lines is None:
        lines = np.load(lines_path) - 1

    if type(crop_size) == int:
        crop_size_x = crop_size
        crop_size_y = crop_size
    else:
        crop_size_x = crop_size[1]
        crop_size_y = crop_size[0]

    if type(new_image_size) == int:
        new_image_size_x = new_image_size
        new_image_size_y = new_image_size
    else:
        new_image_size_x = new_image_size[0]
        new_image_size_y = new_image_size[1]

    new_vertices = {}
    new_lines = []

    for vertex in vertices:
        if (
            vertex[0] >= x_crop + pad_length_x_start
            and vertex[0] < x_crop + crop_size_x - pad_length_x_end
            and vertex[1] >= y_crop + pad_length_y_start
            and vertex[1] < y_crop + crop_size_y - pad_length_y_end
        ):
            if tuple(vertex) not in new_vertices:
                # vertex belongs in the box
                new_vertices[tuple(vertex)] = len(new_vertices)

    for line in lines:
        num_inside = (tuple(vertices[line[0]]) in new_vertices) + (
            tuple(vertices[line[1]]) in new_vertices
        )

        # find intersections with boundaries
        # left
        res1 = line_intersection(
            (vertices[line[0]], vertices[line[1]]),
            (
                [x_crop + pad_length_x_start, y_crop + pad_length_y_start],
                [
                    x_crop + crop_size_x - 1 - pad_length_x_end,
                    y_crop + pad_length_y_start,
                ],
            ),
            x_crop,
            y_crop,
            crop_size_x,
            crop_size_y,
        )
        # right
        res2 = line_intersection(
            (vertices[line[0]], vertices[line[1]]),
            (
                [
                    x_crop + pad_length_x_start,
                    y_crop + crop_size_y - 1 - pad_length_y_end,
                ],
                [
                    x_crop + crop_size_x - 1 - pad_length_x_end,
                    y_crop + crop_size_y - 1 - pad_length_y_end,
                ],
            ),
            x_crop,
            y_crop,
            crop_size_x,
            crop_size_y,
        )
        # down
        res3 = line_intersection(
            (vertices[line[0]], vertices[line[1]]),
            (
                [x_crop + pad_length_x_start, y_crop + pad_length_y_start],
                [
                    x_crop + pad_length_x_start,
                    y_crop + crop_size_y - 1 - pad_length_y_end,
                ],
            ),
            x_crop,
            y_crop,
            crop_size_x,
            crop_size_y,
        )
        # up
        res4 = line_intersection(
            (vertices[line[0]], vertices[line[1]]),
            (
                [
                    x_crop + crop_size_x - 1 - pad_length_x_end,
                    y_crop + pad_length_y_start,
                ],
                [
                    x_crop + crop_size_x - 1 - pad_length_x_end,
                    y_crop + crop_size_y - 1 - pad_length_y_end,
                ],
            ),
            x_crop,
            y_crop,
            crop_size_x,
            crop_size_y,
        )

        results = [res1, res2, res3, res4]

        if num_inside == 0:
            num_intersections = reduce(
                lambda x, y: x + y, [res is not None for res in results]
            )
            if num_intersections >= 2:
                point_1 = None
                point_2 = None
                for res in results:
                    if res is not None:
                        x, y = res
                        if tuple([x, y]) not in new_vertices:
                            new_vertices[tuple([x, y])] = len(new_vertices)
                        if point_1 is None:
                            point_1 = [x, y]
                        else:
                            point_2 = [x, y]
                            break

                add_lines_between(
                    np.array(point_1),
                    np.array(point_2),
                    new_lines,
                    new_vertices,
                    insert_intermediate_vertices_probs,
                    min_distance_between_intermediate_vertices
                    * crop_size_x
                    / new_image_size_x,
                    num_intermediate_vertices,
                )
        elif num_inside == 2:
            add_lines_between(
                vertices[line[0]],
                vertices[line[1]],
                new_lines,
                new_vertices,
                insert_intermediate_vertices_probs,
                min_distance_between_intermediate_vertices
                * crop_size_x
                / new_image_size_x,
                num_intermediate_vertices,
            )
        else:
            # add new line by finding intersection with all boundaries
            # left
            point_inside = (
                vertices[line[0]]
                if tuple(vertices[line[0]]) in new_vertices
                else vertices[line[1]]
            )

            for res in results:
                if res is not None:
                    x, y = res
                    if tuple([x, y]) not in new_vertices:
                        new_vertices[tuple([x, y])] = len(new_vertices)

                    add_lines_between(
                        point_inside,
                        np.array([x, y]),
                        new_lines,
                        new_vertices,
                        insert_intermediate_vertices_probs,
                        min_distance_between_intermediate_vertices
                        * crop_size_x
                        / new_image_size_x,
                        num_intermediate_vertices,
                    )

    new_vertices = np.array(
        [list(k) for k, v in sorted(new_vertices.items(), key=lambda item: item[1])]
    )

    if new_vertices.shape[0] > 0:
        new_vertices = new_vertices - np.array([x_crop, y_crop])

        new_vertices[:, 0] = (
            new_vertices[:, 0] / (crop_size_y - 1) * (new_image_size_x - 1)
        )
        new_vertices[:, 1] = (
            new_vertices[:, 1] / (crop_size_x - 1) * (new_image_size_y - 1)
        )

    new_vertices = np.around(new_vertices).astype(int)

    # after quantization perhaps vertices have collapsed, or maybe vertices exist with lines connecting themselves...
    vertices_with_lines = set()
    for line in new_lines:
        ver1 = new_vertices[line[0]]
        ver2 = new_vertices[line[1]]
        # if differnet point
        if ver1[0] != ver2[0] or ver1[1] != ver2[1]:
            vertices_with_lines.add(tuple(ver1))
            vertices_with_lines.add(tuple(ver2))

    new_vertices_filtered = []
    new_lines_positions = {}

    for i, ver in enumerate(new_vertices):
        if tuple(ver) in vertices_with_lines:
            new_lines_positions[tuple(ver)] = len(new_lines_positions)
            new_vertices_filtered.append(ver)

    new_lines_filtered = []
    for line in new_lines:
        ver1 = new_vertices[line[0]]
        ver2 = new_vertices[line[1]]

        if tuple(ver1) in new_lines_positions and tuple(ver2) in new_lines_positions:
            new_lines_filtered.append(
                [new_lines_positions[tuple(ver1)], new_lines_positions[tuple(ver2)]]
            )

    return np.array(new_vertices_filtered), np.array(new_lines_filtered)


def load_vertices(file):
    vertices = np.load(file)[:, [1, 0]]
    return np.array(vertices, np.int32)
