import numpy as np
import torch
import torch.distributions as td
import pdb
import copy

# from model.dynamics import Dynamic
#
# from model.components import GMM2D


def attach_dim(v, n_dim_to_prepend=0, n_dim_to_append=0):
    return v.reshape(
        torch.Size([1] * n_dim_to_prepend) + v.shape + torch.Size([1] * n_dim_to_append)
    )


def block_diag(m):
    """
    Make a block diagonal matrix along dim=-3
    EXAMPLE:
    block_diag(torch.ones(4,3,2))
    should give a 12 x 8 matrix with blocks of 3 x 2 ones.
    Prepend batch dimensions if needed.
    You can also give a list of matrices.
    :type m: torch.Tensor, list
    :rtype: torch.Tensor
    """
    if type(m) is list:
        m = torch.cat([m1.unsqueeze(-3) for m1 in m], -3)

    d = m.dim()
    n = m.shape[-3]
    siz0 = m.shape[:-3]
    siz1 = m.shape[-2:]
    m2 = m.unsqueeze(-2)
    eye = attach_dim(torch.eye(n, device=m.device).unsqueeze(-2), d - 3, 1)
    return (m2 * eye).reshape(siz0 + torch.Size(torch.tensor(siz1) * n))


def to_one_hot(labels, n_labels):
    return torch.eye(n_labels, device=labels.device)[labels]


class GMM2D(td.Distribution):
    r"""
    Gaussian Mixture Model using 2D Multivariate Gaussians each of as N components:
    Cholesky decompesition and affine transformation for sampling:

    .. math:: Z \sim N(0, I)

    .. math:: S = \mu + LZ

    .. math:: S \sim N(\mu, \Sigma) \rightarrow N(\mu, LL^T)

    where :math:`L = chol(\Sigma)` and

    .. math:: \Sigma = \left[ {\begin{array}{cc} \sigma^2_x & \rho \sigma_x \sigma_y \\ \rho \sigma_x \sigma_y & \sigma^2_y \\ \end{array} } \right]

    such that

    .. math:: L = chol(\Sigma) = \left[ {\begin{array}{cc} \sigma_x & 0 \\ \rho \sigma_y & \sigma_y \sqrt{1-\rho^2} \\ \end{array} } \right]

    :param log_pis: Log Mixing Proportions :math:`log(\pi)`. [..., N]
    :param mus: Mixture Components mean :math:`\mu`. [..., N * 2]
    :param log_sigmas: Log Standard Deviations :math:`log(\sigma_d)`. [..., N * 2]
    :param corrs: Cholesky factor of correlation :math:`\rho`. [..., N]
    :param clip_lo: Clips the lower end of the standard deviation.
    :param clip_hi: Clips the upper end of the standard deviation.
    """

    def __init__(self, log_pis, mus, log_sigmas, corrs):
        super(GMM2D, self).__init__(
            batch_shape=log_pis.shape[0],
            event_shape=log_pis.shape[1:],
            validate_args=False,
        )
        self.components = log_pis.shape[-1]
        self.dimensions = 2
        self.device = log_pis.device

        log_pis = torch.clamp(log_pis, min=-1e5)
        self.log_pis = log_pis - torch.logsumexp(
            log_pis, dim=-1, keepdim=True
        )  # [..., N]
        self.mus = self.reshape_to_components(mus)  # [..., N, 2]
        self.log_sigmas = self.reshape_to_components(log_sigmas)  # [..., N, 2]
        self.sigmas = torch.exp(self.log_sigmas)  # [..., N, 2]
        self.one_minus_rho2 = 1 - corrs**2  # [..., N]
        self.one_minus_rho2 = torch.clamp(
            self.one_minus_rho2, min=1e-5, max=1
        )  # otherwise log can be nan
        self.corrs = corrs  # [..., N]

        self.L = torch.stack(
            [
                torch.stack(
                    [self.sigmas[..., 0], torch.zeros_like(self.log_pis)], dim=-1
                ),
                torch.stack(
                    [
                        self.sigmas[..., 1] * self.corrs,
                        self.sigmas[..., 1] * torch.sqrt(self.one_minus_rho2),
                    ],
                    dim=-1,
                ),
            ],
            dim=-2,
        )

        self.pis_cat_dist = td.Categorical(logits=log_pis)

    @classmethod
    def from_log_pis_mus_cov_mats(cls, log_pis, mus, cov_mats):
        corrs_sigma12 = cov_mats[..., 0, 1]
        sigma_1 = torch.clamp(cov_mats[..., 0, 0], min=1e-8)
        sigma_2 = torch.clamp(cov_mats[..., 1, 1], min=1e-8)
        sigmas = torch.stack([torch.sqrt(sigma_1), torch.sqrt(sigma_2)], dim=-1)
        log_sigmas = torch.log(sigmas)
        corrs = corrs_sigma12 / (torch.prod(sigmas, dim=-1))
        return cls(log_pis, mus, log_sigmas, corrs)

    def rsample(self, sample_shape=torch.Size()):
        """
        Generates a sample_shape shaped reparameterized sample or sample_shape
        shaped batch of reparameterized samples if the distribution parameters
        are batched.

        :param sample_shape: Shape of the samples
        :return: Samples from the GMM.
        """
        mvn_samples = self.mus + torch.squeeze(
            torch.matmul(
                self.L,
                torch.unsqueeze(
                    torch.randn(size=sample_shape + self.mus.shape, device=self.device),
                    dim=-1,
                ),
            ),
            dim=-1,
        )
        component_cat_samples = self.pis_cat_dist.sample(sample_shape)
        selector = torch.unsqueeze(
            to_one_hot(component_cat_samples, self.components), dim=-1
        )
        return torch.sum(mvn_samples * selector, dim=-2)

    def traj_sample(self, sample_shape=torch.Size()):
        mvn_samples = self.mus
        component_cat_samples = self.pis_cat_dist.sample(sample_shape)
        selector = torch.unsqueeze(
            to_one_hot(component_cat_samples, self.components), dim=-1
        )
        return torch.sum(mvn_samples * selector, dim=-2)

    def log_prob(self, value):
        r"""
        Calculates the log probability of a value using the PDF for bivariate normal distributions:

        .. math::
            f(x | \mu, \sigma, \rho)={\frac {1}{2\pi \sigma _{x}\sigma _{y}{\sqrt {1-\rho ^{2}}}}}\exp
            \left(-{\frac {1}{2(1-\rho ^{2})}}\left[{\frac {(x-\mu _{x})^{2}}{\sigma _{x}^{2}}}+
            {\frac {(y-\mu _{y})^{2}}{\sigma _{y}^{2}}}-{\frac {2\rho (x-\mu _{x})(y-\mu _{y})}
            {\sigma _{x}\sigma _{y}}}\right]\right)

        :param value: The log probability density function is evaluated at those values.
        :return: Log probability
        """
        # x: [..., 2]
        value = torch.unsqueeze(value, dim=-2)  # [..., 1, 2]
        dx = value - self.mus  # [..., N, 2]

        exp_nominator = torch.sum(
            (dx / self.sigmas) ** 2, dim=-1
        ) - 2 * self.corrs * torch.prod(  # first and second term of exp nominator
            dx, dim=-1
        ) / torch.prod(
            self.sigmas, dim=-1
        )  # [..., N]

        component_log_p = (
            -(
                2 * np.log(2 * np.pi)
                + torch.log(self.one_minus_rho2)
                + 2 * torch.sum(self.log_sigmas, dim=-1)
                + exp_nominator / self.one_minus_rho2
            )
            / 2
        )

        return torch.logsumexp(self.log_pis + component_log_p, dim=-1)

    def get_for_node_at_time(self, n, t):
        return self.__class__(
            self.log_pis[:, n : n + 1, t : t + 1],
            self.mus[:, n : n + 1, t : t + 1],
            self.log_sigmas[:, n : n + 1, t : t + 1],
            self.corrs[:, n : n + 1, t : t + 1],
        )

    def mode(self):
        """
        Calculates the mode of the GMM by calculating probabilities of a 2D mesh grid

        :param required_accuracy: Accuracy of the meshgrid
        :return: Mode of the GMM
        """
        if self.mus.shape[-2] > 1:
            samp, bs, time, comp, _ = self.mus.shape
            assert samp == 1, "For taking the mode only one sample makes sense."
            mode_node_list = []
            for n in range(bs):
                mode_t_list = []
                for t in range(time):
                    nt_gmm = self.get_for_node_at_time(n, t)
                    x_min = self.mus[:, n, t, :, 0].min()
                    x_max = self.mus[:, n, t, :, 0].max()
                    y_min = self.mus[:, n, t, :, 1].min()
                    y_max = self.mus[:, n, t, :, 1].max()
                    search_grid = (
                        torch.stack(
                            torch.meshgrid(
                                [
                                    torch.arange(x_min, x_max, 0.01),
                                    torch.arange(y_min, y_max, 0.01),
                                ]
                            ),
                            dim=2,
                        )
                        .view(-1, 2)
                        .float()
                        .to(self.device)
                    )

                    ll_score = nt_gmm.log_prob(search_grid)
                    argmax = torch.argmax(ll_score.squeeze(), dim=0)
                    mode_t_list.append(search_grid[argmax])
                mode_node_list.append(torch.stack(mode_t_list, dim=0))
            return torch.stack(mode_node_list, dim=0).unsqueeze(dim=0)
        return torch.squeeze(self.mus, dim=-2)

    def reshape_to_components(self, tensor):
        if len(tensor.shape) == 5:
            return tensor
        return torch.reshape(
            tensor, list(tensor.shape[:-1]) + [self.components, self.dimensions]
        )

    def get_covariance_matrix(self):
        cov = self.corrs * torch.prod(self.sigmas, dim=-1)
        E = torch.stack(
            [
                torch.stack([self.sigmas[..., 0] ** 2, cov], dim=-1),
                torch.stack([cov, self.sigmas[..., 1] ** 2], dim=-1),
            ],
            dim=-2,
        )
        return E


class Dynamic(object):
    def __init__(self, dt, dyn_limits, device, model_registrar, xz_size, node_type):
        self.dt = dt
        self.device = device
        self.dyn_limits = dyn_limits
        self.initial_conditions = None
        self.model_registrar = model_registrar
        self.node_type = node_type
        self.init_constants()
        self.create_graph(xz_size)

    def set_initial_condition(self, init_con):
        self.initial_conditions = init_con
        # print(f"initial_conditions set as {self.initial_conditions}")

    def init_constants(self):
        pass

    def create_graph(self, xz_size):
        pass

    def integrate_samples(self, s, x):
        raise NotImplementedError

    def integrate_distribution(self, dist, x):
        raise NotImplementedError

    def create_graph(self, xz_size):
        pass


class SingleIntegrator(Dynamic):
    def init_constants(self):
        self.F = torch.eye(4, device=self.device, dtype=torch.float32)
        self.F[0:2, 2:] = (
            torch.eye(2, device=self.device, dtype=torch.float32) * self.dt
        )
        self.F_t = self.F.transpose(-2, -1)

    def integrate_samples(self, v, x=None):
        """
        Integrates deterministic samples of velocity.

        :param v: Velocity samples
        :param x: Not used for SI.
        :return: Position samples
        """
        p_0 = self.initial_conditions["pos"].unsqueeze(1)
        # v = (v * self.dt)
        # for i, p in enumerate(p_0):
        #     # p # [x,y]  (1,2)
        #     # dtv[i] # (12,2)
        #     temp = None
        #     for j,v_unit in enumerate(v[i]):
        #         #pdb.set_trace()
        #         if j == 0:
        #             pos = p + v_unit
        #             v[i][j] = pos
        #             #pdb.set_trace()
        #             temp = v[i][j]
        #         else:
        #             pos = v[i][j] + temp
        #             v[i][j] = pos
        #             temp = v[i][j]
        #         #pdb.set_trace()
        # return v
        # pdb.set_trace()
        # print(p_0)
        return torch.cumsum(v, dim=2) * self.dt + p_0

    def integrate_distribution(self, v_dist, x=None):
        r"""
        Integrates the GMM velocity distribution to a distribution over position.
        The Kalman Equations are used.

        .. math:: \mu_{t+1} =\textbf{F} \mu_{t}

        .. math:: \mathbf{\Sigma}_{t+1}={\textbf {F}} \mathbf{\Sigma}_{t} {\textbf {F}}^{T}

        .. math::
            \textbf{F} = \left[
                            \begin{array}{cccc}
                                \sigma_x^2 & \rho_p \sigma_x \sigma_y & 0 & 0 \\
                                \rho_p \sigma_x \sigma_y & \sigma_y^2 & 0 & 0 \\
                                0 & 0 & \sigma_{v_x}^2 & \rho_v \sigma_{v_x} \sigma_{v_y} \\
                                0 & 0 & \rho_v \sigma_{v_x} \sigma_{v_y} & \sigma_{v_y}^2 \\
                            \end{array}
                        \right]_{t}

        :param v_dist: Joint GMM Distribution over velocity in x and y direction.
        :param x: Not used for SI.
        :return: Joint GMM Distribution over position in x and y direction.
        """
        p_0 = self.initial_conditions["pos"].unsqueeze(1)
        ph = v_dist.mus.shape[-3]
        sample_batch_dim = list(v_dist.mus.shape[0:2])
        pos_dist_sigma_matrix_list = []

        pos_mus = p_0[:, None] + torch.cumsum(v_dist.mus, dim=2) * self.dt

        vel_dist_sigma_matrix = v_dist.get_covariance_matrix()
        pos_dist_sigma_matrix_t = torch.zeros(
            sample_batch_dim + [v_dist.components, 2, 2], device=self.device
        )

        for t in range(ph):
            vel_sigma_matrix_t = vel_dist_sigma_matrix[:, :, t]
            full_sigma_matrix_t = block_diag(
                [pos_dist_sigma_matrix_t, vel_sigma_matrix_t]
            )
            pos_dist_sigma_matrix_t = self.F[..., :2, :].matmul(
                full_sigma_matrix_t.matmul(self.F_t)[..., :2]
            )
            pos_dist_sigma_matrix_list.append(pos_dist_sigma_matrix_t)

        pos_dist_sigma_matrix = torch.stack(pos_dist_sigma_matrix_list, dim=2)
        return GMM2D.from_log_pis_mus_cov_mats(
            v_dist.log_pis, pos_mus, pos_dist_sigma_matrix
        )
