# src/flow_matching_matrix_to_matrix.py
from typing import Dict, Tuple

import torch
import torch.nn.functional as F
from torch import Tensor, nn

from group_discovery.distributions import DiscreteDeltaMixture
from group_discovery.utils import batched_div, blogm, expand_like

# helpers functions
expm = torch.linalg.matrix_exp


# Model
class FlowMatrixToMatrix(nn.Module):
    def __init__(
        self,
        in_dim,
        out_dim,
        hidden_dim,
        prior_dist,
        device,
        max_grad_norm=None,
        spectral_norm_weights=False,
        max_output_eigenvalue=None,
    ):
        super().__init__()
        layers = [
            nn.Linear(in_dim + 1, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, out_dim),
        ]

        if spectral_norm_weights:
            for i in range(len(layers)):
                if isinstance(layers[i], nn.Linear):
                    # Apply spectral normalization to all linear layers
                    layers[i] = nn.utils.parametrizations.spectral_norm(layers[i])

        self.net = nn.Sequential(*layers)
        self.device = device
        self.prior_dist = prior_dist

        self.max_grad_norm = max_grad_norm
        self.spectral_norm_weights = spectral_norm_weights
        self.max_output_eigenvalue = max_output_eigenvalue

        self.reset_parameters()

    def reset_parameters(self):
        for i, layer in enumerate(self.net):
            if isinstance(layer, nn.Linear):
                if i == len(self.net) - 1:  # Last layer
                    nn.init.uniform_(layer.weight, -0.01, 0.01)
                else:
                    nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu")
                nn.init.zeros_(layer.bias)

    def forward(self, x_t: Tensor, t: Tensor) -> Tensor:
        B = x_t.shape[0]
        orig_dims = x_t.shape[1:]

        x_t = x_t.flatten(1)
        if t.dim() == 0:
            t = t.view(1, 1).expand(B, 1)
        elif t.dim() == 1:
            t = t.view(B, 1)

        xin = torch.cat((x_t, t), dim=1)

        out = self.net(xin)

        out = out.view(B, *orig_dims)

        if self.max_output_eigenvalue is not None:
            # Frobenius norm scaling, approximation of spectral normalization
            frob_norm = torch.norm(out.view(B, -1), dim=1, keepdim=True)
            max_safe_norm = self.max_output_eigenvalue * 2.0  # Heuristic
            scale = torch.minimum(
                max_safe_norm / (frob_norm + 1e-8), torch.ones_like(frob_norm)
            )
            out = out * scale.unsqueeze(-1)

        return out

    @torch.no_grad()
    def sample_all(self, batch_size, n_steps, x_0=None):
        if x_0 is None:
            x_0 = self.prior_dist.sample((batch_size,)).to(self.device)

        # Ensure x_0 is in GL+(2)
        # x_0 = self.prior_dist.project_to_manifold(x_0)

        x_t = x_0
        x_t_all = [x_t]

        t_values = torch.linspace(0.0, 1.0, n_steps + 1, device=self.device)

        for i in range(n_steps):
            t = t_values[i]
            delta_t = t_values[i + 1] - t_values[i]

            A = self.forward(x_t, t)
            tf = expm(delta_t * A)
            x_t_new = x_t @ tf

            # Safety check: detect numerical issues early
            if torch.isnan(x_t_new).any() or torch.isinf(x_t_new).any():
                print(f"Warning: NaN/Inf detected at step {i}, t={t.item():.4f}")
                print(
                    "Matrix norm before:"
                    f" {torch.norm(x_t.view(batch_size, -1), dim=1).max():.2e}"
                )
                print(
                    "Matrix norm after:"
                    f" {torch.norm(x_t_new.view(batch_size, -1), dim=1).max():.2e}"
                )
                print(
                    "Max eigenvalue of A:"
                    f" {torch.linalg.eigvals(A).real.abs().max():.2e}"
                )
                break

            x_t = x_t_new
            x_t_all.append(x_t)

        x_t_all = torch.stack(x_t_all, dim=0)  # [n_steps + 1, batch, 2, 2]

        return x_t_all

    @torch.no_grad()
    def sample(self, batch_size, n_steps, x_0=None):
        return self.sample_all(batch_size, n_steps, x_0)[-1]

    def interpolate(self, x_0: Tensor, x_1: Tensor, t) -> Tensor:
        # x_0 = self.prior_dist.project_to_manifold(x_0)
        # x_1 = self.prior_dist.project_to_manifold(x_1)

        tf = blogm(torch.linalg.lstsq(x_0, x_1).solution)
        x_t = x_0 @ expm(expand_like(t, x_1) * tf)

        return x_t, tf

    def train_net(self, train_loader, optimizer):
        self.train()
        total_loss = 0.0
        for x_1 in train_loader:
            x_1 = x_1.to(self.device)

            B = x_1.shape[0]
            t = torch.ones((B,), device=self.device)

            x_0 = self.prior_dist.sample((B,)).to(self.device)

            x_t, target = self.interpolate(x_0, x_1, t)

            optimizer.zero_grad()
            pred = self(x_t, t)

            loss = F.mse_loss(pred, target)
            total_loss += loss.item()

            loss.backward()

            if self.max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(
                    self.parameters(), max_norm=self.max_grad_norm
                )

            optimizer.step()

        return total_loss / len(train_loader)

    @torch.no_grad()
    def eval_net(self, test_loader):
        self.eval()
        total_loss = 0.0
        for x_1 in test_loader:
            x_1 = x_1.to(self.device)

            B = x_1.shape[0]
            t = torch.rand((B,), device=self.device)

            x_0 = self.prior_dist.sample((B,)).to(self.device)

            x_t, target = self.interpolate(x_0, x_1, t)

            pred = self(x_t, t)

            loss = F.mse_loss(pred, target)
            total_loss += loss.item()

        return total_loss / len(test_loader)

    def _matrix_geodesic_inverse(self, x_t: Tensor, x_1: Tensor, t: float) -> Tensor:
        """
        Given x_t and x_1, find x_0 such that x_t = x_0 @ expm(t * logm(x_0^{-1} @ x_1)).

        Let A = logm(x_0^{-1} @ x_1).

        x_t^{-1} @ x_1 = (x_0 expm(t A))^{-1} @ (x_0 @ expm(A)) = exp((1-t)A)
        A = \frac{1}{1-t} logm(x_t^{-1} @ x_1)
        x_0 = x_t @ expm(t/(t-1) * logm(x_t^{-1} @ x_1))
        """
        if abs(t - 1.0) < 1e-6:
            # At t=1, x_t = x_1, so any x_0 on the geodesic works
            # Return x_t as a reasonable choice
            return x_t

        # Ensure inputs are on the manifold
        # x_t = self.prior_dist.project_to_manifold(x_t)
        # x_1 = self.prior_dist.project_to_manifold(x_1)

        # x_0 = x_t @ expm((t/(t-1)) * logm(x_t^{-1} @ x_1))
        x_t_inv_x_1 = torch.linalg.lstsq(x_t, x_1).solution
        log_term = blogm(x_t_inv_x_1)
        x_0 = x_t @ expm((t / (t - 1.0)) * log_term)

        # Ensure x_0 is on manifold
        # x_0 = self.prior_dist.project_to_manifold(x_0)

        return x_0

    def _forward_simulate(self, x_0: Tensor, t_values: Tensor) -> Tuple[Tensor, Tensor]:
        """Forward simulate from x_0 to get trajectories and log probabilities."""
        batch_size = x_0.shape[0]

        # Ensure x_0 is on manifold
        # x_0 = self.prior_dist.project_to_manifold(x_0)

        x_t = x_0
        x_t_all = [x_t]

        log_p_t = self.prior_dist.log_prob(x_0)
        log_p_all = [log_p_t]

        for i in range(len(t_values) - 1):
            t = t_values[i]
            delta_t = t_values[i + 1] - t_values[i]
            t_tensor = torch.full((batch_size, 1), t.item(), device=self.device)

            # Compute divergence
            with torch.set_grad_enabled(True):
                x_t_grad = x_t.detach().requires_grad_(True)
                div = batched_div(
                    lambda x, t: x @ self.forward(x, t), x_t_grad, t_tensor
                )

            # Exponential Euler step
            A = self.forward(x_t, t_tensor)
            x_t = x_t @ expm(delta_t * A)

            # Project back to manifold
            # x_t = self.prior_dist.project_to_manifold(x_t)

            log_p_t = log_p_t - delta_t * div

            x_t_all.append(x_t)
            log_p_all.append(log_p_t)

        x_t_all = torch.stack(x_t_all, dim=0)  # [n_steps + 1, batch, 2, 2]
        log_p_all = torch.stack(log_p_all, dim=0)

        return x_t_all, log_p_all

    def _integrate_divergence_to_time(
        self, x_0: Tensor, t: float, n_substeps: int
    ) -> Tensor:
        """Integrate divergence from x_0 to time t."""
        batch_size = x_0.shape[0]
        delta_t = t / n_substeps

        x_t = x_0
        log_p_t = self.prior_dist.log_prob(x_0)  # [batch, M]

        for i in range(n_substeps):
            s = i * delta_t
            s_tensor = torch.full((batch_size, 1), s, device=self.device)

            # Compute divergence
            with torch.set_grad_enabled(True):
                x_t_grad = x_t.detach().requires_grad_(True)
                div = batched_div(
                    lambda x, t: x @ self.forward(x, t), x_t_grad, s_tensor
                )

            # Exponential Euler step
            A = self.forward(x_t, s_tensor)
            x_t = x_t @ expm(delta_t * A)
            log_p_t = log_p_t - delta_t * div

        # Return log probability at time t
        return log_p_t

    @torch.no_grad()
    def approx_p_x1_given_xt(
        self,
        x_0: torch.Tensor,
        p1_dist,  # a torch.distributions instance
        n_steps: int = 50,
        n_x1_grid: int = 100,
    ) -> Dict[str, torch.Tensor]:
        """
        Compute p(x1 | x_t) for matrix flow matching on prior group.

        Works entirely in matrix space without angle conversions.
        """
        self.eval()
        device = x_0.device
        batch_size = x_0.shape[0]

        # Time grid
        t_values = torch.linspace(0.0, 1.0, n_steps + 1, device=device)

        # Detect discrete prior
        use_discrete = isinstance(p1_dist, DiscreteDeltaMixture)

        if use_discrete:
            # Discrete atoms: use matrix means directly
            x1_values = p1_dist.locs.to(device)  # [K, 2, 2]
            prior_weights = p1_dist.weights.to(device)
            log_prior_weights = torch.log(prior_weights)
            M = x1_values.shape[0]
        else:
            # For continuous prior, sample from this distribution
            x1_values = p1_dist.sample((n_x1_grid,)).to(device)  # [M, 2, 2]
            M = x1_values.shape[0]

            log_p_x1_grid = p1_dist.log_prob(x1_values).view(M)

        # Storage tensors
        x_t_all = torch.zeros(n_steps + 1, batch_size, 2, 2, device=device)
        log_p_t_all = torch.zeros(n_steps + 1, batch_size, device=device)
        p_x1_given_xt_all = torch.zeros(n_steps + 1, batch_size, M, device=device)
        # For GMM, we also track component weights
        if not use_discrete and hasattr(p1_dist, "mixture_distribution"):
            K = p1_dist.mixture_distribution.probs.shape[0]
            weights_all = torch.zeros(n_steps + 1, batch_size, K, device=device)
        else:
            weights_all = torch.zeros(n_steps + 1, batch_size, M, device=device)

        # Forward simulate once for trajectories and marginal log p_t
        x_sim, logp_sim = self._forward_simulate(x_0, t_values)
        x_t_all[:] = x_sim
        log_p_t_all[:] = logp_sim

        # Main loop over times
        for i, t in enumerate(t_values):
            x_t = x_t_all[i]  # [batch, 2, 2]
            t_val = t.item()

            if t_val == 0.0:
                # At t=0, posterior equals prior
                if use_discrete:
                    post = prior_weights.unsqueeze(0).expand(
                        batch_size, -1
                    )  # [batch, K]
                    p_x1_given_xt_all[i] = post
                    weights_all[i] = post
                else:
                    p_x1 = torch.exp(log_p_x1_grid)
                    p_x1 = p_x1 / p_x1.sum()
                    p_x1_given_xt_all[i] = p_x1.unsqueeze(0).expand(batch_size, -1)

                    # Compute GMM component weights if applicable
                    if hasattr(p1_dist, "mixture_distribution"):
                        weights_all[i] = self._compute_gmm_weights_matrix(
                            x1_values, p_x1.unsqueeze(0), p1_dist
                        ).expand(batch_size, -1)
                continue

            # For t > 0, compute likelihoods
            log_lik = torch.full((batch_size, M), -float("inf"), device=device)

            if t_val < 0.99:
                # Matrix geodesic inversion
                batch_M = batch_size * M
                x_t_exp = x_t.unsqueeze(1).expand(-1, M, -1, -1)  # [batch, M, 2, 2]
                x1_exp = x1_values.unsqueeze(0).expand(
                    batch_size, -1, -1, -1
                )  # [batch, M, 2, 2]

                x_t_flat = x_t_exp.reshape(batch_M, 2, 2)
                x1_flat = x1_exp.reshape(batch_M, 2, 2)

                # Compute x_0 for all pairs
                x0_all = self._matrix_geodesic_inverse(x_t_flat, x1_flat, t_val)

                # Compute log probability by integrating divergence
                log_p = self._integrate_divergence_to_time(x0_all, t_val, n_substeps=i)
                log_lik = log_p.view(batch_size, M)
            else:
                # Near t=1: use geodesic distance on prior group
                x_t_exp = x_t.unsqueeze(1).expand(-1, M, -1, -1)  # [batch, M, 2, 2]
                x1_exp = x1_values.unsqueeze(0).expand(
                    batch_size, -1, -1, -1
                )  # [batch, M, 2, 2]

                # Geodesic distance: ||log(x_t^{-1} @ x_1)||_F
                x_t_inv_x1 = torch.linalg.lstsq(
                    x_t_exp.reshape(batch_M, 2, 2),
                    x1_exp.reshape(batch_M, 2, 2),
                ).solution

                log_matrices = blogm(x_t_inv_x1).view(batch_size, M, 2, 2)
                distances = torch.norm(
                    log_matrices.view(batch_size, M, -1), dim=-1
                )  # [batch, M]

                sigma = 0.1
                log_lik = -0.5 * (distances / sigma) ** 2

            # Combine with prior
            if use_discrete:
                log_post_unnorm = log_lik + log_prior_weights.unsqueeze(0)
            else:
                log_post_unnorm = log_lik + log_p_x1_grid.unsqueeze(0)

            # Normalize
            all_inf_mask = ~torch.isfinite(log_post_unnorm).any(dim=1)
            log_post_unnorm[all_inf_mask] = 0.0
            post = F.softmax(log_post_unnorm, dim=1)
            p_x1_given_xt_all[i] = post

            # Update weights
            if use_discrete:
                weights_all[i] = post
            elif hasattr(p1_dist, "mixture_distribution"):
                weights_all[i] = self._compute_gmm_weights_matrix(
                    x1_values, post, p1_dist
                )

        # Extract component means as matrices
        if use_discrete:
            component_means = x1_values.clone()
        elif hasattr(p1_dist, "component_distribution"):
            # Convert GMM component means to matrices
            component_means = p1_dist.component_distribution.base_dist.loc
        else:
            component_means = None

        return {
            "x_t_all": x_t_all,
            "log_p_t_all": log_p_t_all,
            "p_x1_given_xt_all": p_x1_given_xt_all,
            "weights_all": weights_all,
            "x1_grid": x1_values,
            "t_values": t_values,
            "component_means": component_means,  # [K, 2, 2]
            "use_discrete": use_discrete,
        }

    def _compute_gmm_weights_matrix(
        self, x1_matrices: Tensor, p_x1: Tensor, p1_dist
    ) -> Tensor:
        """
        Compute posterior mixture weights for GMM components in GL+(2) space.

        Args:
            x1_matrices: [M, 2, 2] sampled matrices
            p_x1: [batch, M] posterior probabilities
            p1_dist: GMM distribution in GL+(2) space

        Returns:
            weights: [batch, K] component weights
        """
        if not hasattr(p1_dist, "mixture_distribution"):
            # Not a GMM, return p_x1 as weights
            return p_x1

        mix_probs = p1_dist.mixture_distribution.probs  # [K]
        K = mix_probs.shape[0]
        M = x1_matrices.shape[0]

        # Evaluate each component at each grid point
        log_mix = torch.log(mix_probs)  # [K]

        # For each component k, evaluate log p_k(x1_matrices)
        log_q = torch.zeros(K, M, device=x1_matrices.device)  # [K, M]

        for k in range(K):
            # Get k-th component distribution
            comp_k = p1_dist.component_distribution.base_dist
            # Evaluate at all grid points
            log_q[k] = comp_k.log_prob(x1_matrices)  # [M]

        # Evaluate total log p1(x1_matrices)
        log_p1 = p1_dist.log_prob(x1_matrices)  # [M]

        # Compute ratio: p(k) * p(x|k) / p(x)
        log_ratio = log_mix.unsqueeze(1) + log_q - log_p1.unsqueeze(0)  # [K, M]
        ratio = torch.exp(log_ratio)

        # Compute weights: sum over grid points weighted by posterior
        weights_unnorm = p_x1 @ ratio.t()  # [batch, K]
        weights = weights_unnorm / (weights_unnorm.sum(dim=1, keepdim=True) + 1e-12)

        return weights


class OracleFlow(nn.Module):
    def __init__(
        self,
        prior_dist,
        device,
    ):
        super().__init__()
        self.device = device
        self.prior_dist = prior_dist

    @torch.no_grad()
    def sample_all(self, x_1, n_steps, x_0=None):
        if x_0 is None:
            B = x_1.shape[0]
            x_0 = self.prior_dist.sample((B,)).to(self.device)
            x_t = x_0.clone().to(self.device)
        else:
            x_t = x_0.clone().to(self.device)

        x_t_all = [x_t]

        t_values = torch.linspace(0.0, 1.0, n_steps + 1, device=self.device)

        for i in range(n_steps):
            delta_t = t_values[i + 1] - t_values[i]

            A = blogm(torch.linalg.lstsq(x_0, x_1, driver="gels").solution)
            tf = expm(delta_t * A)
            x_t_new = x_t @ tf

            x_t = x_t_new
            x_t_all.append(x_t)

        x_t_all = torch.stack(x_t_all, dim=0)  # [n_steps + 1, batch, 2, 2]

        return x_t_all

    @torch.no_grad()
    def sample(self, batch_size, n_steps, x_0=None):
        return self.sample_all(batch_size, n_steps, x_0)[-1]
