import numpy as np
from typing import Optional

from itertools import product

from loguru import logger

from .MoleculeTransitionData import MoleculeActionSpace, Min, Max, StepSize
from .MockUpData import VerticalMockUpData

class FePcMockUpData(VerticalMockUpData):
    def __init__(self, **kwargs):
        logger.info(kwargs)
        super().__init__(name="FePcMockUp", **kwargs)
        x_space = self.action_space.x_space
        y_space = self.action_space.y_space
        translation_means = np.zeros((len(x_space),len(y_space),2), dtype=float)

        for x, y in product(range(len(x_space)), range(len(y_space))):
            x_offset = x_space[x]
            y_offset = y_space[y]
            distance = np.linalg.norm(np.array([x_offset, y_offset]))
            angle = np.angle(x_offset + y_offset * 1j) + np.pi
            sin = np.sin(4 * angle + np.pi/2)/3
            multiplier = 4 / 3 * ((np.sin(distance - 0.5)/(distance - 0.5))**2 * ((1 + (sin))/np.sin(1+sin))**2) - 1.25
            #multiplier = 4 / 3 * ((np.sin(distance - 1)/(distance - 1))**2 * ((1 + (sin))/np.sin(1+sin))**2) - 1.25
            if multiplier < 0:
                vector = np.array([(0,0)])
            else:
                vector = np.array([round(multiplier * -x_offset, 1), round(multiplier * -y_offset, 1)])
                if np.all(vector == [0.0,0.0]):
                    vector = np.array([0.1, 0.1])
            translation_means[x][y] = vector

        self.translation_means = translation_means
        self.assign_rotation_dists()

    def get_translation_cov(self, x: int, y: int):
        print(self.cov)
        cov = np.array([[self.cov,0.0],[0.0,self.cov]])
        if np.any(self.translation_means[x][y] != [0.0,0.0]):
            return cov
        else:
            return np.array([[0.0,0.0],[0.0,0.0]])

    @property
    def rotations(self):
        #return [0,    30 , -30 , 60 ,-60]
        return [0,    30 , -30]

    def assign_rotation_dists(self):
        #z   =    [0.75, 0.05 , 0.05 , 0.05 , 0.05 , 0.05]
        z   =    [1.00, 0.00 , 0.00]
        r   =    [0.85, 0.00 , 0.15]
        l   =    [0.85, 0.15 , 0.00]
        R   =    [0.85, 0.00 , 0.15]
        L   =    [0.85, 0.15 , 0.00]

        self.rotation_dists = np.array([
          #-1.5  -1.2  -0.9  -0.6   -0.3   0.0   0.3   0.6   0.9   1.2   1.5   1.8   2.1
          [    L,    R,    R,    R,    R,    R,    L,    L,    L,    L,    L,    R,    R], #  1.5 3
          [    L,    L,    r,    r,    r,    z,    l,    l,    l,    L,    R,    R,    R], #  1.2 4
          [    L,    l,    z,    r,    r,    z,    l,    l,    z,    r,    R,    R,    R], #  0.9 5
          [    L,    l,    l,    z,    r,    z,    l,    z,    r,    r,    R,    R,    R], #  0.6 6
          [    L,    l,    l,    l,    z,    z,    z,    r,    r,    r,    R,    R,    R], #  0.3 7
          [    R,    R,    R,    z,    z,    z,    z,    z,    R,    R,    R,    R,    R], #  0.0 8
          [    R,    r,    r,    r,    z,    z,    z,    l,    l,    l,    L,    L,    L], # -0.3 9
          [    R,    r,    r,    z,    l,    z,    r,    z,    l,    l,    L,    L,    L], # -0.6 10
          [    R,    r,    z,    l,    l,    z,    r,    r,    z,    l,    L,    L,    L], # -0.9 11
          [    R,    L,    l,    l,    l,    z,    r,    r,    r,    L,    L,    L,    L], # -1.2 12
          [    L,    L,    L,    L,    L,    R,    R,    R,    R,    R,    L,    L,    L]]) # -1.5 13
          #-1.5  -1.2  -0.9  -0.6   -0.3   0.0   0.3   0.6   0.9   1.2   1.5   1.8   2.1


