import tempfile
import hashlib
import re

from shapely.geometry import Point, Polygon, LineString, GeometryCollection

from loguru import logger
from molecule_movement.logging import log_and_raise

from molecule_movement.model_based_scheduling.Utils import *
from pytictoc import TicToc

try:
    import stormpy
except ModuleNotFoundError as e:
    msg = "stormpy is not installed. In order to use symbolic features, please install stormpy: \n\n\thttps://moves-rwth.github.io/stormpy/installation.html\n"
    log_and_raise(e, msg)

HIGH_REWARD = 100
HASH_SIZE = 15

class ModelBasedSynthesizer():

    def __init__(self,
                 model,
                 prism_filepath: tempfile.NamedTemporaryFile,
                 starting_position: Point,
                 goal_position: Point,
                 goal_rotation: int,
                 corridor_geometry: shapely.geometry = GeometryCollection(),
                 accuracy: int = 10,
                 default_reward_model_name: str = "minDistance",
                 optimization_direction: OptimizationDirection = OptimizationDirection.Min,
                 x_max: int = 3,
                 x_min: int = -3,
                 y_max: int = 3,
                 y_min: int = -3,
                 waypoint_tolerance: float = 0.05,
                 parking: bool = False
                 ):

        # Model related member vars
        self.model = model
        self.model_states = model.states
        self.default_reward_model_name = default_reward_model_name
        self.current_reward_model_name = default_reward_model_name
        self.optimization_direction = OptimizationDirection.from_str(optimization_direction)
        self.prism_filepath = prism_filepath
        self.ints_regex = re.compile('([a-zA-Z][_a-zA-Z0-9]*)=(-?[0-9]+)')

        # Scheduler usage member vars
        self.factor = accuracy
        self.original_point = Point(int(starting_position.x * self.factor), int(starting_position.y * self.factor))
        self.goal_position = Point(int((goal_position.x - starting_position.x) * self.factor), int((goal_position.y - starting_position.y) * self.factor))
        self.goal_rotation = goal_rotation
        self.current_position = None
        self.current_waypoint = None
        self.prism_dimension_x = (x_min * self.factor, x_max * self.factor)
        self.prism_dimension_y = (y_min * self.factor, y_max * self.factor)
        self.waypoint_tolerance = waypoint_tolerance * self.factor
        self.parking = parking

        # Waypoint calculation and out of bounds / corridor member vars
        self.current_reference_frame = None
        self.current_out_of_bounds_geometry = None
        self.corridor_geometry = scale_geometry(center_geometry(corridor_geometry, starting_position), self.factor)
        self.ref_frame_center = Point(0,0)
        self.current_waypoint_global = None

        # Dictionaries/Caches
        self.current_scheduler_dict = {}
        self.scheduler_collection_dict = {}

    def __build_formula(self, label: str):
        return f"R{{\"{self.current_reward_model_name}\"}}{self.optimization_direction.value}=? [ F ({label}) ]"

    def __build_scheduler(self, label: str):
        t = TicToc()
        t.tic()

        # Get prism program and build the formula
        prism_program = stormpy.parse_prism_program(self.prism_filepath.name)
        formulas = stormpy.parse_properties(self.__build_formula(label), prism_program)

        # Perform model checking with scheduler extraction (important for nondeterminism)
        scheduler = stormpy.model_checking(self.model, formulas[0], extract_scheduler=True).scheduler
        logger.bind(task="stats", scheduler_synthesis_time=round(t.tocvalue(),3)).trace(f"Getting the scheduler took {t.tocvalue()} seconds.")
        return scheduler

    def __build_scheduler_dictionary(self, label: str):
        dictionary = {}
        #input("before")
        t = TicToc()

        # The scheduler depends on the used reward model and the label (=goal or waypoint)
        if (label, self.current_reward_model_name) in self.scheduler_collection_dict:
            dictionary = self.scheduler_collection_dict[(label, self.current_reward_model_name)]
        else:
            scheduler = self.__build_scheduler(label)

            t.tic()
            for state in self.model.states:
                values_str = state.valuations
                choice = scheduler.get_choice(state)
                action_index = choice.get_deterministic_choice()
                action = state.actions[action_index]

                ints = {i[0]: int(i[1]) for i in re.findall(self.ints_regex, values_str)}

                point = get_point_from_label(str(action.labels))
                dictionary[SchedulerData(Point(ints["x"], ints["y"]), ints["rot"])] = point

            # Add the newly created scheduler dictionary to the collection
            self.scheduler_collection_dict[(label, self.current_reward_model_name)] = dictionary

        self.current_scheduler_dict = dictionary
        t.toc('Building the dictionary took')
        #input("after")
        return dictionary

    # This method checks which points of the frame are outside the bounds and either calculates a new reward model
    # and adds it to the dict or gets the stored reward model
    def __adjusting_reward_model(self):
        points_inside = []
        reward_model_to_add = self.model.reward_models[self.default_reward_model_name]

        for state in self.model_states:
            #logger.info(f"{reward_model_to_add.get_state_reward(state)=}")
            p = valuations_to_point(state.valuations)
            if self.current_out_of_bounds_geometry.contains(point_add(p, self.current_reference_frame.centroid)):
                points_inside.append((p.x, p.y))
                reward_model_to_add.set_state_reward(state, HIGH_REWARD)

        # Sort the points and create the hash as a key
        points_inside.sort()
        reward_model_key = "reward_model_" + hashlib.blake2b(str(points_inside).encode(), digest_size=HASH_SIZE).hexdigest()

        # Try to set the reward model (will fail when RW has already been added to model)
        try:
            self.model.add_reward_model(reward_model_key, reward_model_to_add)  # Add reward_model to model
            self.set_reward_model(reward_model_key)
        except:
            # Try to get the RW from the model if it exists (RW key name is the same between runs)
            try:
                reward_model_to_add = self.model.reward_models[reward_model_key]
                self.set_reward_model(reward_model_key)
            except:
                raise Exception(f"Couldn't add reward model with key {reward_model_key} to model and class dictionary. This should not happen")

    # Calculates the outer bounds of the reference frame and if necessary calls function to handle the reward model
    def __checking_boundaries(self):
        outer_bounds = self.corridor_geometry

        outer_bounds_polygons = self.current_reference_frame.difference(outer_bounds)
        self.current_out_of_bounds_geometry = outer_bounds_polygons

        # Check if we need new or re-use old reward model
        if outer_bounds_polygons.is_empty:
            self.set_reward_model(self.default_reward_model_name)   # No points out of bounds. Use default reward model
        else:
            self.__adjusting_reward_model()

    def __build_rectangle(self, x_coord, y_coord):
        px = (x_coord + self.prism_dimension_x[0], x_coord + self.prism_dimension_x[1])  # X_MIN, X_MAX of that frame
        py = (y_coord + self.prism_dimension_y[0], y_coord + self.prism_dimension_y[1])  # Y_MIN, Y_MAX of that frame

        p1 = (px[0], py[0])  # Bottom-left
        p2 = (px[1], py[0])  # Top-left
        p3 = (px[1], py[1])  # Bottom-right
        p4 = (px[0], py[1])  # Top-right
        rectangle = Polygon([p1, p2, p3, p4])

        return rectangle

    # Calculates next waypoint and returns the point and the info if it's inside the next frame
    def __calculate_next_waypoint(self):
        # Building the rectangle of the current frame around the molecule
        rectangle = self.__build_rectangle(self.current_position.x, self.current_position.y)

        # TODO: Handle overshoot of goal. If we calculate the next waypoint and mirror the current position on the other edge
        # of the rectangle, we may overshoot the goal because the goal may be inside the rectangle now.

        # Create a normal line from the current position to the goal and intersect it with the rectangle to get new center of ref frame
        normalLine_rectangle = LineString([(self.current_position.x, self.current_position.y), self.goal_position])
        intersectionCoordinates_rectangle = normalLine_rectangle.intersection(rectangle.boundary)

        # If goal is already inside this rectangle
        if rectangle.contains(self.goal_position):
            self.ref_frame_center = self.current_position
            self.current_waypoint_global = self.goal_position

            waypoint_ref_frame = point_sub(self.current_waypoint_global, self.current_position)
            return waypoint_ref_frame, True

        # Create the reference frame with the molecule on the boundary (by design)
        ref_frame_center = Point(intersectionCoordinates_rectangle.coords[-1])
        ref_frame = self.__build_rectangle(ref_frame_center.x, ref_frame_center.y)
        self.ref_frame_center = Point(round(ref_frame_center.x), round(ref_frame_center.y))

        # Create a normal line from the current position to the goal and intersect it with the rectangle to get new center of ref frame
        normalLine_ref_frame = LineString([(ref_frame_center.x, ref_frame_center.y), self.goal_position])
        intersection_coordinates_ref_frame = normalLine_ref_frame.intersection(ref_frame.boundary)

        # Checking boundaries and setting reward model if necessary
        self.current_reference_frame = ref_frame
        self.__checking_boundaries()

        # If the goal is inside the current reference frame. Note that the center of ref frame is not self.current_position anymore
        if intersection_coordinates_ref_frame.is_empty and normalLine_ref_frame.intersects(ref_frame):
            self.current_waypoint_global = self.goal_position

            waypoint_ref_frame = point_sub(self.current_waypoint_global, self.ref_frame_center)
            return waypoint_ref_frame, True

        waypoint_global = Point(intersection_coordinates_ref_frame.coords[-1])
        self.current_waypoint_global = waypoint_global
        waypoint_ref_frame = point_sub(waypoint_global, ref_frame_center)

        return waypoint_ref_frame, False

    def get_scheduler_dict(self, current_position: Point):
        self.current_position = current_position
        next_waypoint, goal_is_in_ref_room = self.__calculate_next_waypoint()
        self.current_waypoint = round_coords(next_waypoint)

        # Parking logic
        # Take all rotation combinations as labels if waypoint is not the goal and parking is disabled
        if goal_is_in_ref_room and self.parking:
            label = "\"" + prism_label(round_up_custom(next_waypoint.x), round_up_custom(next_waypoint.y), self.goal_rotation) + "\""
        else:
            label = ""
            for rot in range(0, 360, 60):
                label += "\"" + prism_label(round_up_custom(next_waypoint.x), round_up_custom(next_waypoint.y), rot) + "\""
                if rot != 300:
                    label += " | "

        return self.__build_scheduler_dictionary(label)

    def get_next_action(self, current_state: SchedulerData):
        current_state = SchedulerData(Point(point_sub(scalar_mult(current_state.point, self.factor), self.original_point)), current_state.rot)
        #logger.info(f"{current_state=}")

        # TODO: Adapt this comment
        # Dictionary point sched_point = (current_state - ref_point) -> center the current state to (0,0) of ref_room
        #                                 - current_waypoint -> puts the point from the center to the mirrored line center-waypoint, essentially making the ref_room >= twice as large
        if (self.current_waypoint_global != self.goal_position) and (current_state.point.distance(self.current_waypoint_global) < self.waypoint_tolerance):
            self.get_scheduler_dict(current_state.point)
            sched_point = point_sub(current_state.point, self.ref_frame_center)
            action = self.current_scheduler_dict[SchedulerData(sched_point, current_state.rot)]
        else:
            try:
                sched_point = point_sub(current_state.point, self.ref_frame_center)
                action = self.current_scheduler_dict[SchedulerData(sched_point, current_state.rot)]
            except:
                self.get_scheduler_dict(current_state.point)
                sched_point = point_sub(current_state.point, self.ref_frame_center)
                action = self.current_scheduler_dict[SchedulerData(sched_point, current_state.rot)]

        return action

    def set_reward_model(self, reward_model_name: str):
        if reward_model_name in list(self.model.reward_models.keys()):
            self.current_reward_model_name = reward_model_name
            #logger.info(f"Reward model set to {self.current_reward_model_name}")
        else:
            raise Exception(f"Reward model {reward_model_name} not found. This should not happen")

    def get_reward_model_list(self):
        return list(self.model.reward_models.keys())

    def set_optimization_direction(self, optimization_direction: OptimizationDirection):
        self.optimization_direction = optimization_direction

    # Resets the member vars for new run with new target
    def reset_with_new_target(self,
                              starting_position: Point,
                              goal_position: Point,
                              goal_rotation: int,
                              corridor_geometry: shapely.geometry = GeometryCollection(),
                              waypoint_tolerance: float = 0.05,
                              parking: bool = True):

        self.original_point = Point(int(starting_position.x * self.factor), int(starting_position.y * self.factor))
        self.goal_position = Point(int((goal_position.x - starting_position.x) * self.factor), int((goal_position.y - starting_position.y) * self.factor))
        self.goal_rotation = goal_rotation
        self.ref_frame_center = Point(0,0)
        self.current_reward_model_name = self.default_reward_model_name
        self.waypoint_tolerance = waypoint_tolerance * self.factor
        self.parking = parking
        self.corridor_geometry = scale_geometry(center_geometry(corridor_geometry, starting_position), self.factor)

        self.current_waypoint_global = None
        self.current_reference_frame = None
        self.current_out_of_bounds_geometry = None
        self.current_position = None
        self.current_waypoint = None
        self.current_scheduler_dict = {}
