# src/flow_matching_angle_to_angle.py
from typing import Dict, Tuple

import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torchdiffeq import odeint

from group_discovery.distributions import DiscreteDeltaMixture
from group_discovery.geometry_2d import wrap_angle
from group_discovery.utils import (
    batched_div,
    icdf_power,
    expand_like,
    sample_power_distribution,
)


# Model
class Flow(nn.Module):
    def __init__(
        self,
        in_dim,
        out_dim,
        hidden_dim,
        prior_dist,
        device,
        time_sampling="uniform",
        time_sampling_kwargs={},
    ):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim + 1, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, out_dim),
        )
        self.device = device
        self.prior_dist = prior_dist
        self.prior_dist.to(device)

        self.time_sampling = time_sampling
        if time_sampling == "power":
            self.time_skewness = time_sampling_kwargs["skewness"]

        self.reset_parameters()

    def reset_parameters(self):
        for i, layer in enumerate(self.net):
            if isinstance(layer, nn.Linear):
                if i == len(self.net) - 1:  # Last layer
                    nn.init.uniform_(layer.weight, -0.01, 0.01)
                else:
                    nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu")
                nn.init.zeros_(layer.bias)

    def forward(self, x_t: Tensor, t: Tensor) -> Tensor:
        B = x_t.shape[0]
        orig_dims = x_t.shape[1:]

        x_t = wrap_angle(x_t)
        if t.dim() == 0:
            t = t.view(1, 1).expand(B, 1)
        elif t.dim() == 1:
            t = t.view(B, 1)

        xin = torch.cat((x_t, t), dim=1)

        out = self.net(xin)

        out = out.view(B, *orig_dims)

        return out

    @torch.no_grad()
    def sample_all(self, batch_size, n_steps, x_0=None):
        if x_0 is None:
            x_0 = self.prior_dist.sample((batch_size,)).to(self.device)

        if self.time_sampling == "uniform":
            # Evenly spaced timesteps
            t_values = torch.linspace(0.0, 1.0, n_steps + 1, device=self.device)
        elif self.time_sampling == "power":
            t_values = icdf_power(
                0,
                1,
                n_steps + 1,
                skewness=self.time_skewness,
                device=self.device,
            )
        else:
            raise ValueError(f"Unknown time sampling: {self.time_sampling}")

        x_t_all = odeint(
            lambda t, x: self.forward(x, t),
            x_0,
            t_values,
            atol=1e-5,
            rtol=1e-5,
        )

        x_t_all[0] = wrap_angle(x_t_all[0])
        x_t_all[-1] = wrap_angle(x_t_all[-1])

        return x_t_all

    @torch.no_grad()
    def sample(self, batch_size, n_steps, x_0=None):
        return self.sample_all(batch_size, n_steps, x_0)[-1]

    def interpolate(self, x_0: Tensor, x_1: Tensor, t) -> Tensor:
        angle_diff = wrap_angle(x_1 - x_0)
        x_t = wrap_angle(x_0 + expand_like(t, angle_diff) * angle_diff)

        return x_t, angle_diff

    def train_net(self, train_loader, optimizer):
        self.train()
        total_loss = 0.0
        for x_1 in train_loader:
            B = x_1.shape[0]

            x_1 = x_1.to(self.device)
            if self.time_sampling == "uniform":
                # Uniformly sample t in [0, 1]
                t = torch.rand((B,), device=self.device)
            elif self.time_sampling == "power":
                t = sample_power_distribution(
                    0, 1, skewness=self.time_skewness, size=(B,), device=self.device
                )
            else:
                raise ValueError(f"Unknown time sampling: {self.time_sampling}")

            x_0 = self.prior_dist.sample((B,)).to(self.device)

            x_t, target = self.interpolate(x_0, x_1, t)

            optimizer.zero_grad()
            pred = self(x_t, t)

            loss = F.mse_loss(pred, target)
            total_loss += loss.item()

            loss.backward()
            optimizer.step()

        return total_loss / len(train_loader)

    @torch.no_grad()
    def eval_net(self, test_loader):
        self.eval()
        total_loss = 0.0
        for x_1 in test_loader:
            B = x_1.shape[0]

            x_1 = x_1.to(self.device)
            if self.time_sampling == "uniform":
                # Uniformly sample t in [0, 1]
                t = torch.rand((B,), device=self.device)
            elif self.time_sampling == "power":
                t = sample_power_distribution(
                    0, 1, skewness=self.time_skewness, size=(B,), device=self.device
                )
            else:
                raise ValueError(f"Unknown time sampling: {self.time_sampling}")
            x_0 = self.prior_dist.sample((B,)).to(self.device)

            x_t, target = self.interpolate(x_0, x_1, t)

            pred = self(x_t, t)
            loss = F.mse_loss(pred, target)
            total_loss += loss.item()

        return total_loss / len(test_loader)

    @torch.no_grad()
    def approx_p_x1_given_xt(
        self,
        x_0: torch.Tensor,
        p1_dist,  # a torch.distributions instance
        n_steps: int = 50,
        n_x1_grid: int = 100,
    ) -> Dict[str, torch.Tensor]:
        """
        Compute p(x1 | x_t) for t in [0,1], using only the divergence-exact approach.
        Automatically detects if p1_dist is a mixture of Delta (discrete atoms) or not.

        Args:
            x_0: Tensor [batch, dim], here dim=1.
            p1_dist: a torch.distributions distribution. If it is a MixtureSameFamily whose
                     component_distribution.base_dist is a Delta, we treat it as discrete.
                     Otherwise we treat it as continuous and use a uniform grid.
            n_steps: number of timesteps between 0 and 1.
            n_x1_grid: grid size for continuous prior.

        Returns a dict with:
            "x_t_all": [n_steps+1, batch, dim],
            "log_p_t_all": [n_steps+1, batch],
            "p_x1_given_xt_all": [n_steps+1, batch, M],
            "weights_all": [n_steps+1, batch, M],
            "x1_values": [M],
            "t_values": [n_steps+1],
            "component_means": [M],
        """
        self.eval()
        device = x_0.device
        batch_size = x_0.shape[0]
        dim = x_0.shape[1]

        # Time grid
        if self.time_sampling == "uniform":
            # Evenly spaced timesteps
            t_values = torch.linspace(0.0, 1.0, n_steps + 1, device=self.device)
        elif self.time_sampling == "power":
            t_values = create_evenly_spaced_power_values(
                0,
                1,
                n_steps + 1,
                skewness=self.time_skewness,
                device=self.device,
            )

        # Detect discrete prior: Mixture of Delta distributions
        use_discrete = isinstance(p1_dist, DiscreteDeltaMixture)

        if use_discrete:
            # Discrete atoms: extract means and weights
            means = p1_dist.locs
            x1_values = means.view(-1, 1).to(device)  # [K, 1]
            prior_weights = p1_dist.weights.to(device)  # [K], sums to 1
            log_prior_weights = torch.log(prior_weights)  # [K]
            M = x1_values.shape[0]  # number of grid values (# of atoms)
        else:
            # Continuous prior: build uniform grid
            x1_values = torch.linspace(
                -torch.pi, torch.pi - 1e-12, n_x1_grid, device=device
            ).view(
                -1, 1
            )  # [M, 1]
            M = x1_values.shape[0]
            # Precompute log p1(x1) on the grid
            log_p_x1_grid = p1_dist.log_prob(x1_values).view(M)  # [M]
            comp_means = p1_dist.component_distribution.base_dist.loc.squeeze(-1).to(
                device
            )  # [n_components]

        # Storage tensors
        x_t_all = torch.zeros(n_steps + 1, batch_size, dim, device=device)
        log_p_t_all = torch.zeros(n_steps + 1, batch_size, device=device)
        p_x1_given_xt_all = torch.zeros(n_steps + 1, batch_size, M, device=device)
        weights_all = torch.zeros(n_steps + 1, batch_size, M, device=device)

        # Initial conditions
        x_t_all[0] = x_0
        # prior log-prob for x0 under uniform[-π,π)
        log_p_t_all[0] = self.prior_dist.log_prob(wrap_angle(x_0)).view(batch_size)

        # Forward simulate once for trajectories and marginal log p_t
        x_sim, logp_sim = self._forward_simulate(x_0, t_values)
        x_t_all[:] = x_sim
        log_p_t_all[:] = logp_sim

        # Main loop over times
        for i, t in enumerate(t_values):
            x_t = x_t_all[i]  # [batch, dim]
            t_val = t.item()

            if t_val == 0.0:
                # Posterior p(x_1 | x_0) equals prior p(x_1)
                if use_discrete:
                    post = prior_weights.unsqueeze(0).expand(
                        batch_size, -1
                    )  # [batch, K]
                    p_x1_given_xt_all[i] = post
                    weights_all[i] = post  # posterior over atoms = weights
                else:
                    p_x1 = torch.exp(log_p_x1_grid)  # [M]
                    p_x1 = p_x1 / p_x1.sum()
                    p_x1_given_xt_all[i] = p_x1.unsqueeze(0).expand(batch_size, -1)
                    # Compute GMM-component weights from this posterior:
                    weights_all[i] = self._compute_weights_from_distribution(
                        x1_values.view(-1), p_x1.unsqueeze(0), p1_dist
                    ).expand(batch_size, -1)
                continue

            # For t in (0,1]
            # Prepare log-likelihood container: [batch, M]
            log_lik = torch.full((batch_size, M), -float("inf"), device=device)

            if t_val < 0.99:
                # Inversion: x0 = (x_t - t * x1) / (1 - t)
                x_t_exp = x_t.unsqueeze(1)  # [batch, 1, dim]
                x1_exp = x1_values.unsqueeze(0)  # [1, M, 1]
                x0_all = wrap_angle(
                    (x_t_exp - t_val * x1_exp) / (1.0 - t_val)
                )  # [batch, M, 1]

                log_p = self._integrate_divergence_to_time(
                    x0_all.view(-1, 1), t_val, n_substeps=i
                )
                log_lik = log_p.view(batch_size, M)  # [batch, M]
            else:
                # Near t=1: narrow Gaussian fallback
                # x_t ~ N(x_1, 0.01)
                diff = wrap_angle(
                    x_t.unsqueeze(1) - x1_values.unsqueeze(0)
                )  # [batch, M, 1]
                # Compute log likelihood under wrapped N(x1, 0.01) around S^1 up to a constant
                sigma = 0.1
                log_lik = -0.5 * (diff.squeeze(-1) / sigma) ** 2  # [batch, M]

            # Combine with prior and normalize
            if use_discrete:
                # log_lik: [batch, K], log_prior_weights: [K]
                log_post_unnorm = log_lik + log_prior_weights.unsqueeze(0)  # [batch, K]
            else:
                # Continuous prior on grid
                log_post_unnorm = log_lik + log_p_x1_grid.unsqueeze(0)  # [batch, M]

            # Robust normalization, identify rows with all -infs
            all_inf_mask = ~torch.isfinite(log_post_unnorm).any(dim=1)  # [batch]

            # Replace -inf rows with uniform distribution (log_probs = 0)
            log_post_unnorm[all_inf_mask] = 0.0
            post = F.softmax(log_post_unnorm, dim=1)  # [batch, M]
            p_x1_given_xt_all[i] = post  # [batch, M]

            if use_discrete:
                weights_all[i] = post
            else:
                weights_all[i] = self._compute_weights_from_distribution(
                    x1_values.view(-1), post, p1_dist
                )

        # component_means: for discrete, equal to x1_values; for continuous, GMM means
        if use_discrete:
            component_means = x1_values.clone()  # [K]
        else:
            component_means = comp_means  # [n_components]

        return {
            "x_t_all": x_t_all,  # [n_steps+1, batch, dim]
            "log_p_t_all": log_p_t_all,  # [n_steps+1, batch]
            "p_x1_given_xt_all": p_x1_given_xt_all,  # [n_steps+1, batch, M]
            "weights_all": weights_all,  # [n_steps+1, batch, M]
            "x1_grid": x1_values,  # [M]
            "t_values": t_values,  # [n_steps+1]
            "component_means": component_means,  # [K]
            "use_discrete": use_discrete,  # bool
        }

    def _forward_simulate(self, x_0: Tensor, t_values: Tensor) -> Tuple[Tensor, Tensor]:
        """Forward simulate from x_0 to get trajectories and log probabilities."""
        batch_size = x_0.shape[0]

        def dynamics_func(t, state):
            x, log_p = state
            t_tensor = torch.full((batch_size, 1), t.item(), device=self.device)

            with torch.set_grad_enabled(True):
                x.requires_grad_(True)
                v = self.forward(x, t_tensor)
                div = batched_div(self.forward, x, t_tensor)

            return v, -div

        log_p_0 = self.prior_dist.log_prob(x_0)

        x_t_all, log_p_all = odeint(
            dynamics_func,
            (x_0, log_p_0),
            t_values,
            method="dopri5",
            atol=1e-5,
            rtol=1e-5,
        )

        x_t_all[0] = wrap_angle(x_t_all[0])
        x_t_all[-1] = wrap_angle(x_t_all[-1])

        return x_t_all, log_p_all

    def _integrate_divergence_to_time(
        self, x_0: Tensor, t: float, n_substeps: int
    ) -> Tensor:
        """Integrate divergence from x_0 to time t."""
        batch_size = x_0.shape[0]
        t_span = torch.linspace(0, t, n_substeps + 1, device=self.device)

        def dynamics_func(s, state):
            x, log_p = state
            s_tensor = torch.full((batch_size, 1), s.item(), device=self.device)

            with torch.set_grad_enabled(True):
                x.requires_grad_(True)
                v = self.forward(x, s_tensor)
                div = batched_div(self.forward, x, s_tensor)

            return v, -div

        log_p_0 = self.prior_dist.log_prob(x_0)  # [batch, M]

        solution = odeint(
            dynamics_func,
            (x_0, log_p_0),
            t_span,
            method="dopri5",
            atol=1e-5,
            rtol=1e-5,
        )

        return solution[1][-1]

    def _compute_weights_from_distribution(
        self, x1_grid: Tensor, p_x1: Tensor, p1_dist
    ) -> Tensor:
        """
        Compute posterior mixture weights P(comp = k | x_t[b]) = w_k[b] for a
        posterior approximation p(x1 | x_t) given over a grid x1_grid.

        For each batch b and component k,
        w_k[b] ∝ sum_j p_x1[b,j] * (π_k * q_k(x1_grid[j])) / p1_dist(x1_grid[j])
        Then normalized over k.

        Args:
            x1_grid: [M]
            p_x1: [batch, M]  (posterior p(x1|x_t) on the grid)
            p1_dist: MixtureSameFamily with base_dist having log_prob method.
        Returns:
            weights: [batch, K] normalized posterior weights.
        """
        # Extract mixture weights π_k: shape [K]
        mix_probs = p1_dist.mixture_distribution.probs  # [K]
        K = mix_probs.shape[0]  # number of components

        # log π_k
        log_mix = torch.log(mix_probs)  # [K]

        comp_dist = p1_dist.component_distribution.base_dist  # e.g. Normal

        # Evaluate log q_k(x1_grid) for each component k and each grid point j:
        # Build input [K, M, 1]:
        x1_input = x1_grid.unsqueeze(0).unsqueeze(-1).expand(K, -1, -1)  # [K, M, 1]
        log_q = comp_dist.log_prob(x1_input)  # [K, M]

        # Evaluate log p1(x1_j) via p1_dist.log_prob:
        log_p1 = p1_dist.log_prob(x1_grid.unsqueeze(-1)).view(-1)  # [M]

        # Compute ratio_{k,j} = π_k * q_k(x1_j) / p1(x1_j):
        # In log-form: log ratio_{k,j} = log π_k + log q_k(x1_j) - log p1(x1_j)
        log_ratio = log_mix.unsqueeze(1) + log_q - log_p1.unsqueeze(0)  # [K, M]
        ratio = torch.exp(log_ratio)  # [K, M]
        # Compute weights[b, k] = sum_j p_x1[b, j] * ratio[k, j]
        # Use matrix multiplication to compute this efficiently
        # Now weights_unnorm[batch, K] = p_x1 [batch, M] @ ratio_t [M, K]
        weights_unnorm = p_x1 @ ratio.t()  # [batch, K]

        # Normalize
        weights = weights_unnorm / (weights_unnorm.sum(dim=1, keepdim=True) + 1e-12)
        return weights  # [batch, K]
