import numpy as np
import scipy
import torch

import utils

class HeadDirection:
    def __init__(
        self,
        dimensionality="2D",
        init_hd="uniform",
        biased=False,
        drift_const=0.05,
        anchor_angle=0,
        dt=0.02,
        sigma=11.52,
        mu=0,
        use_hd_cells=True,
        hd_cells_num=128,
        hd_cells_angular_spread=np.pi/6,
        sequence_length=100,
        batch_size=200
    ):
        self.dimensionality = dimensionality
        self.init_hd = init_hd
        self.biased = biased
        self.drift_const = drift_const
        self.anchor_angle = anchor_angle
        self.dt = dt  # time step increment (s)
        self.sigma = sigma  # std. dev. rotation velocity (rad/s)
        self.mu = mu  # turn angle bias
        self.use_hd_cells = use_hd_cells
        if self.use_hd_cells:
            self.hd_cells = HeadDirectionCells(
                dimensionality=self.dimensionality,
                num_cells=hd_cells_num,
                angular_spread=hd_cells_angular_spread
            )
        self.sequence_length = sequence_length
        self.batch_size = batch_size

    def generate_trajectory(self):
        """
        Generate a random trajectory
        """
        samples = self.sequence_length

        # Initialize variables
        head_dir = np.zeros([self.batch_size, samples + 1])

        head_dir[:, 0] = np.random.uniform(0, 2 * np.pi, self.batch_size)
        ang_velocity = np.zeros([self.batch_size, samples])
        updates = np.zeros([self.batch_size, samples])

        # Generate sequence of random turns
        random_turn = np.random.normal(self.mu, self.sigma, [self.batch_size, samples])

        for t in range(samples):
            ang_velocity[:, t] = self.dt * random_turn[:, t]
            update = ang_velocity[:, t]

            if self.biased:
                update += self.drift_const * (self.anchor_angle - head_dir[:, t])

            updates[:, t] = update
            head_dir[:, t + 1] = head_dir[:, t] + update

        head_dir = np.mod(head_dir + np.pi, 2 * np.pi) - np.pi # Periodic variable, modify range to [-π, π]

        traj = {}

        # Input variables
        traj["init_hd"] = torch.from_numpy(head_dir[:, 0, None]).float()
        traj["ang_v"] = torch.from_numpy(updates[:, :, None]).float()

        # Target variables
        traj["target_hd"] = torch.from_numpy(head_dir[:, 1:, None]).float()

        return traj

    def get_generator(self):
        """
        Returns a generator that yields batches of trajectories
        """ 
        while True:
            traj = self.generate_trajectory()

            ang_v = traj["ang_v"]
            hd = traj["target_hd"]
            init_hd = traj["init_hd"].unsqueeze(-1)

            if self.dimensionality == "sin-cos":
                ang_v = torch.stack((torch.sin(ang_v), torch.cos(ang_v)), axis=-1).squeeze()
                hd = torch.stack((torch.sin(hd), torch.cos(hd)), axis=-1).squeeze()
                init_hd = torch.stack((torch.sin(init_hd), torch.cos(init_hd)), axis=-1).squeeze()

            batch = {
                "data": ang_v,
                "init_state": init_hd,
                "targets": hd,
                "init_hd": init_hd,
                "target_hd": hd
            }

            if self.use_hd_cells:
                hd_outputs = self.hd_cells.get_activation(hd)
                init_act = torch.squeeze(self.hd_cells.get_activation(init_hd))
                batch = {
                    "data": ang_v,
                    "init_state": init_act,
                    "targets": hd_outputs,
                    "init_hd": init_hd,
                    "target_hd": hd
                }

            yield batch

    def get_test_batch(self):
        """
        For testing performance, returns a batch of sample trajectories
        """ 
        traj = self.generate_trajectory()

        ang_v = traj["ang_v"]
        hd = traj["target_hd"]
        init_hd = traj["init_hd"].unsqueeze(-1)

        if self.dimensionality == "sin-cos":
            ang_v = torch.stack((torch.sin(ang_v), torch.cos(ang_v)), axis=-1).squeeze()
            hd = torch.stack((torch.sin(hd), torch.cos(hd)), axis=-1).squeeze()
            init_hd = torch.stack((torch.sin(init_hd), torch.cos(init_hd)), axis=-1).squeeze()

        batch = {
            "data": ang_v,
            "init_state": init_hd,
            "targets": hd,
            "init_hd": init_hd,
            "target_hd": hd
        }

        if self.use_hd_cells:
            hd_outputs = self.hd_cells.get_activation(hd)
            init_act = torch.squeeze(self.hd_cells.get_activation(init_hd))
            batch = {
                "data": ang_v,
                "init_state": init_act,
                "targets": hd_outputs,
                "init_hd": init_hd,
                "target_hd": hd
            }

        return batch

    def compute_metrics(self, outputs, targets, aux=None):
        criterion = torch.nn.MSELoss()
        loss = criterion(outputs, targets)
        metric = {
            "loss": loss.item()
        }

        if aux is not None and self.use_hd_cells:
            with torch.no_grad():
                decoded_hd = self.hd_cells.decode_hd(outputs)
                hd_mse = criterion(decoded_hd, aux["target_hd"])
                metric["hd_mse"] = hd_mse.item()

        return loss, metric


class HeadDirectionCells(object):
    def __init__(
        self,
        dimensionality="2D",
        num_cells=128,
        angular_spread=np.pi/6,
    ):
        self.dimensionality = dimensionality
        self.num_cells = num_cells
        self.angular_spread = angular_spread

        self.us = torch.linspace(-torch.pi, torch.pi, self.num_cells).float()
        self.vs = torch.tensor([angular_spread for _ in range(self.num_cells)])

        if self.dimensionality == "1D":
            self.num_cells = 2
            self.angular_spread = None
            self.us = None
            self.vs = None

    def get_activation(self, hd):
        """
        Get head direction cell activations for a given head direction
        """
        outputs = utils.von_mises(hd, self.us, self.vs, norm=1)
        return outputs

    def decode_hd(self, activation, k=3):
        """
        Decode head direction using activities of head direction cells
        """
        idxs = torch.topk(activation, k=k)[1].cpu().detach().numpy()
        pred_hd = np.take(self.us.cpu().detach().numpy(), idxs, axis=0)
        pred_cos = np.cos(pred_hd).mean(axis=-1)
        pred_sin = np.sin(pred_hd).mean(axis=-1)
        pred_hd = np.arctan2(pred_sin, pred_cos)
        return torch.from_numpy(pred_hd).unsqueeze(-1).float()
