import os
import numpy as np
import torch
import cv2

from dataprocessing.data_utils import (
    rotate,
    colorjitter,
    create_graph,
    simplify_graph,
    better_mapextract,
    create_single_segmentation,
    load_vertices_and_lines_in_crop,
)


class CityEngineDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        base_dir,
        mode="train",
        new_image_size=300,
        crop_size=400,
        max_num_vertices=200,
        max_num_lines=400,
        augment=False,
        load_segmentation=None,
        insert_intermediate_vertices_probs=1,
        min_distance_between_intermediate_vertices=None,
        num_intermediate_vertices=(1, 3),
        max_samples=None,
        line_width=None,
        return_tree_mask=False,
        test_size=0.2,
        simplify_graph=True,
        smooth_dist=2,
    ):
        assert mode in ["train", "val"]

        super().__init__()
        self.augment = augment
        self.new_image_size = new_image_size
        self.crop_size = crop_size
        self.line_width = line_width
        self.return_tree_mask = return_tree_mask

        self.max_num_vertices = max_num_vertices
        self.max_num_lines = max_num_lines
        self.load_segmentation = load_segmentation

        # randomly sample intermediate key points along edges in the segmentation prediction
        self.insert_intermediate_vertices_probs = insert_intermediate_vertices_probs
        self.min_distance_between_intermediate_vertices = (
            min_distance_between_intermediate_vertices
        )
        self.num_intermediate_vertices = num_intermediate_vertices

        self.max_samples = max_samples

        self.simplify_graph = simplify_graph

        self.images_path = []
        self.masks_path = []

        self.base_dir = base_dir
        self.images_path = list(filter(lambda x: "rgb" in x, os.listdir(base_dir)))

        if mode == "train":
            self.images_path = self.images_path[
                : int(len(self.images_path) * (1 - test_size))
            ]
        else:
            self.images_path = self.images_path[
                int(len(self.images_path) * (1 - test_size)) :
            ]

        self.smooth_dist = smooth_dist

    def __len__(self):
        if self.max_samples is not None:
            return min(len(self.images_path), self.max_samples)

        return len(self.images_path)

    def __getitem__(
        self, idx, specific_sample=None, original_sizes=False, add_segmentation=None
    ):
        if specific_sample is None:
            if torch.is_tensor(idx):
                idx = idx.tolist()

            x_crop_start = 0
            y_crop_start = 0
        else:
            idx, x_crop_start, y_crop_start = specific_sample

        img = cv2.imread(os.path.join(self.base_dir, self.images_path[idx]))
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

        if self.augment:
            img = colorjitter(img)

        img = img / 255.0

        mask = cv2.imread(
            os.path.join(self.base_dir, self.images_path[idx].replace("rgb", "roads")), 0
        ) / 255.

        if self.return_tree_mask:
            tree_mask = cv2.imread(
                os.path.join(
                    self.base_dir, self.images_path[idx].replace("rgb", "plants")
                ), 0
            ) / 255.
        else:
            tree_mask = None

        if self.load_segmentation is not None:
            segmentation = (
                np.load(
                    os.path.join(
                        self.load_segmentation,
                        self.images_path[idx].split("_")[2] + ".npy",
                    )
                )
                > 0.5
            )
            if add_segmentation is not None:
                segmentation += add_segmentation > 0.5

            segmentation = segmentation.astype(np.float)
        else:
            segmentation = None

        if not original_sizes:
            if self.augment:
                img, mask, tree_mask, segmentation = rotate(
                    img, mask, tree_mask, segmentation
                )

                x_crop_start = np.random.randint(0, img.shape[0] - self.crop_size)
                y_crop_start = np.random.randint(0, img.shape[1] - self.crop_size)

                img = img[
                    x_crop_start : x_crop_start + self.crop_size,
                    y_crop_start : y_crop_start + self.crop_size,
                ]
            else:
                # just for comparable test results
                img = img[
                    x_crop_start : x_crop_start + self.crop_size,
                    y_crop_start : y_crop_start + self.crop_size,
                ]

            img = cv2.resize(img, (self.new_image_size, self.new_image_size))

        if self.line_width is not None and not original_sizes:
            vertices, lines = better_mapextract(mask, smooth_dist=self.smooth_dist)

            mask = create_single_segmentation(
                [create_graph(vertices, lines)],
                [0],
                [0],
                mask.shape,
                mask.shape,
                mask.shape,
                line_width=self.line_width,
            )

        dic = {}
        dic["image"] = np.transpose(img, (2, 0, 1))

        if self.return_tree_mask:
            dic["tree_mask"] = self.crop_and_resize(tree_mask, x_crop_start, y_crop_start)

        vertices, lines = self.get_vertices_lines_from_mask(
            mask, x_crop_start, y_crop_start, original_sizes=original_sizes
        )
        dic["vertices"], dic["lines"] = vertices, lines

        if self.load_segmentation:
            vertices, lines = self.get_vertices_lines_from_mask(
                segmentation, x_crop_start, y_crop_start, original_sizes=original_sizes
            )
            dic["vertices_pred"], dic["lines_pred"] = vertices, lines

        if not original_sizes:
            mask = self.crop_and_resize(mask, x_crop_start, y_crop_start)

            if self.load_segmentation:
                segmentation = self.crop_and_resize(segmentation, x_crop_start, y_crop_start)

        dic["mask"] = mask
        if self.load_segmentation:
            dic["segmentation"] = segmentation

        dic["city"] = "synthetic"
        return dic

    def crop_and_resize(self, mask, x_crop_start, y_crop_start):
        mask = mask[
            x_crop_start : x_crop_start + self.crop_size,
            y_crop_start : y_crop_start + self.crop_size,
        ]
        mask = cv2.resize(
            mask.astype(float), (self.new_image_size, self.new_image_size)
        )
        return mask

    def get_vertices_lines_from_mask(
        self, mask, x_crop_start, y_crop_start, original_sizes=False
    ):
        vertices, lines = better_mapextract(mask > 0.5, smooth_dist=self.smooth_dist)

        if self.simplify_graph:
            vertices, lines = simplify_graph(
                create_graph(vertices, lines),
                min_distance_between_vertices=5,
                max_distance=160,
            )

        vertices, lines = load_vertices_and_lines_in_crop(
            None,
            None,
            y_crop_start,
            x_crop_start,
            self.crop_size if not original_sizes else mask.shape,
            self.new_image_size if not original_sizes else mask.shape,
            self.insert_intermediate_vertices_probs,
            self.min_distance_between_intermediate_vertices,
            self.num_intermediate_vertices,
            vertices=np.array(vertices),
            lines=np.array(lines),
        )

        return self.process_vertices_lines(vertices, lines)

    def process_vertices_lines(self, vertices, lines):
        # Make sure that indices are sorted and only present once after the quantization
        indices = sorted(
            range(vertices.shape[0]),
            key=lambda x: (vertices[x][0], vertices[x][1]),
        )

        vertices_seen = {}
        vertices_seen_reverse = {}
        reverse_indices = {}

        k = 0
        for i in range(len(indices)):
            check_vertex = vertices[indices[i]]
            # vertex overlap..
            if tuple(check_vertex) not in vertices_seen:
                vertices_seen[tuple(check_vertex)] = k
                vertices_seen_reverse[k] = tuple(check_vertex)
                k += 1

            reverse_indices[indices[i]] = vertices_seen[tuple(check_vertex)]

        vertices = np.zeros((len(vertices_seen), 2), dtype=float)
        for k, v in vertices_seen_reverse.items():
            vertices[k][0] = v[0]
            vertices[k][1] = v[1]

        vertices = vertices[: self.max_num_vertices]  # , [1, 0]]

        # add tuplicate elements that are going be removed afterwards
        lines_duplicate = np.zeros((lines.shape[0] * 2, 2), dtype=int)

        for i, line in enumerate(lines):
            for j in range(2):
                lines_duplicate[i * 2, j] = reverse_indices[line[j]]
                lines_duplicate[i * 2 + 1, j] = reverse_indices[line[1 - j]]

        lines = []
        seen = set()

        for line in sorted(lines_duplicate, key=lambda x: (x[0], x[1])):
            if tuple(sorted(line)) in seen:
                continue

            lines.append(line)
            seen.add(tuple(sorted(line)))

        lines = np.array(lines)

        # Remove lines connecting the same vertices....
        if lines.shape[0] > 0:
            lines = lines[lines[:, 0] != lines[:, 1]]

        # Remove vertices with no connecting lines.. (already done when cropping image)
        # (set(range(dic['vertices'].shape[0])) - set(dic['lines'].reshape(-1)))

        # indices 0 and 1 are reserved for stop and break tokens
        lines = np.array(lines)
        if len(lines) <= 0:
            lines = np.empty((0, 2))

        # remove lines that correspond to vertices removed ... (>= max_num_vertices + 2)
        if vertices.shape[0] > 0:
            lines = lines[(lines < vertices.shape[0]).all(axis=1)]

        # 0 corresponds to the end of sequence..
        lines = np.array(lines[: self.max_num_lines], dtype=np.int64) + 1

        return vertices, lines

    def revert_normalization(self, image, city):
        return image
