from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel
from easy_tpp.model.torch_model.torch_dlhp import ComplexEmbedding

EPS = 1e-8


def _perm(x: torch.Tensor) -> torch.Tensor:
    """Pairwise permutation. e.g., [0, 1, 2, 3] -> [1, 0, 3, 2]"""
    return torch.stack([x[..., 1::2], x[..., ::2]], -1).view(*x.shape)


def _orth_mult(x: torch.Tensor, angles: torch.Tensor, is_odd: bool) -> torch.Tensor:
    assert x.ndim == angles.ndim

    if is_odd:
        x = x.roll(1, -1)

    sin_theta, cos_theta = torch.sin(angles.real), torch.cos(angles.real)
    a, b = 1, 1
    if angles.dtype == torch.complex64:
        a = torch.exp(angles.imag * 1j)
        b = 1 + 0j
    v1 = torch.stack([a * cos_theta, b * cos_theta], -1).view(*angles.shape[:-1], x.shape[-1])
    v2 = torch.stack([-a * sin_theta, b * sin_theta], -1).view(*angles.shape[:-1], x.shape[-1])

    x = v1 * x + v2 * _perm(x)

    if is_odd:
        x = x.roll(-1, -1)

    return x


def _rotate(x: torch.Tensor, angles: torch.Tensor, inv: bool = False) -> torch.Tensor:
    d = x.shape[-1] // 2
    switch, is_odd = (-1, True) if inv else (1, False)
    angles = torch.split(switch * angles, d, -1)

    for angle in angles[::switch]:  # reverses order if inv is True
        x = _orth_mult(x, angle, is_odd)
        is_odd = not is_odd

    return x


@dataclass
class LatentParameters:
    A: torch.Tensor
    A_inv: Optional[torch.Tensor]
    B: torch.Tensor
    C: torch.Tensor
    S: torch.FloatTensor | float
    last_hidden: torch.Tensor


class SoftNorm(nn.Module):
    def __init__(self, dim=-1, eps=1, ord=2):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.ord = ord

    def forward(self, x):
        norm = torch.linalg.vector_norm(
            x,
            dim=self.dim,
            ord=self.ord,
            keepdim=True,
        )

        return x / (norm + self.eps)


class HyperNetwork(nn.Module):
    def __init__(self, model_config):
        super().__init__()

        self.non_latent = model_config.non_latent
        self.num_marks = model_config.num_event_types_pad

        if self.non_latent:
            (
                model_config.predict_A,
                model_config.predict_B,
                model_config.predict_C,
                model_config.predict_S,
            ) = (False, False, False, False)

        self.rnn_dim, self.latent_dim = (
            model_config.hyper_size,
            model_config.hidden_size,
        )
        self.latent_dtype = torch.complex64 if model_config.complex_latent else torch.float32

        self.output_groups = []

        self.orthogonal_A = model_config.orthogonal_A
        if self.orthogonal_A:
            assert model_config.num_rotations % 2 == 0
        print("IS ORTHOGONAL A:", self.orthogonal_A)
        self.full_A = model_config.full_A
        assert not (self.orthogonal_A and self.full_A)  # Cannot be orthogonal and a full matrix
        self.full_C = model_config.full_C
        self.A, self.B, self.C, self.S = None, None, None, 1
        self.pred_A, self.pred_B, self.pred_C, self.pred_S = (
            model_config.predict_A,
            model_config.predict_B,
            model_config.predict_C,
            model_config.predict_S,
        )

        if self.pred_A:
            if self.full_A:
                self.output_groups.append(self.latent_dim**2)
            elif self.orthogonal_A:
                self.output_groups.append((self.latent_dim // 2) * model_config.num_rotations)
            else:
                self.output_groups.append(self.latent_dim)
        else:
            self.output_groups.append(0)
            if model_config.full_A:
                self.A = nn.Parameter(torch.randn(self.latent_dim * self.latent_dim, dtype=self.latent_dtype) * 1e-3)
            else:
                self.A = nn.Parameter(torch.randn(self.latent_dim, dtype=self.latent_dtype) * 1e-3)
        if not model_config.full_A:
            self.register_buffer("I", torch.eye(self.latent_dim, dtype=self.latent_dtype))

        if self.pred_B:
            self.output_groups.append(
                self.latent_dim
            )  # Predict mark-specific B, which is a vector and not a full matrix
        else:
            self.output_groups.append(0)
            emb_class = ComplexEmbedding if model_config.complex_latent else nn.Embedding
            self.B = emb_class(
                self.num_marks,
                self.latent_dim,
                padding_idx=model_config.pad_token_id,
            )

        if self.pred_C:  # estimate low-rank components
            if self.full_C:
                self.output_groups.append(self.latent_dim * self.num_marks)
            else:
                self.output_groups.append(self.latent_dim + self.num_marks)
        elif self.non_latent:
            self.output_groups.append(0)
            self.full_C = True
            del self.C  # To not collide with register buffer
            self.register_buffer("C", torch.eye(self.num_marks, dtype=self.latent_dtype).reshape(-1))
        else:
            self.output_groups.append(0)
            if self.full_C:
                self.C = nn.Parameter(torch.randn(self.latent_dim * self.num_marks, dtype=self.latent_dtype) * 1e-3)
            else:
                self.C = nn.Parameter(torch.randn(self.latent_dim + self.num_marks, dtype=self.latent_dtype) * 1e-3)

        if self.pred_S:  # scales the diagonal components for A
            self.output_groups.append(self.latent_dim)
            # self.softplus = ScaledSoftplus(self.latent_dim)
            self.softplus = nn.Softplus()  #
        else:
            self.output_groups.append(0)

        assert not (model_config.full_A and model_config.normalize_A)
        self.A_norm = SoftNorm() if model_config.normalize_A else nn.Identity()
        self.B_norm = SoftNorm() if model_config.normalize_B else nn.Identity()

        self.uses_model = sum(self.output_groups) > 0
        if self.uses_model:
            self.embeddings = nn.Embedding(
                self.num_marks,  # have padding
                self.rnn_dim - 1,  # will include time as the last component
                padding_idx=model_config.pad_token_id,
            )

            self.uses_transformer = False
            if model_config.hyper_type.upper() == "GRU":
                self.network = nn.GRU(
                    input_size=self.rnn_dim,
                    hidden_size=self.rnn_dim,
                    batch_first=True,
                    num_layers=model_config.hyper_layers,
                )
            elif model_config.hyper_type.upper() == "LSTM":
                self.network = nn.GRU(
                    input_size=self.rnn_dim,
                    hidden_size=self.rnn_dim,
                    batch_first=True,
                    num_layers=model_config.hyper_layers,
                )
            elif model_config.hyper_type.upper() == "TRANSFORMER":
                self.uses_transformer = True
                layer = nn.TransformerEncoderLayer(
                    d_model=self.rnn_dim,
                    nhead=model_config.hyper_nhead,
                    batch_first=True,
                    dropout=model_config.hyper_dropout,
                )
                self.network = nn.TransformerEncoder(encoder_layer=layer, num_layers=model_config.hyper_layers)

            self.to_comps = nn.Linear(self.rnn_dim, sum(self.output_groups), dtype=self.latent_dtype)
            self.to_comps.weight.data = self.to_comps.weight.data * 0.0
            self.to_comps.bias.data = self.to_comps.bias.data * 1e-3

    def forward(self, dts, marks, last_hidden=None) -> LatentParameters:
        # last dimension of dts and marks is the sequence dimension

        A, B, C, S = None, None, None, None
        if self.uses_model:
            model_input = torch.cat([self.embeddings(marks), torch.log(dts + EPS)[..., None]], dim=-1)
            if self.uses_transformer:
                output = self.network(
                    model_input,
                    mask=nn.Transformer.generate_square_subsequent_mask(dts.shape[-1], device=dts.device),
                    is_causal=True,
                )
            else:
                output, last_hidden = self.network(model_input, last_hidden)

            A, B, C, S = torch.split(
                self.to_comps(output.type(self.latent_dtype)),
                self.output_groups,
                dim=-1,
            )

        A = A if self.pred_A else self.A
        if not self.pred_A:
            A = A.view(*(1 for _ in dts.shape), *A.shape).expand(*dts.shape, -1)
        if self.full_A:
            A = A.view(*A.shape[:-1], self.latent_dim, self.latent_dim)
            A_inv = None  # To be computed implicitly with `torch.linalg.solve` later
        elif self.orthogonal_A:
            # Will use `_rotate` command later with A being the angles argument
            A_inv = None  # Not needed
        else:
            # Woodbury Identity: https://en.wikipedia.org/wiki/Woodbury_matrix_identity
            A = self.A_norm(A)
            UV = torch.einsum("...i, ...j -> ...ij", A, A)  # outer product
            VU = torch.einsum("...i, ...i -> ...i", A, A).unsqueeze(-1)  # inner product
            I = self.I.view(*(1 for _ in A.shape[:-1]), self.latent_dim, self.latent_dim)
            A = I + UV
            A_inv = I - UV / (1 + VU)

        B = B if self.pred_B else self.B(marks)
        B = self.B_norm(B)

        C = C if self.pred_C else self.C
        if not self.pred_C:
            C = C.view(*(1 for _ in dts.shape), *C.shape).expand(*dts.shape, -1)
        if self.full_C:
            C = C.reshape(*C.shape[:-1], self.num_marks, self.latent_dim)
        else:
            C = torch.einsum(
                "...l, ...m -> ...ml",
                *torch.split(C, [self.latent_dim, self.num_marks], dim=-1),
            )

        S = self.softplus(S.real) if self.pred_S else self.S

        return LatentParameters(
            A=A,
            A_inv=A_inv,
            B=B,
            C=C,
            S=S,
            last_hidden=last_hidden,
        )


class HHP(TorchBaseModel):
    """Hyper Hawkes Process"""

    def __init__(self, model_config):
        """Initialize the model

        Args:
            model_config (EasyTPP.ModelConfig): config of model specs.

        """
        model_config.full_A = model_config.full_A and not model_config.orthogonal_A
        super().__init__(model_config)

        del self.layer_type_emb

        self.num_marks = model_config.num_event_types_pad
        if model_config.non_latent:
            print("DISABLED LATENT SPACE")

            model_config.hidden_size = self.num_marks

        self.is_complex = model_config.complex_latent
        self.latent_dtype = torch.complex64 if self.is_complex else torch.float32
        self.latent_dim = model_config.hidden_size
        self.orthogonal_A = model_config.orthogonal_A

        self.hypernet = HyperNetwork(model_config)
        self.mu = nn.Parameter(torch.randn(self.num_marks) * 1e-3)
        # self.softplus = ScaledSoftplus(self.num_marks)
        self.softplus = nn.Softplus()
        self.lambda_log_neg_real = nn.Parameter(torch.randn(self.latent_dim) * 1e-3)
        if self.is_complex:
            self.lambda_imag = nn.Parameter(torch.randn(self.latent_dim) * 1e-3)

    def forward(self, dts, marks, num_samples=None):
        """Call the model.

        Args:
            dts (tensor): [batch_size, seq_len], inter-event time seqs.
            marks (tensor): [batch_size, seq_len], event type seqs.
            num_samples (int): number of samples to take

        Returns:
            list: hidden states, [batch_size, seq_len, hidden_dim], states right before the event happens.
        """

        # dts: dt_0:=0, dt_1:=t_1, dt_2:=t_2-t_1, ..., dt_n:=t_n-t_{n-1}
        # marks: k_0, k_1, k_2, ..., k_n
        # Condition on (dt_0, k_0), only predict (dt_1, k_1) to (dt_n, k_n)

        n = dts.shape[-1]
        p: LatentParameters = self.hypernet.forward(dts=dts, marks=marks)

        mu = self.mu
        for _ in range(dts.ndim):
            mu = mu.unsqueeze(0)

        lambda_diag = -self.lambda_log_neg_real.exp()
        if self.is_complex:
            lambda_diag = torch.complex(real=lambda_diag, imag=self.lambda_imag)

        # Cast into a sequence of lambda values and scale
        lambda_diag = p.S * lambda_diag.view(*(1 for _ in p.B.shape[:-1]), -1).expand(*p.B.shape)
        L = (
            dts[..., 1:, None] * lambda_diag[..., :-1, :]
        ).exp()  # evolve from t_i to t_{i+1} using lambda_i and dt_{i+1} := t_{i+1} - t_i

        # Recurrence
        left_x = 0
        left_xs = []  # left limits, just before event occurrences - used for intensities at event times
        right_xs = []  # right limits, just after event occurrences - used to evolve towards next events
        for i in range(n - 1):
            # Add impulse to left limit to get right limit
            right_x = left_x + p.B[..., i, :]

            # Evolve right limit to get left limit: A @ exp(dt * scale * diag(lambda)) @ A^{-1} @ x
            # Multiply right to left to perform matrix-vector products instead matrix-matrix products
            if self.orthogonal_A:
                right_x = _rotate(right_x, p.A[..., i, :], inv=True)
            elif p.A_inv is None:
                right_x = torch.linalg.solve(p.A[..., i, :, :], right_x)
            else:
                right_x = torch.einsum("...ij, ...j -> ...i", p.A_inv[..., i, :, :], right_x)
            left_x = L[..., i, :] * right_x
            if self.orthogonal_A:
                left_x = _rotate(left_x, p.A[..., i, :])
            else:
                left_x = torch.einsum("...ij, ...j -> ...i", p.A[..., i, :, :], left_x)

            right_xs.append(right_x)  # Store precomputed A^{-1} @ x_{t_i+} instead of just x_{t_i+}
            left_xs.append(left_x)

        right_xs, left_xs = (
            torch.stack(right_xs, dim=-2),
            torch.stack(left_xs, dim=-2),
        )  # Shapes: (..., seq_dim, latent_dim)

        # Times for resulting x values:
        # left_xs: t_1-, t_2-, ..., t_n-
        # right_xs: t_0+, t_1+, ..., t_{n-1}+

        event_intensities = self.softplus(
            2
            * torch.einsum(
                "...ml, ...l -> ...m",
                p.C[..., :-1, :, :],
                left_xs,
            ).real  # Apply conjugate symmetry trick
            + mu
        )  # Shape: (..., seq_dim-1, marks)
        sampled_intensities = None
        if num_samples:
            sampled_dts = dts[..., 1:, None] * torch.rand(dts.shape[:-1] + (n - 1, num_samples), device=dts.device)
            L_samples = (
                sampled_dts[..., None] * lambda_diag[..., :-1, None, :]
            ).exp()  # Shape: (..., seq_dim, sample_dim, latent_dim)
            if self.orthogonal_A:
                sampled_intensities = (
                    2
                    * torch.einsum(
                        "...mi, ...ni -> ...nm",
                        p.C[..., :-1, :, :],
                        _rotate(
                            L_samples * right_xs[..., :, None, :],
                            p.A[..., :-1, None, :],
                        ),
                    ).real
                )
            else:
                sampled_intensities = (
                    2
                    * torch.einsum(
                        "...mi, ...ij, ...nj -> ...nm",
                        p.C[..., :-1, :, :],
                        p.A[..., :-1, :, :],
                        L_samples * right_xs[..., :, None, :],
                    ).real
                )
            sampled_intensities = self.softplus(sampled_intensities + mu[..., None, :])
            # Shape: (..., seq_dim-1, sample_dim, marks)

        return event_intensities, sampled_intensities

    def loglike_loss(self, batch, **kwargs):
        """Compute the loglike loss.

        Args:
            batch (list): batch input.

        Returns:
            tuple: loglikelihood loss and num of events.
        """

        times, dts, marks, batch_non_pad_mask, _ = batch

        event_intensities, sampled_intensities = self.forward(
            dts=dts, marks=marks, num_samples=self.loss_integral_num_sample_per_step
        )

        event_ll, non_event_ll, num_events, mark_ll, time_ll_pos = self.compute_loglikelihood(
            lambda_at_event=event_intensities,
            lambdas_loss_samples=sampled_intensities,
            time_delta_seq=dts[:, 1:],
            seq_mask=batch_non_pad_mask[:, 1:],
            type_seq=marks[:, 1:],
        )

        # compute extra statistics
        time_ll = time_ll_pos - non_event_ll

        # compute loss to optimize
        loss = -(event_ll - non_event_ll).sum()

        return_raw_ll = kwargs.get("return_raw_ll", False)
        res_dict = {"non_event_ll": non_event_ll, "mark_intensity": event_intensities} if return_raw_ll else None

        return loss, num_events, mark_ll.sum(), time_ll.sum(), res_dict

    def predict_one_step_at_every_event(self, batch, **kwargs):
        """One-step prediction for every event in the sequence.

        Args:
            time_seqs (tensor): [batch_size, seq_len].
            time_delta_seqs (tensor): [batch_size, seq_len].
            type_seqs (tensor): [batch_size, seq_len].

        Returns:
            tuple: tensors of dtime and type prediction, [batch_size, seq_len].
        """

        raise NotImplementedError()

        # time_seq, time_delta_seq, event_seq, batch_non_pad_mask, _ = batch

        # dtime_boundary = time_delta_seq[:, 1:] + self.gen_config.dtime_max

        # # remove the last event, as the prediction based on the last event has no label
        # # time_delta_seq should start from 1, because the first one is zero
        # time_seq, time_delta_seq, event_seq = time_seq[:, :-1], time_delta_seq[:, :-1], event_seq[:, :-1]

        # # [batch_size, seq_len, hidden_size]
        # context = self.forward(time_delta_seq, event_seq)

        # # [batch_size, seq_len, 3 * num_mix_components]
        # raw_params = self.linear(context)
        # locs = raw_params[..., :self.num_mix_components]
        # log_scales = raw_params[..., self.num_mix_components: (2 * self.num_mix_components)]
        # log_weights = raw_params[..., (2 * self.num_mix_components):]

        # log_scales = clamp_preserve_gradients(log_scales, -5.0, 3.0)
        # log_weights = torch.log_softmax(log_weights, dim=-1)
        # inter_time_dist = LogNormalMixtureDistribution(
        #     locs=locs,
        #     log_scales=log_scales,
        #     log_weights=log_weights,
        #     mean_log_inter_time=self.mean_log_inter_time,
        #     std_log_inter_time=self.std_log_inter_time
        # )

        # # [num_samples, batch_size, seq_len]
        # accepted_dtimes = inter_time_dist.sample((self.event_sampler.num_sample,)).clamp(max=dtime_boundary)
        # # accepted_dtimes = torch.where(accepted_dtimes > dtime_boundary, dtime_boundary, accepted_dtimes)
        # # dtimes_pred = accepted_dtimes.mean(dim=0)

        # get_raw_pred_next_time = kwargs.get('get_raw_pred_next_time', False)
        # if get_raw_pred_next_time:
        #     dtimes_pred = torch.permute(accepted_dtimes, (1, 2, 0))
        # else:
        #     dtimes_pred = accepted_dtimes.mean(dim=0)

        # # [batch_size, seq_len, num_marks]
        # mark_logits = torch.log_softmax(self.mark_linear(context), dim=-1)  # Marks are modeled conditionally independently from times

        # get_raw_mark_distribution = kwargs.get('get_raw_mark_distribution', False)
        # if get_raw_mark_distribution:
        #     types_pred = mark_logits[..., :-1]  # TODO: take exp? check the effect of having probability mass on padded event id
        # else:
        #     types_pred = torch.argmax(mark_logits, dim=-1)

        # return dtimes_pred, types_pred

    @torch.no_grad()
    def attributions(self, dts: torch.Tensor, marks: torch.Tensor, num_samples=100):
        """Call the model.

        Args:
            dts (tensor): [seq_len], inter-event time seqs.
            marks (tensor): [seq_len], event type seqs.
            num_samples (int): number of samples to take

        Returns:
            list: hidden states, [batch_size, seq_len, hidden_dim], states right before the event happens.
        """

        # dts: dt_0:=0, dt_1:=t_1, dt_2:=t_2-t_1, ..., dt_n:=t_n-t_{n-1}
        # marks: k_0, k_1, k_2, ..., k_n
        # Condition on (dt_0, k_0), only predict (dt_1, k_1) to (dt_n, k_n)
        if dts.ndim > 1:
            dts = dts.squeeze()

        if marks.ndim > 1:
            marks = marks.squeeze()

        assert dts.ndim == 1 and marks.ndim == 1
        n = dts.shape[-1]
        dts, marks = dts[None, :], marks[None, :]
        p: LatentParameters = self.hypernet.forward(dts=dts, marks=marks)

        lambda_diag = -self.lambda_log_neg_real.exp()
        if self.is_complex:
            lambda_diag = torch.complex(real=lambda_diag, imag=self.lambda_imag)

        # Cast into a sequence of lambda values and scale
        lambda_diag = p.S * lambda_diag.view(*(1 for _ in p.B.shape[:-1]), -1).expand(*p.B.shape)
        L = (
            dts[..., 1:, None] * lambda_diag[..., :-1, :]
        ).exp()  # evolve from t_i to t_{i+1} using lambda_i and dt_{i+1} := t_{i+1} - t_i

        impulses = p.B.squeeze(0)  # [n, latent dim]

        # Recurrence
        left_x = torch.zeros_like(impulses)  # Every event gets its own hidden state
        left_xs = []  # left limits, just before event occurrences - used for intensities at event times
        right_xs = []  # right limits, just after event occurrences - used to evolve towards next events
        for i in range(n - 1):
            # Add impulse to left limit to get right limit
            right_x = left_x.clone()
            right_x[i, :] += impulses[i, :]

            # Evolve right limit to get left limit: A @ exp(dt * scale * diag(lambda)) @ A^{-1} @ x
            # Multiply right to left to perform matrix-vector products instead matrix-matrix products
            if self.orthogonal_A:
                right_x = _rotate(right_x, p.A[..., i, :], inv=True)
            elif p.A_inv is None:
                right_x = torch.linalg.solve(p.A[[0] * right_x.shape[0], i, :, :], right_x)
            else:
                right_x = torch.einsum("...ij, ...j -> ...i", p.A_inv[..., i, :, :], right_x)
            left_x = L[..., i, :] * right_x
            if self.orthogonal_A:
                left_x = _rotate(left_x, p.A[..., i, :])
            else:
                left_x = torch.einsum("...ij, ...j -> ...i", p.A[..., i, :, :], left_x)

            right_xs.append(right_x)  # Store precomputed A^{-1} @ x_{t_i+} instead of just x_{t_i+}
            left_xs.append(left_x)

        right_xs, left_xs = (
            torch.stack(right_xs, dim=-2),
            torch.stack(left_xs, dim=-2),
        )  # Shapes: (..., seq_dim, latent_dim)

        # Times for resulting x values:
        # left_xs: t_1-, t_2-, ..., t_n-
        # right_xs: t_0+, t_1+, ..., t_{n-1}+

        sampled_dts = dts[..., 1:, None] * torch.linspace(0.0, 1.0, num_samples, device=dts.device)[None, None, :]
        L_samples = (
            sampled_dts[..., None] * lambda_diag[..., :-1, None, :]
        ).exp()  # Shape: (..., seq_dim, sample_dim, latent_dim)
        if self.orthogonal_A:
            sampled_xs = _rotate(
                L_samples * right_xs[..., :, None, :],
                p.A[..., :-1, None, :],
            )
            sampled_intensities = (
                2
                * torch.einsum(
                    "...mi, ...ni -> ...nm",
                    p.C[..., :-1, :, :],
                    _rotate(
                        L_samples * right_xs[..., :, None, :],
                        p.A[..., :-1, None, :],
                    ),
                ).real
            )
        else:
            sampled_intensities = (
                2
                * torch.einsum(
                    "...mi, ...ij, ...nj -> ...nm",
                    p.C[..., :-1, :, :],
                    p.A[..., :-1, :, :],
                    L_samples * right_xs[..., :, None, :],
                ).real
            )
        # Shape: (..., seq_dim-1, sample_dim, marks)
        # Do not actually compute intensity, as we want pre-activation values
        # sampled_intensities = self.softplus(
        #     sampled_intensities + self.mu[None, None, :]
        # )

        # Last dimension is padding
        return self.mu[:-1], sampled_intensities[..., :-1], sampled_dts.squeeze(0), sampled_xs
