import numpy as np
import scipy
import torch

class SpatialNavigation:
    def __init__(
        self,
        box_width=2.2,
        box_height=2.2,
        border_region=0.03,
        border_slow_factor=0.25,
        init_pos="uniform",
        biased=False,
        biased_ratio=1.0,
        drift_const=0.05,
        anchor_point=np.array([0, 0]),
        dt=0.02,
        sigma=11.52,
        b=0.26 * np.pi,
        mu=0,
        use_place_cells=True,
        place_cells_num=512,
        place_cells_sigma=0.2,
        place_cells_surround_scale=2,
        place_cells_dog=False,
        sequence_length=100,
        batch_size=200
    ):
        self.box_width = box_width
        self.box_height = box_height
        self.border_region = border_region
        self.border_slow_factor = border_slow_factor
        self.init_pos = init_pos
        self.biased = biased
        self.biased_ratio = biased_ratio
        self.drift_const = drift_const
        self.anchor_point = anchor_point
        self.dt = dt  # time step increment (s)
        self.sigma = sigma  # std. dev. rotation velocity (rad/s)
        self.b = b  # forward velocity rayleigh dist. scale (m/s)
        self.mu = mu  # turn angle bias
        self.use_place_cells = use_place_cells
        if self.use_place_cells:
            self.place_cells = PlaceCells(
                num_cells=place_cells_num,
                sigma=place_cells_sigma,
                surround_scale=place_cells_surround_scale,
                box_width=self.box_width,
                box_height=self.box_height,
                diff_of_gaussians=place_cells_dog
            )
        self.sequence_length = sequence_length
        self.batch_size = batch_size

    def avoid_wall(self, position, hd):
        """"
        Compute distance and angle to nearest wall
        """
        x = position[:, 0]
        y = position[:, 1]
        dists = [self.box_width / 2 - x, self.box_height / 2 - y, self.box_width / 2 + x, self.box_height / 2 + y]
        d_wall = np.min(dists, axis=0)
        angles = np.arange(4) * np.pi / 2
        theta = angles[np.argmin(dists, axis=0)]
        hd = np.mod(hd, 2 * np.pi)
        a_wall = hd - theta
        a_wall = np.mod(a_wall + np.pi, 2 * np.pi) - np.pi
        
        is_near_wall = (d_wall < self.border_region) * (np.abs(a_wall) < np.pi / 2)
        turn_angle = np.zeros_like(hd)
        turn_angle[is_near_wall] = np.sign(a_wall[is_near_wall]) * (np.pi / 2 - np.abs(a_wall[is_near_wall]))

        return is_near_wall, turn_angle

    def generate_trajectory(self):
        """
        Generate a random walk in a rectangular box
        """
        samples = self.sequence_length

        # Initialize variables
        position = np.zeros([self.batch_size, samples + 2, 2])
        head_dir = np.zeros([self.batch_size, samples + 2])

        if self.init_pos == "uniform":
            position[:, 0, 0] = np.random.uniform(-self.box_width / 2, self.box_width / 2, self.batch_size)
            position[:, 0, 1] = np.random.uniform(-self.box_height / 2, self.box_height / 2, self.batch_size)
        else:
            position[:, 0, 0] = np.zeros(self.batch_size)
            position[:, 0, 1] = np.zeros(self.batch_size)

        head_dir[:, 0] = np.random.uniform(0, 2 * np.pi, self.batch_size)
        velocity = np.zeros([self.batch_size, samples + 2])
        updates = np.zeros([self.batch_size, samples + 2, 2])

        # Generate sequence of random boosts and turns
        random_turn = np.random.normal(self.mu, self.sigma, [self.batch_size, samples + 1])
        random_vel = np.random.rayleigh(self.b, [self.batch_size, samples + 1])
        v = np.abs(np.random.normal(0, self.b * np.pi / 2, self.batch_size))

        for t in range(samples + 1):
            # Update velocity
            v = random_vel[:, t]
            turn_angle = np.zeros(self.batch_size)

            # If in border region, turn and slow down
            is_near_wall, turn_angle = self.avoid_wall(position[:, t], head_dir[:, t])
            v[is_near_wall] *= self.border_slow_factor

            # Update turn angle
            turn_angle += self.dt * random_turn[:, t]

            # Take a step
            velocity[:, t] = v * self.dt
            update = velocity[:, t, None] * np.stack([np.cos(head_dir[:, t]), np.sin(head_dir[:, t])], axis=-1)

            if self.biased:
                biased_mask = np.zeros((self.batch_size, 1))
                biased_mask[:int(self.batch_size * self.biased_ratio)] = 1
                np.random.shuffle(biased_mask)
                update += self.drift_const * (self.anchor_point - position[:, t]) * biased_mask

            updates[:, t] = update
            position[:, t + 1] = position[:, t] + update

            # Rotate head direction
            head_dir[:, t + 1] = head_dir[:, t] + turn_angle

        head_dir = np.mod(head_dir + np.pi, 2 * np.pi) - np.pi # Periodic variable, modify range to [-π, π]

        traj = {}

        # Input variables
        traj["init_hd"] = torch.from_numpy(head_dir[:, 0, None]).float()
        traj["init_x"] = torch.from_numpy(position[:, 1, 0, None]).float()
        traj["init_y"] = torch.from_numpy(position[:, 1, 1, None]).float()

        traj["ego_v"] = torch.from_numpy(velocity[:, 1:-1]).float()
        traj["v"] = torch.from_numpy(updates[:, 1:-1]).float()
        ang_v = np.diff(head_dir, axis=-1)
        traj["phi_x"] = torch.from_numpy(np.cos(ang_v)[:, :-1]).float()
        traj["phi_y"] = torch.from_numpy(np.sin(ang_v)[:, :-1]).float()

        # Target variables
        traj["target_hd"] = torch.from_numpy(head_dir[:, 1:-1]).float()
        traj["target_x"] = torch.from_numpy(position[:, 2:, 0]).float()
        traj["target_y"] = torch.from_numpy(position[:, 2:, 1]).float()

        return traj

    def get_generator(self):
        """
        Returns a generator that yields batches of trajectories
        """ 
        while True:
            traj = self.generate_trajectory()

            v = traj["v"]
            pos = torch.stack([traj["target_x"], traj["target_y"]], axis=-1)
            init_pos = torch.stack([traj["init_x"], traj["init_y"]], axis=-1)
            batch = {
                "data": v,
                "init_state": init_pos,
                "targets": pos,
                "init_pos": init_pos,
                "target_pos": pos
            }

            if self.use_place_cells:
                place_outputs = self.place_cells.get_activation(pos)
                init_act = self.place_cells.get_activation(init_pos)
                batch = {
                    "data": v,
                    "init_state": init_act,
                    "targets": place_outputs,
                    "init_pos": init_pos,
                    "target_pos": pos
                }

            yield batch

    def get_test_batch(self):
        """
        For testing performance, returns a batch of sample trajectories
        """ 
        traj = self.generate_trajectory()

        v = traj["v"]
        pos = torch.stack([traj["target_x"], traj["target_y"]], axis=-1)
        init_pos = torch.stack([traj["init_x"], traj["init_y"]], axis=-1)
        batch = {
            "data": v,
            "init_state": init_pos,
            "targets": pos,
            "init_pos": init_pos,
            "target_pos": pos
        }

        if self.use_place_cells:
            place_outputs = self.place_cells.get_activation(pos)
            init_act = torch.squeeze(self.place_cells.get_activation(init_pos))
            batch = {
                "data": v,
                "init_state": init_act,
                "targets": place_outputs,
                "init_pos": init_pos,
                "target_pos": pos
            }

        return batch
    
    def compute_metrics(self, outputs, targets, aux=None):
        criterion = torch.nn.MSELoss()
        loss = criterion(outputs, targets)
        metric = {
            "loss": loss.item()
        }

        if aux is not None:
            with torch.no_grad():
                decoded_pos = self.place_cells.get_nearest_cell_pos(outputs)
                pos_mse = criterion(decoded_pos, aux["target_pos"])
                metric["pos_mse"] = pos_mse.item()

        return loss, metric


class PlaceCells(object):
    def __init__(
        self,
        num_cells=512,
        sigma=0.2,
        surround_scale=2,
        box_width=2.2,
        box_height=2.2,
        diff_of_gaussians=False
    ):
        self.num_cells = num_cells
        self.sigma = sigma
        self.surround_scale = surround_scale
        self.box_width = box_width
        self.box_height = box_height
        self.diff_of_gaussians = diff_of_gaussians

        # Randomly tile place cell centers across environment
        usx = np.random.uniform(-self.box_width / 2, self.box_width / 2, (self.num_cells,))
        usy = np.random.uniform(-self.box_height / 2, self.box_height / 2, (self.num_cells,))
        self.us = torch.from_numpy(np.stack([usx, usy], axis=-1)).float()

    def get_activation(self, pos):
        """
        Get place cell activations for a given position
        """
        d = torch.abs(pos[:, :, None, :] - self.us[None, None, ...])
        norm2 = torch.sum(d ** 2, axis=-1)

        # Normalize with softmax (nearly equivalent to using prefactor)
        outputs = torch.softmax(-norm2 / (2 * self.sigma ** 2), dim=-1)

        if self.diff_of_gaussians:
            # Normalize again with softmax 
            outputs -= torch.softmax(-norm2 / (2 * self.surround_scale * self.sigma ** 2), dim=-1)

            # Shift and scale outputs so that they lie in [0, 1]
            outputs += torch.abs(torch.min(outputs, dim=-1, keepdim=True))
            outputs /= torch.sum(outputs, dim=-1, keepdim=True)

        return outputs

    def get_nearest_cell_pos(self, activation, k=3):
        """
        Decode position using centers of k maximally active place cells
        """
        idxs = torch.topk(activation, k=k)[1].cpu().detach().numpy()
        pred_pos = np.mean(np.take(self.us.cpu().detach().numpy(), idxs, axis=0), axis=-2)
        return torch.from_numpy(pred_pos).float()

    def grid_pc(self, pc_outputs, res=32):
        """
        Interpolate place cell outputs onto a grid
        """
        coordsx = np.linspace(-self.box_width / 2, self.box_width / 2, res)
        coordsy = np.linspace(-self.box_height / 2, self.box_height / 2, res)
        grid_x, grid_y = np.meshgrid(coordsx, coordsy)
        grid = np.stack([grid_x.ravel(), grid_y.ravel()]).T

        us_np = self.us.cpu().detach().numpy()
        pc_outputs = pc_outputs.cpu().detach().numpy().reshape(-1, self.num_cells)
        
        T = pc_outputs.shape[0]
        pc = np.zeros([T, res, res])
        for i in range(len(pc_outputs)):
            gridval = scipy.interpolate.griddata(us_np, pc_outputs[i], grid)
            pc[i] = gridval.reshape([res, res])
        
        return pc
