import gym
import random
import numpy as np
import matplotlib.pyplot as plt
from shapely.geometry import Point, LineString
import networkx as nx
import math

from utils.metrics_utils import get_all_metrics
from utils.apls import make_graphs, compute_apls_metric, calculate_apls_metric
from dataprocessing.cityengine_dataset import CityEngineDataset
from environment import env_utils


def bound_metric(x):
    if math.isnan(x):
        return 0
    return max(min(x, 1), 0)


datasets = {
    "cityengine": CityEngineDataset,
}


class RoadKeypoints(gym.Env):
    def __init__(
        self,
        dataset_type="cityengine",
        dataset_config={},
        max_step=40,
        image_size=300,
        reward_weights=[1, 0, 0, 0, 0],
        initialize_to_pred_graph=False,
        min_connected_components_length=-1,
        min_linestrings_length=-1,
    ):
        """
        Reward weights indicate how to weight rewards based on:
            - Binary reward, encourages predictions as the provided segmetnation mask
            - APLS
            - Junction based f1
            - Path based f1
            - Graph based f1
        """
        self.dataset_type = dataset_type
        assert self.dataset_type in datasets

        # initialize according to segmentaiton prediction
        self.initialize_to_pred_graph = initialize_to_pred_graph
        # only add edges that belong to a connected comonents of this length
        self.min_connected_components_length = min_connected_components_length
        # remove deadends of small length
        self.min_linestrings_length = min_linestrings_length

        self.image_size = image_size
        self.max_step = max_step

        # binary, apls, optj, optp, optg
        self.reward_weights = reward_weights

        self.dataset = datasets[dataset_type](**dataset_config)

    def observation(self):
        return_dic = {
            "image": self.image,
            "keypoints": self.keypoints,
            "keypoints_len": np.array(self.keypoints.shape[0]).astype(int),
            "current_edges": np.array(self.current_edges) + 1,
            "current_edges_len": np.array(len(self.current_edges)).astype(int),
            "step_percentage": np.array(self.current_step * 1.0 / self.max_step),
        }

        return return_dic

    def metrics(self):
        metrics = {}
        metrics["apls"] = calculate_apls_metric(self.G_gt, self.G_pred)

        try:
            results = get_all_metrics(
                env_utils.transform_graph(self.G_gt),
                env_utils.transform_graph(self.G_pred),
                self.dataset.new_image_size,
                self.dataset.new_image_size,
            )

            metrics["OPTJ_precision"] = results[0].split(" ")[-3].split("=")[1]
            metrics["OPTJ_recall"] = results[0].split(" ")[-2].split("=")[1]
            metrics["OPTJ_f1"] = results[0].split(" ")[-1].split("=")[1]

            metrics["Corr"] = results[1].split(" ")[-3].split("=")[1]
            metrics["Comp"] = results[1].split(" ")[-2].split("=")[1]
            metrics["Qual"] = results[1].split(" ")[-1].split("=")[1]

            metrics["2long-2short_correct"] = results[2].split(" ")[-3].split("=")[1]
            metrics["2long-2short_2l_2s"] = results[2].split(" ")[-2].split("=")[1]
            metrics["2long-2short_inf"] = results[2].split(" ")[-1].split("=")[1]

            metrics["OPTP_con_prob_precision"] = results[3].split(" ")[-3].split("=")[1]
            metrics["OPTP_con_prob_recall"] = results[3].split(" ")[-2].split("=")[1]
            metrics["OPTP_con_prob_f1"] = results[3].split(" ")[-1].split("=")[1]

            metrics["hole_marbles_spurious"] = results[4].split(" ")[-3].split("=")[1]
            metrics["hole_marbles_missing"] = results[4].split(" ")[-2].split("=")[1]
            metrics["hole_marbles_f1"] = results[4].split(" ")[-1].split("=")[1]

            metrics["OPTG_spurious"] = results[5].split(" ")[-3].split("=")[1]
            metrics["OPTG_missings"] = results[5].split(" ")[-2].split("=")[1]
            metrics["OPTG_f1"] = results[5].split(" ")[-1].split("=")[1]
        except:
            metrics["OPTJ_precision"] = 0
            metrics["OPTJ_recall"] = 0
            metrics["OPTJ_f1"] = 0

            metrics["Corr"] = 0
            metrics["Comp"] = 0
            metrics["Qual"] = 0

            metrics["2long-2short_correct"] = 0
            metrics["2long-2short_2l_2s"] = 0
            metrics["2long-2short_inf"] = 0

            metrics["OPTP_con_prob_precision"] = 0
            metrics["OPTP_con_prob_recall"] = 0
            metrics["OPTP_con_prob_f1"] = 0

            metrics["hole_marbles_spurious"] = 0
            metrics["hole_marbles_missing"] = 0
            metrics["hole_marbles_f1"] = 0

            metrics["OPTG_spurious"] = 0
            metrics["OPTG_missings"] = 0
            metrics["OPTG_f1"] = 0

        return metrics

    def get_training_metrics(self):
        metrics = self.metrics()
        return (
            bound_metric(float(metrics["apls"])),
            bound_metric(float(metrics["OPTJ_f1"])),
            bound_metric(float(metrics["OPTP_con_prob_f1"])),
            bound_metric(float(metrics["OPTG_f1"])),
        )

    def statistics(self):
        # should only be called after done and before reset
        return (
            self.current_step,
            self.num_correct_edges,
            self.num_overlapping_edges,
            self.num_wrong_edges,
            self.current_step - self.starting_step,
            len(self.lines_remaining) == 0,
        )

    def step(self, action):
        self.current_step += 1

        done = action == 0 or self.current_step > self.max_step or self.done

        if done:
            self.done = True
            return self.observation(), 0, True, None

        if self.next_action_is_select_first_keypoint:
            # select starting point for an edge

            self.next_action_is_select_first_keypoint = False
            self.current_edges.append(action - 1)

            return self.observation(), 0, False, None

        self.next_action_is_select_first_keypoint = True

        action1 = self.current_edges[-1]
        action2 = action - 1

        reward = 0

        new_edge = (action1, action2)
        sorted_new_edge = env_utils.sort_edge(new_edge)

        # calculate a binary reward based on the segmentation prediction is requested
        if sorted_new_edge in self.lines_remaining:
            reward = 1 / len(self.true_lines)

            self.lines_remaining.remove(sorted_new_edge)
            self.num_correct_edges += 1
        elif sorted_new_edge in self.true_lines:
            reward = -0.5 / len(self.true_lines)

            self.num_overlapping_edges += 1
        else:
            reward = -1 / len(self.true_lines)
            self.num_wrong_edges += 1

        if action1 != action2:
            for action in [action1, action2]:
                if tuple(self.keypoints[action]) not in self.keypoints_seen:
                    self.keypoints_seen.add(tuple(self.keypoints[action]))
                    self.G_pred.add_node(
                        action, x=self.keypoints[action][0], y=self.keypoints[action][1]
                    )

            if sorted_new_edge not in self.G_pred.edges():
                line = LineString(
                    [
                        Point(self.keypoints[sorted_new_edge[0]]),
                        Point(self.keypoints[sorted_new_edge[1]]),
                    ]
                )
                self.G_pred.add_edge(
                    sorted_new_edge[0],
                    sorted_new_edge[1],
                    length=line.length,
                    geometry=line,
                )

        binary_reward = reward * self.reward_weights[0]

        new_metrics = self.get_training_metrics()
        connectivity_reward = sum(
            [
                (new - old) * w
                for old, new, w in zip(
                    self.old_metrics, new_metrics, self.reward_weights[1:]
                )
            ]
        )
        self.old_metrics = new_metrics

        reward = binary_reward + connectivity_reward

        self.current_edges.append(action2)
        self.total_reward += reward

        return self.observation(), reward, False, None

    def reset(self, idx=None, specific_sample=None, add_segmentation=None):
        if idx is None:
            idx = random.randint(0, len(self.dataset) - 1)

        sample = self.dataset.__getitem__(
            idx, specific_sample=specific_sample, add_segmentation=add_segmentation
        )

        # This is due to bug in apls calculation ..
        gt_node_bias = 100000

        if "vertices_pred" not in sample:
            # There is no provided segmentation prediction
            # Then train based on the gt key point locations
            self.image, self.keypoints, lines, self.city = (
                sample["image"],
                sample["vertices"],
                sample["lines"] - 1,
                sample["city"],
            )
            G_gt = env_utils.create_graph(
                self.keypoints, lines, node_bias=gt_node_bias
            )
        else:
            # There is a provided segmentation mask
            # Train based on keypoints from the provied graph
            (
                self.image,
                self.gt_keypoints,
                gt_lines,
                self.city,
                self.keypoints,
                lines,
            ) = (
                sample["image"],
                sample["vertices"],
                sample["lines"] - 1,
                sample["city"],
                sample["vertices_pred"],
                sample["lines_pred"] - 1,
            )

            G_gt = env_utils.create_graph(
                self.gt_keypoints, gt_lines, node_bias=gt_node_bias
            )

        self.true_lines = set([env_utils.sort_edge(e) for e in lines])
        self.lines_remaining = set([env_utils.sort_edge(e) for e in lines])

        self.G_gt = G_gt

        self.G_pred = nx.MultiGraph()

        self.keypoints_seen = set()

        self.total_road_length = sum(
            [
                np.linalg.norm(self.keypoints[lin[0]] - self.keypoints[lin[1]])
                for lin in lines
            ]
        )

        if self.initialize_to_pred_graph:
            assert (
                "vertices_pred" in sample
            ), "Preprocessing only valid if segmentation prediction is provided"
            self.current_edges = []

            tmp_G_pred = nx.MultiGraph()

            for new_edge in self.lines_remaining:
                for action in new_edge:
                    if tuple(self.keypoints[action]) not in self.keypoints_seen:
                        self.keypoints_seen.add(tuple(self.keypoints[action]))
                        tmp_G_pred.add_node(
                            action,
                            x=self.keypoints[action][0],
                            y=self.keypoints[action][1],
                        )
                sorted_new_edge = env_utils.sort_edge(new_edge)
                if sorted_new_edge in tmp_G_pred.edges():
                    continue
                line = LineString(
                    [
                        Point(self.keypoints[sorted_new_edge[0]]),
                        Point(self.keypoints[sorted_new_edge[1]]),
                    ]
                )
                tmp_G_pred.add_edge(
                    sorted_new_edge[0],
                    sorted_new_edge[1],
                    length=line.length,
                    geometry=line,
                )

            self.lines_remaining = self.true_lines.copy()
            self.keypoints_seen = set()

            # filter based on nodes by their connected components length
            connected_components = list(nx.connected_components(tmp_G_pred))
            for component in connected_components:
                length = 0
                edges_considred = set()
                for node in component:
                    for neigh, edge in tmp_G_pred[node].items():
                        sorted_new_edge = env_utils.sort_edge((node, neigh))
                        if sorted_new_edge in edges_considred:
                            continue
                        length += edge[0]["length"]

                if length > self.min_connected_components_length:
                    for node in component:
                        for neigh in tmp_G_pred[node]:
                            new_edge = [node, neigh]
                            for action in new_edge:
                                if (
                                    tuple(self.keypoints[action])
                                    not in self.keypoints_seen
                                ):
                                    self.keypoints_seen.add(
                                        tuple(self.keypoints[action])
                                    )
                                    self.G_pred.add_node(
                                        action,
                                        x=self.keypoints[action][0],
                                        y=self.keypoints[action][1],
                                    )

                            sorted_new_edge = env_utils.sort_edge(new_edge)
                            if sorted_new_edge in self.G_pred.edges():
                                continue

                            self.lines_remaining.remove(sorted_new_edge)
                            self.current_edges.append(node)
                            self.current_edges.append(neigh)
                            line = LineString(
                                [
                                    Point(self.keypoints[sorted_new_edge[0]]),
                                    Point(self.keypoints[sorted_new_edge[1]]),
                                ]
                            )
                            self.G_pred.add_edge(
                                sorted_new_edge[0],
                                sorted_new_edge[1],
                                length=line.length,
                                geometry=line,
                            )

            if self.min_linestrings_length > 0:
                self.lines_remaining = self.true_lines.copy()
                self.keypoints_seen = set()
                self.current_edges = []

                tmp_G_pred = self.G_pred.copy()
                self.G_pred = nx.MultiGraph()
                env_utils.remove_edges_in_small_segments(
                    tmp_G_pred,
                    self.min_linestrings_length,
                    image_size=self.image_size,
                )

                for x, y in tmp_G_pred.edges():
                    new_edge = [x, y]
                    for action in new_edge:
                        if tuple(self.keypoints[action]) not in self.keypoints_seen:
                            self.keypoints_seen.add(tuple(self.keypoints[action]))
                            self.G_pred.add_node(
                                action,
                                x=self.keypoints[action][0],
                                y=self.keypoints[action][1],
                            )
                    sorted_new_edge = env_utils.sort_edge(new_edge)
                    if sorted_new_edge in self.G_pred.edges():
                        continue

                    self.lines_remaining.remove(sorted_new_edge)
                    self.current_edges.append(x)
                    self.current_edges.append(y)
                    line = LineString(
                        [
                            Point(self.keypoints[sorted_new_edge[0]]),
                            Point(self.keypoints[sorted_new_edge[1]]),
                        ]
                    )
                    self.G_pred.add_edge(
                        sorted_new_edge[0],
                        sorted_new_edge[1],
                        length=line.length,
                        geometry=line,
                    )

            self.old_metrics = self.get_training_metrics()
        else:
            self.current_edges = []
            self.old_metrics = (0, 0, 0, 0)

        self.done = False
        self.current_step = len(self.current_edges)
        assert self.current_step % 2 == 0

        self.starting_step = self.current_step
        self.num_correct_edges = 0
        self.num_overlapping_edges = 0
        self.num_wrong_edges = 0
        self.total_reward = 0
        self.next_action_is_select_first_keypoint = True

        return self.observation()

    def legal_actions(self):
        return list(range(self.keypoints.shape[0] + 1))

    def render(
        self,
        mode="human",
        close=False,
        plot_edges=True,
        show_numbers=True,
        show_title=True,
    ):
        plt.ioff()

        obs = self.observation()
        # here we want zero idnexing ..
        obs["current_edges"] = obs["current_edges"] - 1

        fig, ax = plt.subplots(figsize=(9, 9))
        ax.axis("off")

        ax.imshow(
            self.dataset.revert_normalization(
                obs["image"].transpose((1, 2, 0)), self.city
            ),
            zorder=0,
        )

        ax.scatter(
            obs["keypoints"][:, 0], obs["keypoints"][:, 1], color="blue", s=80, zorder=1
        )

        if hasattr(self, "gt_keypoints"):
            ax.scatter(
                self.gt_keypoints[:, 0],
                self.gt_keypoints[:, 1],
                color="green",
                s=80,
                zorder=1,
            )

        if show_numbers:
            for i in range(1, obs["keypoints"].shape[0] + 1):
                ax.annotate(
                    str(i),
                    (obs["keypoints"][i - 1][0], obs["keypoints"][i - 1][1]),
                    fontsize=26,
                    zorder=10,
                    color="orange",
                )

        if obs["current_edges_len"] % 2 == 1:
            if show_title:
                ax.set_title(
                    "Last keypoint selected: " + str(obs["current_edges"][-1] + 1)
                )
            # last action was to select this vertex ...
            ax.scatter(
                [obs["keypoints"][obs["current_edges"][-1], 0]],
                [obs["keypoints"][obs["current_edges"][-1], 1]],
                color="red",
                s=100,
                zorder=2,
            )
        for i in range(0, obs["current_edges"].shape[0] // 2 * 2, 2):
            e1 = obs["current_edges"][i]
            e2 = obs["current_edges"][i + 1]
            ax.plot(
                [obs["keypoints"][e1, 0], obs["keypoints"][e2, 0]],
                [obs["keypoints"][e1, 1], obs["keypoints"][e2, 1]],
                color="blue",
                linewidth=3,
                zorder=1,
            )

        if (
            obs["current_edges_len"] % 2 == 0
            and len(obs["current_edges"]) > 0
            and not self.done
        ):
            # last actino was to select this edge
            e1 = obs["current_edges"][-2]
            e2 = obs["current_edges"][-1]
            edge = (e1, e2)
            if show_title:
                ax.set_title("Last edge selected: " + str((edge[0] + 1, edge[1] + 1)))
            ax.plot(
                [obs["keypoints"][edge[0], 0], obs["keypoints"][edge[1], 0]],
                [obs["keypoints"][edge[0], 1], obs["keypoints"][edge[1], 1]],
                color="red",
                linewidth=4,
                zorder=2,
            )

        if self.done and show_title:
            ax.set_title("Terminated")

        # to allow all annotations to be seen
        plt.xlim([-1, self.image_size + 1])
        plt.ylim([-1, self.image_size + 1])
        plt.tight_layout()

        # If we haven't already shown or saved the plot, then we need to
        # draw the figure first...
        fig.canvas.draw()

        # Now we can save it to a numpy array.
        data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
        data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))

        plt.close()

        return data

    def seed(self, seed):
        random.seed(seed)
        random.seed(seed)
