import math
import numpy as np
from dataclasses import dataclass
from enum import Enum
import re

import shapely.affinity
from shapely import Point

# --- regex ---
valuations_regex = r'x=([-\d]+).*y=([-\d]+)'

# --- dataclass ---
# Data structure to save the 2D point/vector and its rotation
@dataclass(frozen=True)
class SchedulerData:
    point: Point
    rot: int

# Enum for the optimization direction, i.e. if we want to minimize or maximize
class OptimizationDirection(Enum):
    Min = "min"
    Max = "max"

    @classmethod
    def from_str(cls, value):
        if isinstance(value, cls):
            return value
        try:
            return cls(value.lower())
        except ValueError:
            raise ValueError(f"Invalid mode: {value}. Choose from: {[e.value for e in cls]}")

# --- helper function ---

def point_add(p1: Point, p2: Point):
    return Point(p1.x + p2.x, p1.y + p2.y)

def point_sub(p1: Point, p2: Point):
    return Point(p1.x - p2.x, p1.y - p2.y)

def scalar_mult(p1: Point, scalar: int):
    return Point(p1.x * scalar, p1.y * scalar)

def inverse_vector(p: Point):
    return Point(-p.x, -p.y)

# Rounds coordinates to integers
def round_coords(p: Point):
    return Point(round(p.x), round(p.y))

# Rotates vector by a specific angle
def rotate_point(x, y, angle_degrees, origin=(0, 0)):
    point = Point(x, y)
    rotated_point = shapely.affinity.rotate(point, angle_degrees, origin=origin, use_radians=False)
    return rotated_point.x, rotated_point.y

def scale_geometry(geometry: shapely.geometry, factor: float):
    return shapely.affinity.scale(geometry, factor, factor, origin=(0,0))

def center_geometry(geometry: shapely.geometry, p: Point):
    return shapely.affinity.translate(geometry, -p.x, -p.y)

# TODO: Make this more generalized
# Rounds up an int to its next even number
def round_up_custom(x):
    n = int(x)
    if n % 2 != 0:
        n += 1
    return n

# Signs numbers with p/n instead of a "-" or "". This is needed for prism file parsing
def prism_sign(x):
    if abs(x) == x:
        return f"p{x}"
    else:
        return f"n{x}".replace("-", "")

# Creates boundary vars for the prism lines
def prism_max_min(variable, update):
    return f"max(min({variable} + {update}, {variable.upper()}_MAX), {variable.upper()}_MIN)"

# Creates prism file labels
def prism_label(x, y, rot):
    return f"x_{prism_sign(int(x))}_y_{prism_sign(int(y))}_rot_{int(rot)}"

def valuations_to_point(valuation: str):
    match = re.search(valuations_regex, valuation)
    x = int(match.group(1))
    y = int(match.group(2))
    return Point(x,y)

# Translates a label to a point in the reference space
def get_point_from_label(action_label):
    action_label = action_label.split("action_")[1].replace("'", "").replace("}", "")
    x_str, y_str = action_label.split("_")
    x = int(x_str[1:]) if x_str[0] == "p" else -int(x_str[1:])
    y = int(y_str[1:]) if y_str[0] == "p" else -int(y_str[1:])
    return np.array([x/10, y/10])

# For easier Scheduler debugging
def debug_scheduler(model, scheduler):
    with open("debug.sched", "w") as f:
        for state in model.states:
            choice = scheduler.get_choice(state)
            action_index = choice.get_deterministic_choice()
            action = state.actions[action_index]
            f.write("In state {} ({}) choose action {} ({})\n".format(state.valuations, ", ".join(state.labels), action, ", ".join(action.labels)))