import numpy as np

from shapely import Point

from typing import Optional
from abc import abstractmethod

from itertools import product

from loguru import logger

from .MoleculeTransitionData import MoleculeTransitionData, MoleculeActionSpace, Min, Max, StepSize
from molecule_movement.Objects import VerticalAction, LateralAction

class MockUpData():
    def __init__(self, name: Optional[str] = None):
        if name:
            self.name = name
        else:
            self.name = self.__class__.__name__
        self.maximum_movement = -1

    @classmethod
    def shortname(cls) -> str:
        return cls.__name__.removesuffix("MockUpData")

    @abstractmethod
    def get_response_map(self) -> MoleculeTransitionData:
        pass

class VerticalMockUpData(MockUpData):
    def __init__(self,
                 dimensions_x : tuple[Min, Max],
                 dimensions_y : tuple[Min, Max],
                 step_x : StepSize,
                 step_y : StepSize,
                 dimensions_z : Optional[tuple[Min, Max]] = None,
                 dimensions_V : Optional[tuple[Min, Max]] = None,
                 step_z : Optional[StepSize] = None,
                 step_V : Optional[StepSize] = None,
                 symmetry : Optional[int] = 0,
                 cov: Optional[float] = 0.025,
                 **kwargs
                 ):
        super().__init__(kwargs.get("name", None))
        self.symmetry = symmetry
        self.action_space = MoleculeActionSpace(dimensions_x=dimensions_x,
                                                dimensions_y=dimensions_y,
                                                step_x=step_x,
                                                step_y=step_y,
                                                dimensions_z=dimensions_z,
                                                dimensions_V=dimensions_V,
                                                step_z=step_z,
                                                step_V=step_V)
        self.translation_means = np.zeros((len(self.action_space.x_space),len(self.action_space.y_space),2), dtype=float)
        self.translation_covs =  np.zeros((len(self.action_space.x_space),len(self.action_space.y_space),2,2), dtype=float)
        self.rotation_dists = np.zeros((len(self.action_space.x_space),len(self.action_space.y_space),len(self.rotations)), dtype=float)
        self.cov = cov

    def get_translation_mean(self, x: int, y: int):
        return self.translation_means[x][y]

    def get_translation_cov(self, x: int, y: int):
        return self.translation_covs[x][y]

    def get_rotation_dist(self, x: int, y: int):
        return self.rotation_dists[x][y]

    @property
    def rotations(self):
        return [0,    60 , -60 , 120 ,-120 , 180]

    def get_response_map(self) -> MoleculeTransitionData:
        if hasattr(self, "moleculeData"):
            return self.moleculeData

        translations = dict()
        rotations = dict()

        for x,y in self.action_space.action_indices:
            x_action = round(self.action_space.x_space[0] + x * self.action_space.step_x,1)
            y_action = round(self.action_space.y_space[0] + y * self.action_space.step_y,1)
            mean = self.get_translation_mean(x,y)
            cov = self.get_translation_cov(x,y)
            translations[VerticalAction(Point(x_action, y_action))] = lambda mean=mean, cov=cov: tuple(*np.random.multivariate_normal(mean, cov, 1))
            rot_dist = self.get_rotation_dist(x,y)
            rotations[VerticalAction(Point(x_action, y_action))] = lambda rot_dist=rot_dist: self.rotations[np.random.choice(len(self.rotations), 1, p=rot_dist)[0]]
            length = np.linalg.norm(mean)
            if length > self.maximum_movement:
                self.maximum_movement = length


        self.moleculeData = MoleculeTransitionData(translations,
                                                   rotations,
                                                   self.action_space,
                                                   self.name,
                                                   self.maximum_movement)
        return self.moleculeData

    def __str__(self) -> str:
        means = ""
        lengths = ""
        np.set_printoptions(formatter={'float': '{: .1f}'.format})
        for x, y in self.action_space.action_indices:
            mean = self.get_translation_mean(x,y)
            means += f"{mean} "
            lengths += f"{np.linalg.norm(mean):0.2f} "
            if y == len(self.action_space.y_space) - 1:
                means += "\n"
                lengths += "\n"
        return f"{self.name}:\n{means}\n\n{lengths}\n"

    def as_vector_field(self):
        for (x_i,y_i), (x,y) in zip(self.action_space.action_indices, self.action_space.actions):
            movement = self.get_translation_mean(x_i,y_i)
            if np.all(movement == [0.0, 0.0]): continue
            print(f"{x} {y} {movement[0]} {movement[1]}")

class LateralMockUpData(MockUpData):
    def __init__(self,
                 dimensions_x : tuple[Min, Max],
                 dimensions_y : tuple[Min, Max],
                 step_x : StepSize,
                 step_y : StepSize,
                 dimensions_dest_x : tuple[Min, Max],
                 dimensions_dest_y : tuple[Min, Max],
                 step_dest_x : StepSize,
                 step_dest_y : StepSize,
                 dimensions_z : Optional[tuple[Min, Max]] = None,
                 dimensions_V : Optional[tuple[Min, Max]] = None,
                 step_z : Optional[StepSize] = None,
                 step_V : Optional[StepSize] = None,
                 symmetry : int = 0,
                 **kwargs
                 ):
        super().__init__(kwargs.get("name", None))
        self.symmetry = symmetry
        self.action_space = MoleculeActionSpace(dimensions_x=dimensions_x,
                                                dimensions_y=dimensions_y,
                                                step_x=step_x,
                                                step_y=step_y,
                                                dimensions_z=dimensions_z,
                                                dimensions_V=dimensions_V,
                                                step_z=step_z,
                                                step_V=step_V,
                                                dimensions_dest_x=dimensions_dest_x,
                                                dimensions_dest_y=dimensions_dest_y,
                                                step_dest_x=step_dest_x,
                                                step_dest_y=step_dest_y)

        self.translation_means = np.zeros((self.symmetry,len(self.action_space.x_space),len(self.action_space.y_space),len(self.action_space.dest_x_space),len(self.action_space.dest_y_space),2), dtype=float)
        self.translation_covs = np.zeros((self.symmetry,len(self.action_space.x_space),len(self.action_space.y_space),len(self.action_space.dest_x_space),len(self.action_space.dest_y_space),2,2), dtype=float)
        self.rotation_dists = np.zeros((self.symmetry,len(self.action_space.x_space),len(self.action_space.y_space),len(self.action_space.dest_x_space),len(self.action_space.dest_y_space),len(self.rotations)), dtype=float)

    def get_translation_mean(self, angle: int, x: int, y: int, dest_x: int, dest_y: int):
        return self.translation_means[angle][x][y][dest_x][dest_y]

    def get_translation_cov(self, angle: int, x: int, y: int, dest_x: int, dest_y: int):
        return self.translation_covs[angle][x][y][dest_x][dest_y]

    def get_rotation_dist(self, angle: int, x: int, y: int, dest_x: int, dest_y: int):

        return self.rotation_dists[angle][x][y][dest_x][dest_y]

    def get_response_map(self) -> MoleculeTransitionData:
        translations = dict()
        rotations = dict()
        for angle in range(self.symmetry):
            translations[angle] = dict()
            rotations[angle] = dict()
            for x,y in product(range(len(self.action_space.x_space)), range(len(self.action_space.y_space))):
                x_action = round(self.action_space.x_space[0] + x * self.action_space.step_x,1)
                y_action = round(self.action_space.y_space[0] + y * self.action_space.step_y,1)
                for dest_x,dest_y in product(range(len(self.action_space.dest_x_space)), range(len(self.action_space.dest_y_space))):
                    dest_x_action = round(self.action_space.dest_x_space[0] + dest_x * self.action_space.step_dest_x,1)
                    dest_y_action = round(self.action_space.dest_y_space[0] + dest_y * self.action_space.step_dest_y,1)
                    action = LateralAction(Point(x_action, y_action), Point(dest_x_action,dest_y_action))
                    mean = self.get_translation_mean(angle,x,y,dest_x,dest_y)
                    cov = self.get_translation_cov(angle,x,y,dest_x,dest_y)
                    translations[angle][action] = lambda mean=mean, cov=cov: tuple(*np.random.multivariate_normal(mean, cov, 1))
                    rot_dist = self.get_rotation_dist(angle,x,y,dest_x,dest_y)
                    rotations[angle][action] = lambda rot_dist=rot_dist: self.rotations[np.random.choice(len(self.rotations), 1, p=rot_dist)[0]]
                    length = np.linalg.norm(mean)
                    if length > self.maximum_movement:
                        self.maximum_movement = length

        self.moleculeData = MoleculeTransitionData(translations,
                                                   rotations,
                                                   self.action_space,
                                                   self.name,
                                                   self.maximum_movement)
        return self.moleculeData

    @property
    def rotations(self):
        return [0,    60 , -60 , 120 ,-120 , 180]
