import numpy as np
from typing import Optional

from itertools import product

from .MoleculeTransitionData import MoleculeActionSpace, Min, Max, StepSize
from .MockUpData import LateralMockUpData

from loguru import logger

class LateralActionsMockUpData(LateralMockUpData):
    def __init__(self, **kwargs):
        kwargs["name"] = "Lat"
        super().__init__(**kwargs)

        x_space = self.action_space.x_space
        y_space = self.action_space.y_space
        dest_x_space = self.action_space.dest_x_space
        dest_y_space = self.action_space.dest_y_space

        translation_means = np.zeros((self.symmetry,
                                      len(x_space),
                                      len(y_space),
                                      len(dest_x_space),
                                      len(dest_y_space),2), dtype=float)

        rotation_dists = np.zeros((self.symmetry,
                                   len(self.action_space.x_space),
                                   len(self.action_space.y_space),
                                   len(self.action_space.dest_x_space),
                                   len(self.action_space.dest_y_space),
                                   len(self.rotations)), dtype=float)

        for x, y in product(range(len(x_space)), range(len(y_space))):
            x_offset = x_space[x]
            y_offset = y_space[y]
            for dest_x, dest_y in product(range(len(dest_x_space)), range(len(dest_y_space))):
                dest_x_offset = dest_x_space[dest_x]
                dest_y_offset = dest_y_space[dest_y]
                mean = np.array([dest_x_offset - x_offset, dest_y_offset - y_offset])
                translation_means[0][x][y][dest_x][dest_y] = mean
                rotation_dists[0][x][y][dest_x][dest_y] = [0.75, 0.05 , 0.05 , 0.05 , 0.05 , 0.05]

        self.translation_means = translation_means
        self.rotation_dists = rotation_dists

    def get_translation_cov(self, angle: int, x: int, y: int, dest_x: int, dest_y: int):
        return np.array([[0.005,0.0],[0.0,0.005]])


