import torch
import torch.nn as nn


def _compute_batch_gradient(input, wrt='T', order=1):
    input = input.clone()
    assert len(input.shape) == 3  # B x T x d
    bsz, t_len, dim = input.shape
    if wrt == 'T':
        ans = input.permute(0, 2, 1).reshape(bsz * dim, t_len)
        grad = torch.gradient(ans, dim=1)[0]
        if order > 1:
            grad = torch.gradient(grad, dim=1)[0]
        grad = grad.reshape(bsz, dim, t_len).permute(0, 2, 1)
    elif wrt == 'd':
        ans = input.reshape(bsz * t_len, dim)
        grad = torch.gradient(ans, dim=1)[0]
        if order > 1:
            grad = torch.gradient(grad, dim=1)[0]
        grad = grad.reshape(bsz, t_len, dim)
    return grad


class SummaryNet(nn.Module):
    def __init__(self, summary_dim=3, hidden_dim=128, mode='physics', state_dim=60):
        """
        Args:
            summary_dim: dimension of summary statistics
            hidden_dim: hidden dimension for MLP
            mode: 'physics' (physics-informed), 'pointwise' (MLP per point), 'statewise' (MLP per state)
            state_dim: spatial dimension of L-96 state (60 for default L-96, only used for 'statewise' mode)
        """
        super().__init__()
        self.summary_dim = summary_dim
        self.hidden_dim = hidden_dim
        self.mode = mode
        self.state_dim = state_dim
        
        if mode == 'physics':
            # Physics-informed approach: compute statistics then project
            self.proj = nn.Linear(3, summary_dim)
        elif mode == 'pointwise':
            # Data-driven approach: MLP learns features from raw values (point by point)
            self.mlp = nn.Sequential(
                nn.Linear(1, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, summary_dim),
            )
        elif mode == 'statewise':
            # Data-driven approach: MLP processes full state at each timestep
            self.mlp = nn.Sequential(
                nn.Linear(state_dim - 4, hidden_dim),  # trimmed spatial dimension
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, (state_dim - 4) * summary_dim),
            )
        else:
            raise ValueError(f"Invalid mode: {mode}. Choose 'physics', 'pointwise', or 'statewise'")

    def forward(self, traj):
        if self.mode == 'physics':
            # Physics-informed approach: compute statistics then project
            var = traj
            var_k_1 = torch.roll(var, 1, dims=2)
            var_k_2 = torch.roll(var, 2, dims=2)
            var_k_p_1 = torch.roll(var, -1, dims=2)
            grad_t = _compute_batch_gradient(var, wrt='T', order=1)
            advection_stats = var_k_1 * (var_k_2 - var_k_p_1)
            advection_stats = advection_stats[:, 2:-2, 2:-2]
            grad_t = grad_t[:, 2:-2, 2:-2]
            var = var[:, 2:-2, 2:-2]
            stats = torch.stack([advection_stats, grad_t, var], dim=-1)
            bsz, t_len, dim, _ = stats.shape
            stats = self.proj(stats.reshape(-1, 3)).reshape(bsz, t_len, dim, self.summary_dim)
            return stats.reshape(bsz, -1, self.summary_dim)
        
        elif self.mode == 'pointwise':
            # Data-driven MLP approach: process each point individually
            bsz, t_len, dim = traj.shape
            traj_trimmed = traj[:, 2:-2, 2:-2]  # (B, T', d')
            bsz, t_trim, d_trim = traj_trimmed.shape
            points = traj_trimmed.reshape(-1, 1)  # (B * T' * d', 1)
            summary = self.mlp(points)  # (B * T' * d', summary_dim)
            summary = summary.reshape(bsz, -1, self.summary_dim)
            return summary
        
        elif self.mode == 'statewise':
            # Data-driven MLP approach: process full state at each timestep
            bsz, t_len, dim = traj.shape
            traj_trimmed = traj[:, 2:-2, 2:-2]  # (B, T', d')
            bsz, t_trim, d_trim = traj_trimmed.shape
            states = traj_trimmed.reshape(bsz * t_trim, d_trim)  # (B * T', d')
            summary = self.mlp(states)  # (B * T', d' * summary_dim)
            summary = summary.reshape(bsz, t_trim, d_trim, self.summary_dim)
            return summary.reshape(bsz, -1, self.summary_dim)


class Critic(nn.Module):
    def __init__(self, summary_dim=3, hidden_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(summary_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, stats):
        # stats: (B, N, D)
        out = self.net(stats)
        return out.mean(dim=1).squeeze(-1)
