import numpy as np
from scipy.special import softmax

class PlaceCells():

    def __init__(
        self, n: int, env_dim: float, width: float,
        method: str = 'dos', normalize: str = 'sum', surround_scale: int = 2
    ):
        self.n = n
        self.centers = np.random.uniform(0, env_dim, (n,2))
        self.width = width
        self.surround_scale = surround_scale

        self.softmax = softmax

        if method not in ['gaus', 'dos', 'dog']:
            raise ValueError("Method must be 'gaus', 'dos' or 'dog'")
        self.method = method

        if normalize is not None and normalize not in ['softmax', 'sum', 'positive']:
            raise ValueError("Normalize must be 'softmax', 'sum' or 'positive'")
        self.normalize = normalize

    def get_centers(self):
        return self.centers

    def _normalize(self, activity):
        if self.normalize == 'softmax':
            activity = softmax(activity, axis=-1)
        elif self.normalize == 'sum':
            _min = activity.min(axis=-1)
            activity += np.abs(_min)[..., None]
            activity /= activity.sum(axis=-1)[..., None]
        elif self.normalize == 'positive':
            _min = activity.min(axis=-1)
            activity += np.abs(_min)[..., None]

        return activity
    
    def _compute_gaus(self, dist):
        activity = np.exp(
            -dist /
            (2*np.power(self.width, 2))
        )

        return self._normalize(activity)

    def _compute_dos(self, dist):
        term1 = np.exp(
            -dist / (2*np.power(self.width, 2))
        )
        term2 = np.exp(
            -dist / (2*np.power(self.surround_scale*self.width, 2)),
        )

        activity = term1 / np.sum(term1) - term2 / np.sum(term2)

        return self._normalize(activity)
    
    def _compute_dog(self, dist):
        term1 = np.exp(
            -dist /
            (2*np.power(self.width, 2))
        )
        term2 = np.exp(
            -dist /
            (2*np.power(self.surround_scale*self.width, 2))
        )

        activity = (
            term1 / (2*np.pi*np.power(self.width, 2)) -
            term2 / (2*np.pi*np.power(self.surround_scale*self.width, 2))
        )

        return self._normalize(activity)

    def get_state(self, pos):
        '''
        Get place cell activations for a given position.

        Args:
            pos: 2d position of shape [batch_size, sequence_length, 2].

        Returns:
            outputs: Place cell activations with shape [batch_size, sequence_length, Np].
        '''
        pos = np.array(pos)
        original_shape = list(pos.shape[:-1])
        pos = pos.reshape(-1, pos.shape[-1])

        pos_reshaped = np.moveaxis(np.tile(pos, (self.n, 1, 1)), 0, 1)

        dist = np.linalg.norm(pos_reshaped - self.centers, axis=-1)**2

        if self.method == 'gaus':
            activity = self._compute_gaus(dist)
        elif self.method == 'dos':
            activity = self._compute_dos(dist)
        elif self.method == 'dog':
            activity = self._compute_dog(dist)

        activity = activity.reshape(original_shape+[self.n])
        
        return activity
