from loguru import logger
from molecule_movement.logging import log_and_raise
import numpy as np
from itertools import product
from sklearn.cluster import KMeans
from molecule_movement.model_based_scheduling.Utils import *
from pytictoc import TicToc
import tempfile
import os
import shutil
from molecule_movement import VerticalAction
from molecule_movement.parsing import MoleculeTransitionData, MockUpData, DataParser


try:
    import stormpy
except ModuleNotFoundError as e:
    msg = "stormpy is not installed. In order to use symbolic features, please install stormpy: \n\n\thttps://moves-rwth.github.io/stormpy/installation.html\n"
    log_and_raise(e, msg)

from memory_profiler import profile
from datetime import datetime
log_file = open(f"memory_{str(datetime.now()).replace(' ', '_')}", "w+")

class ModelBuilder():
    def __init__(self,
                 data_processor: DataParser | MockUpData,
                 prism_filepath: tempfile = tempfile.NamedTemporaryFile(mode="w", delete=False),
                 x_max: int = 3,
                 x_min: int = -3,
                 y_max: int = 3,
                 y_min: int = -3,
                 accuracy: int = 10,
                 label_grid_size: int = 0.2,
                 sample_size: int = 100,
                 bins: int = 5
                 ):

        self.prism_filepath = prism_filepath
        self.data_processor = data_processor

        self.molecule_transition_data = data_processor.get_molecular_data()

        self.x_space = self.data_processor.action_space.x_space
        self.y_space = self.data_processor.action_space.y_space
        self.factor = accuracy
        self.prism_dimension_x = (x_min * self.factor, x_max * self.factor)
        self.prism_dimension_y = (y_min * self.factor, y_max * self.factor)

        self.samples_translation = None
        self.samples_rotation = None
        self.label_grid_size = label_grid_size
        self.sample_size = sample_size
        self.bins = bins

        self._model = None

    def __sample_data(self):
        samples_translations_array = {}
        samples_rotations_array = {}

        for x_index, y_index in product(range(len(self.x_space)), range(len(self.y_space))):
            x_offset = self.x_space[x_index]
            y_offset = self.y_space[y_index]
            action = VerticalAction(Point(x_offset, y_offset))
            sample_translations = [self.molecule_transition_data.vertical_translations[action]() for _ in range(self.sample_size)]
            samples_translations_array[(x_offset,y_offset)] = np.array(sample_translations)
            sample_rotations = [self.molecule_transition_data.rotations[action]() for _ in range(self.sample_size)]
            samples_rotations_array[(x_offset,y_offset)] = np.array(sample_rotations)

        self.samples_translation = samples_translations_array
        self.samples_rotation = samples_rotations_array


    def __generate_prism_file(self):
        self.__sample_data()

        model_type = "mdp"
        constants = {
            "X_MIN": self.prism_dimension_x[0],
            "X_MAX": self.prism_dimension_x[1],
            "Y_MIN": self.prism_dimension_y[0],
            "Y_MAX": self.prism_dimension_y[1]
        }

        # Variables. Start, End, Initial value
        variables = {
            "x": ["X_MIN", "X_MAX", 0],
            "y": ["Y_MIN", "Y_MAX", 0]
        }
        module_name = "molecule_movement"

        # --- Generate Prism file ---
        with self.prism_filepath as f:
            # Writing the model type (default is mdp)
            f.write(f"{model_type}\n\n")

            # Writing the constants from the dictionary to prism file
            for constant, value in constants.items():
                f.write(f"const int {constant} = {value};\n")
            f.write(f"\n")

            # Module starts from here
            f.write(f"module {module_name}\n\n")

            for variable, data in variables.items():
                f.write(f"\t{variable} : [{data[0]}..{data[1]}] init {data[2]};\n")
            f.write(f"\n")

            prism_action_label = []
            action_to_distance_moved = dict()

            for rot in range(0, 360, 60):
                for x, y in zip(np.repeat(self.x_space, len(self.x_space)), np.tile(self.y_space, len(self.y_space))):
                    # Run K-Means with 10 bins
                    # For translations
                    k = self.bins
                    kmeans = KMeans(n_clusters=k, n_init=10)
                    kmeans.fit(self.samples_translation[(x,y)])
                    labels = kmeans.labels_
                    centroids_translation = kmeans.cluster_centers_

                    # Count occurrences of each label (cluster) and normalize to get probabilities
                    unique_labels, counts = np.unique(labels, return_counts=True)
                    probabilities_translation = counts / len(self.samples_translation[(x,y)])

                    # For rotations
                    sample = self.samples_rotation[(x,y)]
                    rotations, counts = np.unique(sample, return_counts=True)
                    probabilities_rotation = counts / len(self.samples_rotation[(x,y)])

                    x_write = f"p{np.abs(int(x*10))}" if x >= 0 else f"n{np.abs(int(x*10))}"
                    y_write = f"p{np.abs(int(y*10))}" if y >= 0 else f"n{np.abs(int(y*10))}"
                    action_write = f"[action_{x_write}_{y_write}]"

                    if action_write not in prism_action_label: prism_action_label.append(action_write)
                    f.write(f"\t{action_write} rot = {rot} -> \n")

                    action_to_distance_moved[action_write] = 0

                    for i, prob in enumerate(probabilities_translation):
                        if (i != 0): f.write(" + \n")
                        centroid_x, centroid_y = centroids_translation[i]
                        centroid_x, centroid_y = rotate_point(centroid_x, centroid_y, rot)
                        action_to_distance_moved[action_write] += prob * np.linalg.norm([centroid_x, centroid_y])
                        f.write(f"\t\t{prob} : (x' = {prism_max_min('x',int(np.round(centroid_x * self.factor, 1)))}) & (y' = {prism_max_min('y',int(np.round(centroid_y * self.factor, 1)))})")
                    f.write(";\n");

            f.write(f"\nendmodule")

            f.write(f"\n\nmodule rotation\n")
            f.write(f"\trot : [0..300] init 0;\n\n")

            for action_label in prism_action_label:
                f.write(f"\t{action_label} true -> \n")
                for i in range(0,len(probabilities_rotation)):
                    if (i != 0): f.write(" + \n")
                    f.write(f"\t\t{probabilities_rotation[i]} : (rot' = mod(rot + {rotations[i]}, 360) )")
                f.write(";\n");
            f.write(f"\nendmodule\n")

            f.write(f"\nrewards \"state\"\n")
            f.write(f"\ttrue: 1;\n")
            f.write(f"endrewards")

            f.write(f"\nrewards \"minDistance\"\n")
            f.write(f"\ttrue: 10;\n")
            for action, distance in action_to_distance_moved.items():
                f.write(f"\t{action} true : {distance};\n")
            f.write(f"endrewards\n\n")

            for rot in range(0, 360, 60):
                for x in np.arange(constants["X_MIN"], constants["X_MAX"] + self.factor, int(self.factor * self.label_grid_size)):
                    for y in np.arange(constants["Y_MIN"], constants["Y_MAX"] + self.factor, int(self.factor * self.label_grid_size)):
                        label = f"label \"x_{prism_sign(x)}_y_{prism_sign(y)}_rot_{rot}\""
                        f.write(f"{label} = x = {x} & y = {y} & rot = {rot};\n")

    @profile(stream=log_file)
    def build(self):
        t = TicToc()
        t.tic()
        self.__generate_prism_file()
        prism_program = stormpy.parse_prism_program(self.prism_filepath.name)

        options = stormpy.BuilderOptions()
        options.set_build_all_labels()
        options.set_build_state_valuations()
        options.set_build_choice_labels()
        options.set_build_with_choice_origins()

        model = stormpy.build_sparse_model_with_options(prism_program, options)
        self.model = model
        t.toc()
        return model

    def save_prism_file(self, filepath: str):
        try:
            shutil.copy(self.prism_filepath.name, filepath)
        except Exception as e:
            print(f"Failed to copy temp file to filepath: {e}")

    @property
    def filepath(self):
        self.prism_filepath.close()
        return self.prism_filepath

    @property
    def model(self):
        return self._model

    @model.setter
    def model(self, value):
        self._model = value

    def __del__(self):
        try:
            self.prism_filepath.close()
            os.unlink(self.prism_filepath.name)
            os.remove(self.prism_filepath.name)
        except Exception as e:
            logger.trace(f"Failed to delete temp file: {e}")
