import os
import sys
import torch
import torch.nn as nn
from torch.linalg import matrix_exp
import math
import numpy as np
import warnings
from scipy.linalg import logm
from torchdiffeq import odeint
import contextlib

@contextlib.contextmanager
def suppress_stdout():
    with open(os.devnull, "w") as fnull:
        old_stdout = sys.stdout
        sys.stdout = fnull
        try:
            yield
        finally:
            sys.stdout = old_stdout


class Latent_MoS(nn.Module):
    def __init__(self, input_dim, latent_dim, output_dim, output_length,
                 time_feat_dim=0, gate_dim=3, top_k_gates=3, ode_func=None, num_subintervals = 5):  # num_subintervals = 5
        super(Latent_MoS, self).__init__()

        self.latent_dim = latent_dim
        self.output_length = output_length
        self.time_feat_dim = 0
        self.gate_dim = gate_dim
        self.top_k_gates = top_k_gates
        self.current_epoch = 0
        self.num_subintervals = num_subintervals

        # ✅ If no ode_func is given, create a default one
        if ode_func is None:
            ode_func = LatentODEFunc(latent_dim)

        # ✅ Create GRU update module
        gru_update = GRUUpdate(input_dim, latent_dim)

        # ✅ Initialize ODE-RNN encoder
        self.encoder = DeterministicODERNNEncoder(
            input_dim=input_dim,
            latent_dim=latent_dim,
            ode_func=ode_func,
            gru_update=gru_update,
        )
        # Symmetry networks
        self.omega_net = nn.Sequential(
            nn.Linear(latent_dim + 1, latent_dim * latent_dim),
            nn.Tanh(),
            nn.Linear(latent_dim * latent_dim, latent_dim * latent_dim),
        )

        self.alpha_mlp = nn.Sequential(
            nn.Linear(latent_dim + 1, latent_dim),
            nn.Tanh(),
            nn.Linear(latent_dim, latent_dim),
            nn.Tanh(),  # restrict scaling to [-1, 1]
        )

        self.velocity_net = nn.Sequential(
            nn.Linear(latent_dim + 1, latent_dim),
            nn.Tanh(),
            nn.Linear(latent_dim, latent_dim),
        )

        self.gating_net = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.Tanh(),
            nn.Linear(latent_dim, gate_dim)
        )

        self.fc = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.Tanh(),
            nn.Linear(latent_dim, output_dim),
        )

        self.velocity = nn.Parameter(torch.zeros(latent_dim))

        # Learnable projection matrix P ∈ ℝ^{latent_dim × 2}
        self.P_net = nn.Sequential(
            nn.Linear(latent_dim + 1, latent_dim),  # outputs flattened m×2
            nn.Tanh(),
            nn.Linear(latent_dim, 2 * latent_dim),
        )

        # Learnable rotation angle θ ∈ ℝ
        self.theta_net = nn.Sequential(
            nn.Linear(latent_dim + 1, 1),
            nn.Tanh(),
        )


    def set_epoch(self, epoch):
        self.current_epoch = epoch

    import torch
    from torch.linalg import svd

    def compute_rr_matrix(self, zt, omega_net):
        B, D_full = zt.size()
        D = D_full - 1

        A = omega_net(zt).view(B, D, D)
        # A = 0.01 * A  # Scale down to ensure small skew-symmetric matrix
        Omega = (A - A.transpose(-1, -2)) / 2
        R = torch.matrix_exp(Omega)

        # Compute determinant of each R in the batch
        # det_R = torch.det(R)  # [B]

        # # Print or log det(R)
        # print("Determinant of R:", det_R)

        return R

    def compute_pi_rr(self, z_t):
        """
        Compute planar rotation matrix π_rr from input [z, t]
        Input:
            z_t: Tensor of shape [B, m+1]
        Output:
            pi_rr: Tensor of shape [B, m, m]
        """
        B, m_plus_1 = z_t.shape
        m = self.latent_dim

        # 1. Rotation angle θ → SO(2)
        epsilon = 0.6
        # During forward pass
        theta = self.theta_net(z_t).squeeze(-1)  # shape [B]
        theta = epsilon * theta # restrict to [-epsilon, epsilon]

        cos_theta = torch.cos(theta)
        sin_theta = torch.sin(theta)
        R = torch.stack([
            torch.stack([cos_theta, -sin_theta], dim=-1),
            torch.stack([sin_theta, cos_theta], dim=-1)
        ], dim=-2)  # [B, 2, 2]

        # 2. Projection matrix P ∈ ℝ^{B × m × 2}
        P_raw = self.P_net(z_t).view(B, m, 2)
        Q, _ = torch.linalg.qr(P_raw)
        P = Q  # [B, m, 2]

        # # 🔍 Check orthonormality of P
        # for i in range(min(3, B)):
        #     PtP = torch.matmul(P[i].T, P[i])  # [2, 2]
        #     print(f"[Sample {i}] PᵀP ≈\n{PtP}")

        # 3. Construct π_rr = I + P(R - I)Pᵀ
        I_2 = torch.eye(2, device=z_t.device).unsqueeze(0).expand(B, -1, -1)
        delta = torch.matmul(P, torch.matmul(R - I_2, P.transpose(1, 2)))  # [B, m, m]

        I_m = torch.eye(m, device=z_t.device).unsqueeze(0).expand(B, -1, -1)
        pi_rr = I_m + delta  # [B, m, m]

        # # 🔍 Check orthogonality of pi_rr
        # for i in range(min(3, B)):
        #     pi_T_pi = torch.matmul(pi_rr[i].T, pi_rr[i])  # [m, m]
        #     print(f"[Sample {i}] πᵀπ ≈\n{pi_T_pi}")

        return pi_rr

    def compute_group_matrices_all(self, z, t_scalar):
        batch_size = z.size(0)

        # Concatenate z and t
        zt = torch.cat([z, t_scalar], dim=-1)

        # Rotation
        R = self.compute_pi_rr(zt)  # [B, m, m]
        Pi_rr = torch.eye(self.latent_dim + 1, device=z.device).unsqueeze(0).repeat(batch_size, 1, 1)
        Pi_rr[:, :self.latent_dim, :self.latent_dim] = R

        # Scaling
        epsilon = 1.5
        alpha = epsilon * self.alpha_mlp(zt) # [-1.5, 1.5]
        Pi_sca = torch.eye(self.latent_dim + 1, device=z.device).unsqueeze(0).repeat(batch_size, 1, 1)
        Pi_sca[:, :self.latent_dim, :self.latent_dim] = torch.diag_embed(alpha)

        # Translation
        velocity = self.velocity.unsqueeze(0).expand(batch_size, -1)
        # velocity = self.velocity_net(zt)
        v_norm = velocity.norm()
        max_v = 0.001
        if v_norm > max_v:
            velocity = velocity * (max_v / (v_norm + 1e-8))
        Pi_tra = torch.eye(self.latent_dim + 1, device=z.device).unsqueeze(0).repeat(batch_size, 1, 1)
        Pi_tra[:, :self.latent_dim, -1] = velocity

        # Base matrices dictionary
        Pi = {
            "Pi_rr": Pi_rr,
            "Pi_sca": Pi_sca,
            "Pi_tra": Pi_tra
        }

        # Add composite matrices if gate_dim == 9
        if self.gate_dim == 9:
            base_keys = ["Pi_rr", "Pi_sca", "Pi_tra"]
            for k1 in base_keys:
                for k2 in base_keys:
                    if k1 == k2:
                        continue  # skip same -> already in base
                    composite_key = f"{k1}_{k2.split('_')[1]}"
                    Pi[composite_key] = torch.bmm(Pi[k1], Pi[k2])  # B x m+1 x m+1

        return Pi, alpha, velocity

    def compute_group_matrices(self, z, t_scalar):
        batch_size = z.size(0)

        # Concatenate z and t
        zt = torch.cat([z, t_scalar], dim=-1)

        # R = self.compute_rr_matrix(zt, self.omega_net)

        R = self.compute_pi_rr(zt)  # [B, m, m]

        Pi_rr = torch.eye(self.latent_dim + 1, device=z.device).unsqueeze(0).repeat(batch_size, 1, 1)
        Pi_rr[:, :self.latent_dim, :self.latent_dim] = R

        alpha = self.alpha_mlp(zt)
        Pi_sca = torch.eye(self.latent_dim + 1, device=z.device).unsqueeze(0).repeat(batch_size, 1, 1)
        Pi_sca[:, :self.latent_dim, :self.latent_dim] = torch.diag_embed(alpha)

        velocity = self.velocity
        v_norm = velocity.norm()
        # max_v = 0.001
        # if v_norm > max_v:
        #     velocity = velocity * (max_v / (v_norm + 1e-8))
        # velocity = velocity.unsqueeze(0).expand(batch_size, -1)

        Pi_tra = torch.eye(self.latent_dim + 1, device=z.device).unsqueeze(0).repeat(batch_size, 1, 1)
        Pi_tra[:, :self.latent_dim, -1] = velocity

        return Pi_rr, Pi_sca, Pi_tra, alpha, velocity

    def rescale_norm(self, z, min_val=0.95, max_val=1.05):
        norm = z.norm(dim=-1, keepdim=True) + 1e-8
        clipped_norm = norm.clamp(min=min_val, max=max_val)
        return z / norm * clipped_norm

    def compute_gates(self, gate_logits, current_epoch, warmup_epochs=10, temperature=1, noise_std=0.1):
        if self.training and current_epoch < warmup_epochs:
            noisy_logits = gate_logits + torch.randn_like(gate_logits) * noise_std
            gates = torch.softmax(noisy_logits / temperature, dim=-1)
        else:
            top_k = min(self.top_k_gates, self.gate_dim)
            topk_values, topk_indices = torch.topk(gate_logits, k=top_k, dim=-1)
            sparse_logits = torch.full_like(gate_logits, float('-inf'))
            sparse_logits.scatter_(dim=-1, index=topk_indices, src=topk_values)
            gates = torch.softmax(sparse_logits / temperature, dim=-1)
        return gates

    def compute_symmetry_flows(self, Pi_rr, Pi_sca, Pi_tra, z_aug):
        Pi = {"Pi_rr": Pi_rr, "Pi_sca": Pi_sca, "Pi_tra": Pi_tra}

        if self.gate_dim == 3:
            expert_names = ["Pi_rr", "Pi_sca", "Pi_tra"]
        elif self.gate_dim == 9:
            expert_names = [
                "Pi_rr", "Pi_sca", "Pi_tra",
                "Pi_rr_sca", "Pi_rr_tra",
                "Pi_sca_rr", "Pi_sca_tra",
                "Pi_tra_rr", "Pi_tra_sca"
            ]

        z_flows = []
        for name in expert_names:
            parts = name.split("_")
            if len(parts) == 2:
                Pi_comp = Pi[name]
            elif len(parts) == 3:
                Pi_comp = torch.bmm(Pi[f"Pi_{parts[1]}"], Pi[f"Pi_{parts[2]}"])
            else:
                raise ValueError(f"Invalid expert name: {name}")

            z_comp = torch.bmm(Pi_comp, z_aug).squeeze(-1)[:, :-1]
            # if name == "Pi_rr":
            #     z_comp = z_comp / (z_comp.norm(dim=-1, keepdim=True) + 1e-8)
            # else:
            #     z_comp = self.rescale_norm(z_comp, min_val=0.5, max_val=1.5)

            z_flows.append(z_comp)

        return z_flows

    def compute_symmetry_flow_weighted_sum(self, Pi_dict, z_aug, gates):
        """
        Efficiently compute z_next as a weighted sum of active symmetry flows.
        Inputs:
            Pi_dict: dict of [B, m+1, m+1] symmetry matrices (Pi_rr, Pi_sca, ..., Pi_tra_sca)
            z_aug:   [B, m+1, 1] — the latent vector in homogeneous form
            gates:   [B, K] — sparse gating weights for K experts (many zeros)
        Returns:
            z_next: [B, m] — next latent vector (excluding homogeneous coord)
        """
        if self.gate_dim == 3:
            expert_names = ["Pi_rr", "Pi_sca", "Pi_tra"]
        elif self.gate_dim == 9:
            expert_names = [
                "Pi_rr", "Pi_sca", "Pi_tra",
                "Pi_rr_sca", "Pi_rr_tra",
                "Pi_sca_rr", "Pi_sca_tra",
                "Pi_tra_rr", "Pi_tra_sca"
            ]

        B, _, _ = z_aug.shape
        D = self.latent_dim
        z_next = torch.zeros(B, D, device=z_aug.device)

        for k, name in enumerate(expert_names):
            h_k = gates[:, k]  # shape: [B]
            if torch.all(h_k == 0):
                continue  # skip inactive expert

            Pi_k = Pi_dict[name]  # [B, m+1, m+1]
            z_k = torch.bmm(Pi_k, z_aug).squeeze(-1)[:, :-1]  # [B, m]
            z_next += z_k * h_k.unsqueeze(-1)

        return z_next

    def compute_symmetry_flow_weighted_sum_v2(self, Pi_dict, z_aug, gates):
        """
        Efficiently compute z_next as a weighted sum of active symmetry flows.

        Optimization: Instead of computing K separate matrix-vector products,
        we compute a single product with the weighted sum of active Pi_k matrices.

        Inputs:
            Pi_dict: dict of [B, m+1, m+1] symmetry matrices (Pi_rr, Pi_sca, ..., Pi_tra_sca)
            z_aug:   [B, m+1, 1] — latent vector in homogeneous coordinates
            gates:   [B, K] — gating weights for K experts (sparse)

        Returns:
            z_next: [B, m] — next latent vector (excluding homogeneous coordinate)
        """
        if self.gate_dim == 3:
            expert_names = ["Pi_rr", "Pi_sca", "Pi_tra"]
        elif self.gate_dim == 9:
            expert_names = [
                "Pi_rr", "Pi_sca", "Pi_tra",
                "Pi_rr_sca", "Pi_rr_tra",
                "Pi_sca_rr", "Pi_sca_tra",
                "Pi_tra_rr", "Pi_tra_sca"
            ]

        B, _, _ = z_aug.shape
        D = self.latent_dim
        Pi_eff = torch.zeros(B, D + 1, D + 1, device=z_aug.device)  # Effective Pi matrix

        for k, name in enumerate(expert_names):
            h_k = gates[:, k]  # [B]
            if torch.all(h_k == 0):
                continue  # Skip inactive expert

            Pi_k = Pi_dict[name]  # [B, m+1, m+1]
            # Weighted sum: Pi_eff += h_k[:, None, None] * Pi_k
            Pi_eff += Pi_k * h_k.view(B, 1, 1)  # Broadcasting to match [B, m+1, m+1]

        # One matrix-vector multiplication per batch
        z_next_aug = torch.bmm(Pi_eff, z_aug).squeeze(-1)  # [B, m+1]

        z_next = z_next_aug[:, :-1]  # Discard homogeneous coordinate → [B, m]

        return z_next

    def forward(self, x, t, mask):
        # x: [B, T, D], t: [T], mask: [B, T]
        x_reversed = torch.flip(x, dims=[1])
        mask_reversed = torch.flip(mask, dims=[1])
        t_reversed = torch.flip(t, dims=[0])

        z = self.encoder(x_reversed, t_reversed, mask_reversed)  # [B, latent_dim]
        z = self.rescale_norm(z, min_val=0.5, max_val=1.5)

        batch_size = z.size(0)
        z_aug = torch.cat([z, torch.ones(batch_size, 1, device=z.device)], dim=-1).unsqueeze(-1)

        outputs = []

        if not self.training:
            z_seq = []
            gates_seq, alpha_seq, vt_seq = [], [], []

        y0_pred = self.fc(z)
        outputs.append(y0_pred.unsqueeze(1))

        # Get start and end times
        t_0, t_N = t[0].item(), t[-1].item()
        L = self.num_subintervals
        delta_T = (t_N - t_0) / L
        # Create anchor times t^(l)
        t_anchors = torch.tensor([t_0 + l * delta_T for l in range(L)], device=t.device)  # shape: [L]
        current_anchor_idx = 0
        z_aug_anchor = z_aug  # initial latent state at t0, shape: [B, D]
        t_scalar = t_anchors[0].expand(batch_size, 1)

        # initial mixture symmetry transformations
        Pi_dict, alpha, vt = self.compute_group_matrices_all(z_aug_anchor[:, :-1, 0], t_scalar)

        # initial gates
        gate_input = z_aug_anchor[:, :-1, 0]
        gate_logits = self.gating_net(gate_input)
        gates = self.compute_gates(gate_logits, current_epoch=self.current_epoch)

        for step_idx in range(x.shape[1] - 1):
            # t_scalar = t[step_idx].expand(batch_size, 1)  # [batch_size, 1]
            t_i = t[step_idx].item()

            # Check if we’ve moved into the next subinterval
            if current_anchor_idx < L - 1 and t_i >= t_anchors[current_anchor_idx + 1]:
                current_anchor_idx += 1
                z_aug_anchor = z_aug

                t_scalar = t_anchors[current_anchor_idx].expand(batch_size, 1)  # [B, 1]

                Pi_dict, alpha, vt = self.compute_group_matrices_all(z_aug_anchor[:, :-1, 0], t_scalar)

                gate_input = z_aug_anchor[:, :-1, 0]
                gate_logits = self.gating_net(gate_input)
                gates = self.compute_gates(gate_logits, current_epoch=self.current_epoch)


            # compute flows for non-zero gates
            gates = torch.ones(x.shape[0], self.gate_dim, device=z_aug.device) # no gates
            z_next  = self.compute_symmetry_flow_weighted_sum_v2(Pi_dict, z_aug, gates)


            # z_flows_stacked = torch.stack(z_flows, dim=1)
            # z_next = torch.sum(gates.unsqueeze(-1) * z_flows_stacked, dim=1)

            # do normalization
            z_next = self.rescale_norm(z_next, min_val=0.5, max_val=1.5)

            z_aug = torch.cat([z_next, torch.ones(batch_size, 1, device=z.device)], dim=-1).unsqueeze(-1)
            z = z_next

            # save parameters during tests
            if not self.training:
                gates_seq.append(gates.unsqueeze(1))
                alpha_seq.append(alpha.unsqueeze(1))
                vt_seq.append(vt.unsqueeze(1))
                z_seq.append(z.unsqueeze(1))

            y_pred = self.fc(z)
            outputs.append(y_pred.unsqueeze(1))

        if not self.training:
            return (
                torch.cat(outputs, dim=1),
                torch.cat(z_seq, dim=1),
                torch.cat(gates_seq, dim=1),
                torch.cat(vt_seq, dim=1).detach().cpu(),
                torch.cat(alpha_seq, dim=1).detach().cpu()
            )
        else:
            return torch.cat(outputs, dim=1)


class GRUUpdate(nn.Module):
    def __init__(self, input_dim, latent_dim, hidden_units=64):
        super().__init__()
        concat_dim = input_dim + latent_dim
        self.update_gate = nn.Sequential(
            nn.Linear(concat_dim,  latent_dim),
            nn.Tanh(),
            nn.Linear(latent_dim, latent_dim),
            nn.Sigmoid()
        )
        self.reset_gate = nn.Sequential(
            nn.Linear(concat_dim, latent_dim),
            nn.Tanh(),
            nn.Linear(latent_dim, latent_dim),
            nn.Sigmoid()
        )
        self.new_state = nn.Sequential(
            nn.Linear(concat_dim, latent_dim),
            nn.Tanh(),
            nn.Linear(latent_dim, latent_dim)
        )

    def forward(self, h, x, mask_1d):
        # h: [B, H], x: [B, D], mask_1d: [B]
        concat = torch.cat([h, x], dim=-1)  # [B, H + D]

        z = self.update_gate(concat)
        r = self.reset_gate(concat)

        r_h = r * h
        concat_r = torch.cat([r_h, x], dim=-1)
        h_tilde = self.new_state(concat_r)

        h_new = (1 - z) * h_tilde + z * h

        # mask_1d: [B] → reshape to [B, 1] to broadcast
        mask = mask_1d.unsqueeze(-1)
        h_out = mask * h_new + (1 - mask) * h
        return h_out



class LatentODEFunc(nn.Module):
    def __init__(self, latent_dim, hidden_units=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.Tanh(),
            nn.Linear(latent_dim, latent_dim)
        )

    def forward(self, t, h):
        return self.net(h)


class DeterministicODERNNEncoder(nn.Module):
    def __init__(self, input_dim, latent_dim, ode_func, gru_update, min_step=2e-3):
        super().__init__()
        self.ode_func = ode_func
        self.gru_update = gru_update
        self.latent_dim = latent_dim
        self.min_step = min_step

    def forward(self, x, t, mask):
        """
        Args:
            x:    [B, T, D] - input data
            t:    [B, T]    - per-batch time
            mask: [B, T]    - binary mask
        Returns:
            h: [B, latent_dim] - final latent state
        """
        B, T, D = x.size()
        h = torch.zeros(B, self.latent_dim, device=x.device)

        for i in range(T):
            if i > 0:
                t_prev = t[i - 1].item()
                t_now = t[i].item()
                delta = abs(t_now - t_prev)

                if delta < self.min_step:
                    # Simple Euler step
                    dh = self.ode_func(torch.tensor(t_prev).to(h), h)
                    h = h + dh * (t_now - t_prev)
                else:
                    # Add intermediate steps
                    num_steps = max(2, int(delta / self.min_step))
                    time_points = torch.linspace(t_prev, t_now, num_steps).to(h)
                    h_traj = odeint(self.ode_func, h, time_points, method='rk4')  # [num_steps, B, H]
                    h = h_traj[-1]  # use the last point

            # GRU update
            h = self.gru_update(h, x[:, i, :], mask[:, i])

        return h