from dataclasses import dataclass
import numpy as np
import jax.numpy as jnp

def numpy_seed_from_key(key):
    k0 = int(np.uint32(np.array(key[0])))
    k1 = int(np.uint32(np.array(key[1])))
    return int(((k0 << 16) ^ k1) % (2**32 - 1))

# =============================================================================
# TFMPE CONFIG
# =============================================================================

@dataclass
class TFMPEConfig:
    N_t: int = 50
    nx: int = 81
    dt_init: float = 2e-4
    eta: float = 0.05
    latent_dim: int = 64
    n_encoder: int = 3
    n_decoder: int = 3
    n_heads: int = 4
    n_ff: int = 2
    n_rounds: int = 1
    n_samples_per_round: int = 1000
    n_val_samples: int = 100
    n_iter_per_round: int = 2000
    batch_size: int = 64
    learning_rate: float = 3e-4
    n_posterior_samples: int = 500
    output_dir: str = "tfmpe_hemo_results"
    use_normalization: bool = True

# =============================================================================
# NORMALIZATION UTILITIES
# =============================================================================

@dataclass
class Normalizer:
    """Handles per-site normalization of flow data."""
    means: np.ndarray
    stds: np.ndarray
    
    @classmethod
    def fit(cls, y: np.ndarray) -> 'Normalizer':
        if y.ndim == 4:
            y_flat = y[:, :, :, 0]
            means = np.mean(y_flat, axis=(0, 2))
            stds = np.std(y_flat, axis=(0, 2))
        elif y.ndim == 2:
            means = np.mean(y, axis=1)
            stds = np.std(y, axis=1)
        else:
            raise ValueError(f"Unexpected y shape: {y.shape}")
        stds = np.maximum(stds, 1e-10)
        return cls(means=means, stds=stds)
    
    def normalize(self, y: jnp.ndarray) -> jnp.ndarray:
        means = jnp.array(self.means)[None, :, None, None]
        stds = jnp.array(self.stds)[None, :, None, None]
        return (y - means) / stds
    
    def denormalize(self, y_norm: jnp.ndarray) -> jnp.ndarray:
        means = jnp.array(self.means)[None, :, None, None]
        stds = jnp.array(self.stds)[None, :, None, None]
        return y_norm * stds + means
