"""
Spatio-Temporal Neural Process (STNP) model for GLEAM-AI.

This module contains the main STNP model that combines spatial and temporal modeling
for epidemiological forecasting using neural processes.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist
from torch.distributions import constraints
import pytorch_lightning as pl
from typing import Dict, Any, Optional, Tuple, List
import numpy as np
import logging
import pandas as pd
from torchmetrics import WeightedMeanAbsolutePercentageError
from .components import EmbedModel, EncoderRNN, DecoderRNN_5, LatentEncoder, MeanAggregator
from ..config.settings import ModelConfig, TrainingConfig
from ..data.utils import load_graph_data
from ..data.constructors import apply_seasonality

logger = logging.getLogger(__name__)

class NegativeBinomial2(dist.Distribution):
    
    arg_constraints={'mu': constraints.positive, 'phi': constraints.positive}
    
    def __init__(self, mu, phi):
        
        self.mu = mu
        self.phi = phi
        self.rate= phi/mu

        super().__init__()
        
    def sample(self, sample_shape=torch.Size()):
        with torch.no_grad():
            shape = self._extended_shape(sample_shape)
            
            # Handle MPS device - Gamma and Poisson sampling not supported on MPS yet
            device = self.mu.device
            if device.type == 'mps':
                # Move to CPU for distribution sampling, then back to MPS
                phi_cpu = self.phi.cpu()
                rate_cpu = self.rate.cpu()
                gamma = dist.Gamma(concentration=phi_cpu, rate=rate_cpu)
                gamma_samples = gamma.sample(shape)
                # Poisson also needs to be on CPU
                negative_binomial_samples = dist.Poisson(gamma_samples).sample()
                # Move result back to MPS
                return negative_binomial_samples.to(device)
            else:
                gamma = dist.Gamma(concentration=self.phi, rate=self.rate)
                gamma_samples = gamma.sample(shape)
                negative_binomial_samples = dist.Poisson(gamma_samples).sample()
                return negative_binomial_samples
    def rsample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        
        # Handle MPS device - Gamma and Poisson sampling not supported on MPS yet
        device = self.mu.device
        if device.type == 'mps':
            # Move to CPU for distribution sampling, then back to MPS
            phi_cpu = self.phi.cpu()
            rate_cpu = self.rate.cpu()
            gamma = dist.Gamma(concentration=phi_cpu, rate=rate_cpu)
            gamma_samples = gamma.rsample(shape)
            # Poisson also needs to be on CPU
            poisson_samples = dist.Poisson(gamma_samples).sample()
            poisson_samples_detached = poisson_samples.detach()
            straight_through_samples = poisson_samples_detached + (gamma_samples - gamma_samples.detach())
            # Move result back to MPS
            return straight_through_samples.to(device)
        else:
            gamma = dist.Gamma(concentration=self.phi, rate=self.rate)
            gamma_samples = gamma.rsample(shape)
            # gamma_samples.requres_grad=True
            poisson_samples = dist.Poisson(gamma_samples).sample()
            # poisson_samples.requires_grad=True
            poisson_samples_detached = poisson_samples.detach()
            straight_through_samples = poisson_samples_detached + (gamma_samples - gamma_samples.detach())
            return straight_through_samples
   
    def log_prob(self, y):
        lp= -torch.lgamma(self.phi)-torch.lgamma(y+1)+torch.lgamma(self.phi+y+1)-torch.log(self.phi+y)+ \
            +self.phi*torch.log(self.rate)- (self.phi+y)*torch.log1p(self.rate)
        return lp
    
    def mean(self):
        return self.mu
    
    def variance(self):
        return self.phi/torch.sqrt(self.rate)*(1+self.rate) 


class STNP(pl.LightningModule):
    """
    Spatio-Temporal Neural Process (STNP) model.
    
    This model combines spatial and temporal modeling for epidemiological forecasting
    using neural processes. It processes spatial relationships through graph neural
    networks and temporal patterns through recurrent neural networks.
    
    Args:
        config: Model configuration containing all model parameters
        edge_index: Graph edge indices [2, num_edges]
        edge_weight: Graph edge weights [num_edges]
    """
    
    def __init__(
        self, 
        config: Optional[ModelConfig] = None,
        edge_index: Optional[np.ndarray] = None,
        edge_weight: Optional[np.ndarray] = None,
        training_config: Optional[TrainingConfig] = None,
        # Old parameter names for backward compatibility
        x_dim: Optional[int] = None,
        xt_dim: Optional[int] = None,
        y_dim: Optional[int] = None,
        z_dim: Optional[int] = None,
        r_dim: Optional[int] = None,
        seq_len: Optional[int] = None,
        num_nodes: Optional[int] = None,
        in_channels: Optional[int] = None,
        out_channels: Optional[int] = None,
        embed_out_dim: Optional[int] = None,
        max_diffusion_step: Optional[int] = None,
        encoder_num_rnn: Optional[int] = None,
        decoder_num_rnn: Optional[int] = None,
        decoder_hidden_dims: Optional[List[int]] = None,
        context_percentage: Optional[float] = None,
        NUM_COMP: Optional[int] = None,
        # Training parameters for backward compatibility
        lr: Optional[float] = None,
        lr_encoder: Optional[float] = None,
        lr_decoder: Optional[float] = None,
        lr_milestones: Optional[List[int]] = None,
        lr_gamma: Optional[float] = None
    ):
        super().__init__()
        
        # Handle both new and old parameter styles
        if config is not None:
            # New style: using ModelConfig
            self.config = config
            self.training_config = training_config
            self.x_dim = config.x_dim
            self.xt_dim = config.xt_dim
            self.y_dim = config.y_dim
            self.z_dim = config.z_dim
            self.r_dim = config.r_dim
            self.seq_len = config.seq_len
            self.num_nodes = config.num_nodes
            self.in_channels = config.in_channels
            self.out_channels = config.out_channels
            self.embed_out_dim = config.embed_out_dim
            self.max_diffusion_step = config.max_diffusion_step
            self.encoder_num_rnn = config.encoder_num_rnn
            self.decoder_num_rnn = config.decoder_num_rnn
            self.decoder_hidden_dims = config.decoder_hidden_dims
            self.context_percentage = config.context_percentage
            self.NUM_COMP = config.NUM_COMP
        else:
            # Old style: using individual parameters
            self.config = None
            self.training_config = None
            self.x_dim = x_dim
            self.xt_dim = xt_dim
            self.y_dim = y_dim
            self.z_dim = z_dim
            self.r_dim = r_dim
            self.seq_len = seq_len
            self.num_nodes = num_nodes
            self.in_channels = in_channels
            self.out_channels = out_channels
            self.embed_out_dim = embed_out_dim
            self.max_diffusion_step = max_diffusion_step
            self.encoder_num_rnn = encoder_num_rnn
            self.decoder_num_rnn = decoder_num_rnn
            self.decoder_hidden_dims = decoder_hidden_dims
            self.context_percentage = context_percentage
            self.NUM_COMP = NUM_COMP
        
        # Store graph structure
        if edge_index is not None:
            self.edge_index = torch.from_numpy(edge_index).long()
            self.edge_weight = torch.from_numpy(edge_weight).float()
        else:
            # For backward compatibility, these might be passed as tensors
            self.edge_index = None
            self.edge_weight = None
        
        # Save hyperparameters for PyTorch Lightning - use the old parameter names for compatibility
        self.save_hyperparameters(
            "x_dim", "xt_dim", "y_dim", "z_dim", "r_dim", "seq_len", "num_nodes", 
            "in_channels", "out_channels", "embed_out_dim", "max_diffusion_step",
            "encoder_num_rnn", "decoder_num_rnn", "decoder_hidden_dims", 
            "context_percentage", "NUM_COMP", "lr", "lr_encoder", "lr_decoder", 
            "lr_milestones", "lr_gamma"
        )
        
        # Initialize model components
        self._build_model()
        
        # Initialize loss functions
        self.mae = nn.L1Loss(reduction="mean")
        self.mse = nn.MSELoss(reduction="mean")
        
        # Global latent variables (for amortized inference)
        self.register_buffer(
            "mu_z_global", 
            torch.zeros((self.seq_len, self.z_dim), requires_grad=False)
        )
        self.register_buffer(
            "var_z_global", 
            torch.ones((self.seq_len, self.z_dim), requires_grad=False)
        )
        
        # Statistics for normalization - use regular tensors like the old model
        self.y_mean = torch.zeros(self.seq_len, self.y_dim * self.NUM_COMP)
        self.y_std = torch.ones(self.seq_len, self.y_dim * self.NUM_COMP)
        
        # Validation outputs for PyTorch Lightning
        self.validation_step_outputs = []
    
    def run(self, starting_date: str, R0: float, seasonality_min: float, frac_susceptible: float, frac_latent: float, 
            frac_recovered: float, num_runs: int):
        # Run the trained model to get samples.
        # Handle both new ModelConfig and old dictionary config
        if hasattr(self.config, 'seq_len'):
            # New ModelConfig style
            seq_len = self.config.seq_len
            num_nodes = self.config.num_nodes
            population_scaler = self.config.POPULATION_SCALER
            # Use stored original config for meta_data
            if hasattr(self, '_original_config'):
                population_csv_path = self._original_config['meta_data']['population_csv_path']
            else:
                population_csv_path = "meta_data/populations.csv"  # Default path
        else:
            # Old dictionary style
            seq_len = self.config['model']['seq_len']
            num_nodes = self.config['model']['num_nodes']
            population_scaler = self.config['model']['POPULATION_SCALER']
            population_csv_path = self.config['meta_data']['population_csv_path']
        
        starting_date = pd.to_datetime(starting_date)
        populations = pd.read_csv(population_csv_path)["population"].values
        populations_rescaled = populations / population_scaler

        year = starting_date.year
        time_index = (starting_date-pd.to_datetime(f"{year}-01-01")).days/366

        # Create temporal embeddings
        dates = pd.date_range(start=starting_date, periods=seq_len)
        xt = np.array(list(map(apply_seasonality, dates, [seasonality_min]*seq_len)), dtype=np.float32)
        xt = np.repeat(np.repeat(xt[:, np.newaxis, np.newaxis], num_nodes, axis=1)[np.newaxis,:], num_runs, axis=0) # num_runs x seq_len x num_nodes x xt_dim (i.e. 1)

        # Create x features
        x = [R0, time_index, frac_susceptible, frac_latent, frac_recovered]
        x = np.repeat(np.column_stack([np.tile(x, (num_nodes, 1)), populations_rescaled])[np.newaxis,:],num_runs, axis=0) # num_runs x num_nodes x x_dim

        # Create y0 features
        y0 = np.repeat(np.round(frac_latent * populations)[np.newaxis,:],num_runs,axis=0) # num_runs x num_nodes

        y_pred = self.get_samples(x, xt, y0)
        # y_pred shape is (n_runs, n_dates, num_nodes * NUM_COMP)
        # The last dimension is ordered as [hosp_inc_nodes, hosp_prev_nodes, latent_inc_nodes, latent_prev_nodes]
        # where each compartment has num_nodes values
        
        # Reshape to separate compartments and nodes
        # First reshape to (n_runs, n_dates, NUM_COMP, num_nodes)
        y_pred_reshaped = y_pred.reshape((y_pred.shape[0], y_pred.shape[1], 4, num_nodes))
        n_runs, n_dates, n_comps, n_nodes = y_pred_reshaped.shape
        
        # Transpose to (n_runs, n_dates, n_nodes, n_comps) for easier unpacking
        y_pred_reshaped = y_pred_reshaped.transpose(0, 1, 3, 2)
        
        run_ids = np.arange(n_runs)
        state_ids = np.arange(n_nodes)
        compartment_names = ['hosp_inc', 'hosp_prev', 'latent_inc', 'latent_prev']

        df = pd.DataFrame({
            'run_id': np.repeat(run_ids, n_dates * n_nodes * n_comps),
            'date': np.tile(np.repeat(dates, n_nodes * n_comps), n_runs),
            'state_id': np.tile(np.repeat(state_ids, n_comps), n_runs * n_dates),
            'compartment': np.tile(compartment_names, n_runs * n_dates * n_nodes),
            'value': y_pred_reshaped.ravel()
        })

        return df


    def run_batch(self, starting_dates: np.ndarray, R0s: np.ndarray, seasonality_mins: np.ndarray, fracs_susceptible: np.ndarray, fracs_latent: np.ndarray, fracs_recovered: np.ndarray):
        # Run the trained model to get sample while passing in a list of starting dates, R0s, seasonality_mins, frac_susceptibles, frac_latents, and frac_recovereds.
        # Handle both new ModelConfig and old dictionary config
        if hasattr(self.config, 'seq_len'):
            # New ModelConfig style
            seq_len = self.config.seq_len
            num_nodes = self.config.num_nodes
            population_scaler = self.config.POPULATION_SCALER
            # Use stored original config for meta_data
            if hasattr(self, '_original_config'):
                population_csv_path = self._original_config['meta_data']['population_csv_path']
            else:
                population_csv_path = "meta_data/populations.csv"  # Default path
        else:
            # Old dictionary style
            seq_len = self.config['model']['seq_len']
            num_nodes = self.config['model']['num_nodes']
            population_scaler = self.config['model']['POPULATION_SCALER']
            population_csv_path = self.config['meta_data']['population_csv_path']
        
        starting_dates = pd.to_datetime(starting_dates)
        populations = pd.read_csv(population_csv_path)["population"].values
        populations_rescaled = populations / population_scaler

        num_runs = len(starting_dates)
        # Make sure all the arrays have the same length, otherwise raise an error with a clear message
        assert len(R0s) == num_runs, "R0s must have the same length as starting_dates"
        assert len(seasonality_mins) == num_runs, "seasonality_mins must have the same length as starting_dates"
        assert len(fracs_susceptible) == num_runs, "fracs_susceptible must have the same length as starting_dates"
        assert len(fracs_latent) == num_runs, "fracs_latent must have the same length as starting_dates"
        assert len(fracs_recovered) == num_runs

        years = starting_dates.year
        # Use vectorized pandas operations instead of object-dtype array operations
        year_starts = pd.to_datetime(years.astype(str) + '-01-01')
        time_indices = (starting_dates - year_starts).days / 366

        # Create temporal embeddings
        dates = list(map(lambda x: pd.date_range(start=x, periods=seq_len), starting_dates))
        xt = np.array([list(map(apply_seasonality, dates[i], [seasonality_mins[i]]*seq_len)) for i,seasonality_min in enumerate(seasonality_mins)], dtype=np.float32) # this is num_runs x seq_len 
        xt = np.repeat(xt[:,:, np.newaxis, np.newaxis], num_nodes, axis=2) # num_runs x seq_len x num_nodes x xt_dim (i.e. 1)

        # Create x features 
        x = np.array([R0s, time_indices, fracs_susceptible, fracs_latent, fracs_recovered]).T
        x = np.concatenate([np.repeat(x[:, np.newaxis],num_nodes,axis=1), np.repeat(populations_rescaled[np.newaxis,:],num_runs,axis=0)[:,:,np.newaxis]],axis=2) # num_runs x num_nodes x_dim

        # Create y0 features
        y0 = fracs_latent[:,np.newaxis].dot(populations[np.newaxis,:]).round() # num_runs x num_nodes

        y_pred = self.get_samples(x, xt, y0) 
        # y_pred shape is (n_runs, n_dates, num_nodes * NUM_COMP)
        # The last dimension is ordered as [hosp_inc_nodes, hosp_prev_nodes, latent_inc_nodes, latent_prev_nodes]
        
        # Reshape to separate compartments and nodes
        # First reshape to (n_runs, n_dates, NUM_COMP, num_nodes)
        y_pred_reshaped = y_pred.reshape((y_pred.shape[0], y_pred.shape[1], 4, num_nodes))
        n_runs, n_dates, n_comps, n_nodes = y_pred_reshaped.shape
        
        # Transpose to (n_runs, n_dates, n_nodes, n_comps) for easier unpacking
        y_pred_reshaped = y_pred_reshaped.transpose(0, 1, 3, 2)
        
        run_ids = np.arange(n_runs)
        state_ids = np.arange(n_nodes)
        compartment_names = ['hosp_inc', 'hosp_prev', 'latent_inc', 'latent_prev']

        # Create a pandas dataframe with the following columns: run_id, date, state_id, compartment, value and consider that each run_id has a different set of dates.
        # Generate all dates for each run efficiently using vectorized operations
        all_dates = np.concatenate([
            pd.date_range(start=starting_dates[run_id], periods=seq_len) 
            for run_id in range(n_runs)
        ])
        
        # Create the DataFrame using efficient numpy operations
        df = pd.DataFrame({
            'run_id': np.repeat(run_ids, n_dates * n_nodes * n_comps),
            'date': np.repeat(all_dates, n_nodes * n_comps),
            'state_id': np.tile(np.repeat(state_ids, n_comps), n_runs * n_dates),
            'compartment': np.tile(compartment_names, n_runs * n_dates * n_nodes),
            'value': y_pred_reshaped.ravel()
        })

        return df
    
    def get_samples_df(self, x: np.ndarray, xt: np.ndarray, y0: np.ndarray):
        y_pred = self.get_samples(x, xt, y0)


        return pd.DataFrame(y_pred)

    @classmethod
    def load_model_from_checkpoint(cls, checkpoint_filename: str, config: Dict[str, Any], device: str = "cpu"):
        """
        Load a trained STNP model from checkpoint with graph data.
        
        This class method handles the migration from the old parameter structure to the new ModelConfig system.
        
        Args:
            checkpoint_filename: Path to the checkpoint file
            config: Configuration dictionary containing model and training parameters
            device: Device to load the model on ("cpu", "cuda", or "mps")
            
        Returns:
            STNP: Loaded model instance
        """
        # Auto-detect MPS for Apple Silicon if cuda requested but not available
        if device == "cuda" and not torch.cuda.is_available():
            if torch.backends.mps.is_available():
                device = "mps"
                logger.info("CUDA not available, using MPS (Metal Performance Shaders) for Apple Silicon GPU acceleration")
            else:
                device = "cpu"
                logger.info("CUDA not available, falling back to CPU")
        # Load graph data using the new package structure
        edge_index, edge_weight = load_graph_data(config["meta_data"]["metaPath"])
        
        # Extract model parameters from config
        model_params = config["model"]
        
        # Create ModelConfig object for the new STNP model
        model_config = ModelConfig(
            x_dim=model_params["x_dim"],
            xt_dim=model_params["xt_dim"],
            y_dim=model_params["y_dim"],
            num_nodes=model_params["num_nodes"],
            z_dim=model_params["z_dim"],
            r_dim=model_params["r_dim"],
            seq_len=model_params["seq_len"],
            in_channels=model_params["in_channels"],
            embed_out_dim=model_params["embed_out_dim"],
            out_channels=model_params["out_channels"],
            max_diffusion_step=model_params["max_diffusion_step"],
            encoder_num_rnn=model_params["encoder_num_rnn"],
            decoder_num_rnn=model_params["decoder_num_rnn"],
            num_rnn=model_params["encoder_num_rnn"],  # Use encoder_num_rnn for backward compatibility
            decoder_hidden_dims=model_params["decoder_hidden_dims"],
            hidden_dims=model_params["decoder_hidden_dims"],  # Alias for backward compatibility
            context_percentage=model_params["context_percentage"],
            NUM_COMP=model_params["NUM_COMP"],
            POPULATION_SCALER=model_params.get("POPULATION_SCALER", 1_000_000.0)
        )
        
        # Create TrainingConfig object
        train_params = config["train"]
        training_config = TrainingConfig(
            max_epochs=train_params["max_epochs"],
            lr=train_params["lr"],
            lr_encoder=train_params["lr_encoder"],
            lr_decoder=train_params["lr_decoder"],
            lr_milestones=train_params["lr_milestones"],
            lr_gamma=train_params["lr_gamma"],
            train_batch_size=train_params["train_batch_size"],
            val_batch_size=train_params["val_batch_size"],
            patience=train_params["patience"],
            min_delta=train_params.get("min_delta", 0.001),
            gradient_clip_val=train_params["gradient_clip_val"],
            weight_decay=train_params.get("weight_decay", 0.0),
            device=device,
            num_workers=train_params.get("num_workers", 4)
        )
        
        # Create model instance with new configuration
        model = cls(
            config=model_config,
            edge_index=edge_index,
            edge_weight=edge_weight,
            training_config=training_config
        )
        
        # Store the original config for backward compatibility with run methods
        model._original_config = config
        
        # Load checkpoint weights
        checkpoint = torch.load(checkpoint_filename, map_location=device)
        
        # Handle state dict loading
        if 'state_dict' in checkpoint:
            # PyTorch Lightning checkpoint
            model.load_state_dict(checkpoint['state_dict'])
        else:
            # Direct state dict
            model.load_state_dict(checkpoint)
        
        model.eval()
        
        # Move model to the requested device
        model = model.to(device)
        
        return model
    
    @classmethod
    def from_pretrained(cls, model_name: Optional[str] = None, config_path: Optional[str] = None, device: str = "cpu"):
        """
        Load a pretrained STNP model with automatic configuration and statistics loading.
        
        This is the recommended way to load models for inference. It automatically:
        - Loads the configuration file
        - Gets model paths from the registry
        - Loads the model checkpoint
        - Loads and caches y statistics
        
        Args:
            model_name: Name of the model in the registry. If None, uses default model.
            config_path: Path to configuration file. If None, uses "config.yaml".
            device: Device to load the model on ("cpu", "cuda", or "mps")
            
        Returns:
            STNP: Loaded and configured model instance
        """
        # Auto-detect MPS for Apple Silicon if cuda requested but not available
        if device == "cuda" and not torch.cuda.is_available():
            if torch.backends.mps.is_available():
                device = "mps"
                logger.info("CUDA not available, using MPS (Metal Performance Shaders) for Apple Silicon GPU acceleration")
            else:
                device = "cpu"
                logger.info("CUDA not available, falling back to CPU")
        from .registry import ModelRegistry
        
        try:
            # Initialize model registry
            registry = ModelRegistry(config_path)
            
            # Get model paths
            checkpoint_path, y_stats_path = registry.get_model_paths(model_name)
            
            # Load configuration
            config = registry.config
            
            # Load model using existing method
            model = cls.load_model_from_checkpoint(checkpoint_path, config, device)
            
            # Load and set y statistics
            y_mean, y_std = registry.load_y_stats(y_stats_path)
            model.update_y_stats(y_mean, y_std)
            
            # Store registry reference for future use
            model._registry = registry
            model._model_name = model_name or registry.get_default_model_name()
            
            # Move model to the requested device
            model = model.to(device)
            
            logger.info(f"Successfully loaded pretrained model '{model._model_name}' from {checkpoint_path}")
            logger.info(f"Y statistics loaded from {y_stats_path}")
            logger.info(f"Model moved to device: {device}")
            
            return model
            
        except Exception as e:
            raise RuntimeError(
                f"Failed to load pretrained model '{model_name or 'default'}': {e}\n"
                f"Please ensure:\n"
                f"1. Configuration file exists and contains model_registry section\n"
                f"2. Model checkpoint and statistics files exist\n"
                f"3. All paths in config are correct"
            ) from e
    
    def _build_model(self) -> None:
        """Build the model components using the old architecture for checkpoint compatibility."""
        # Embedding model for spatial features - use out_channels from config, not embed_out_dim
        self.embed = EmbedModel(
            in_channels=self.in_channels,
            embed_out_dim=self.embed_out_dim,
            out_channels=self.out_channels,  # This is the key difference!
            max_diffusion_step=self.max_diffusion_step,
            num_nodes=self.num_nodes
        )
        
        # RNN encoder for temporal features
        # Input: embedding + 2 * y_dim (current and previous targets)
        enc_in_dim = self.embed_out_dim + 2 * self.y_dim
        self.rnn_encoder = EncoderRNN(
            enc_in_dim=enc_in_dim,
            r_dim=self.r_dim,
            num_rnn=self.encoder_num_rnn  # Use encoder_num_rnn
        )
        
        # Multiple decoders for different compartments
        decoder_params = {
            "embed_out_dim": self.embed_out_dim,
            "z_dim": self.z_dim,
            "hidden_dims": self.decoder_hidden_dims,  # Use decoder_hidden_dims
            "y_dim": self.y_dim,
            "num_rnn": self.decoder_num_rnn  # Use decoder_num_rnn
        }
        
        self.rnn_decoder_hosp_inc = DecoderRNN_5(**decoder_params)
        self.rnn_decoder_hosp_prev = DecoderRNN_5(**decoder_params)
        self.rnn_decoder_latent_inc = DecoderRNN_5(**decoder_params)
        self.rnn_decoder_latent_prev = DecoderRNN_5(**decoder_params)
        
        # Latent encoder
        self.z_encoder = LatentEncoder(
            r_dim=self.r_dim,
            z_dim=self.z_dim
        )
        
        # Aggregator for combining representations
        self.aggregator = MeanAggregator()
    
    def update_y_stats(self, y_mean, y_std):
        """Update y statistics for normalization - compatibility with old model."""
        if isinstance(y_mean, np.ndarray):
            self.y_mean = torch.from_numpy(y_mean).float()
        elif isinstance(y_mean, torch.Tensor):
            self.y_mean = y_mean.float()
        else:
            raise TypeError("y_mean must be numpy array or torch tensor")
            
        if isinstance(y_std, np.ndarray):
            self.y_std = torch.from_numpy(y_std).float()
        elif isinstance(y_std, torch.Tensor):
            self.y_std = y_std.float()
        else:
            raise TypeError("y_std must be numpy array or torch tensor")
            
        self.y_mean.requires_grad = False
        self.y_std.requires_grad = False
    
    def get_latent_tensors(self):
        """Get latent tensors for compatibility with old model."""
        return self.mu_z_global, self.var_z_global
    
    @torch.no_grad()
    def get_samples(self, x: np.ndarray, xt: np.ndarray, y0_latent_prev: np.ndarray, x_mean: Optional[np.ndarray] = None, x_std: Optional[np.ndarray] = None, y_mean: Optional[np.ndarray] = None, y_std: Optional[np.ndarray] = None):
        """Get samples from the model with automatic statistics handling.
        
        This method automatically uses the loaded y statistics if they were loaded during model initialization.
        If y_mean and y_std are provided, they will override the loaded statistics.
        
        Args:
            x: Static features as numpy array [batch_size, num_nodes, x_dim] or [num_nodes, x_dim]
            xt: Temporal features as numpy array [batch_size, seq_len, num_nodes, xt_dim] or [seq_len, num_nodes, xt_dim]
            y0_latent_prev: Initial latent prevalence as numpy array [batch_size, y_dim] or [y_dim]
            x_mean: Mean for x normalization as numpy array (optional, for backward compatibility)
            x_std: Standard deviation for x normalization as numpy array (optional, for backward compatibility)
            y_mean: Mean for y normalization as numpy array [seq_len, y_dim * NUM_COMP] (optional, overrides loaded stats)
            y_std: Standard deviation for y normalization as numpy array [seq_len, y_dim * NUM_COMP] (optional, overrides loaded stats)
            
        Returns:
            numpy array of samples [batch_size, seq_len, y_dim * 4]
        """
        # Ensure inputs are numpy arrays
        if not isinstance(x, np.ndarray):
            raise TypeError("x must be a numpy array")
        if not isinstance(xt, np.ndarray):
            raise TypeError("xt must be a numpy array")
        if not isinstance(y0_latent_prev, np.ndarray):
            raise TypeError("y0_latent_prev must be a numpy array")
        
        # Use provided statistics or fall back to loaded statistics
        if y_mean is not None and y_std is not None:
            # Use provided statistics
            self.update_y_stats(y_mean, y_std)
        elif hasattr(self, '_registry') and hasattr(self, '_model_name'):
            # Use loaded statistics from registry
            logger.info("Using cached y statistics from model registry")
        else:
            # No statistics available
            raise ValueError(
                "No y statistics available. Please either:\n"
                "1. Load the model using STNP.from_pretrained() to automatically load statistics, or\n"
                "2. Provide y_mean and y_std parameters to this method"
            )
        
        # Get the device currently used by the model
        device = next(self.parameters()).device
        
        # Split y_mean and y_std to get the last compartment (latent_prev)
        # Handle both tensor and numpy array cases
        if isinstance(self.y_mean, torch.Tensor):
            y_mean_split = torch.chunk(self.y_mean, self.NUM_COMP, dim=-1)
            y_std_split = torch.chunk(self.y_std, self.NUM_COMP, dim=-1)
            y_latent_prev_mean_tensor = y_mean_split[-1].to(device)
            y_latent_prev_std_tensor = y_std_split[-1].to(device)
        else:
            # Fallback to numpy splitting for backward compatibility
            *_, y_latent_prev_mean = np.split(self.y_mean, self.NUM_COMP, axis=-1)
            *_, y_latent_prev_std = np.split(self.y_std, self.NUM_COMP, axis=-1)
            y_latent_prev_mean_tensor = torch.from_numpy(y_latent_prev_mean).float().to(device)
            y_latent_prev_std_tensor = torch.from_numpy(y_latent_prev_std).float().to(device)
        
        # Handle input dimensions and convert to tensors
        if x.ndim < 3:
            x = x[np.newaxis, ...]
        x_tensor = torch.from_numpy(x).float().to(device)
            
        xt_tensor = torch.from_numpy(xt).float().to(device)
            
        if y0_latent_prev.ndim < 2:
            y0_latent_prev = y0_latent_prev[np.newaxis, ...]    
        # Convert y0_latent_prev to tensor before normalization
        y0_latent_prev_tensor_for_norm = torch.from_numpy(y0_latent_prev).float().to(device)
        y0_latent_prev_standard = self.normalize(y0_latent_prev_tensor_for_norm, y_latent_prev_mean_tensor, y_latent_prev_std_tensor)
        y0_latent_prev_tensor = y0_latent_prev_standard
        
        mu_pred, phi_pred = self(x_tensor, xt_tensor, y0_latent_prev_tensor)
        samples = self.sample_post(mu_pred, phi_pred)
        samples = samples.detach().cpu().numpy()
        return samples
    
    @staticmethod
    def normalize(data, mean, std):
        """Normalize data using mean and std - compatibility with old model."""
        eps = 1e-8
        if isinstance(data, torch.Tensor):
            return (data - mean.to(data.device)) / (std.to(data.device) + eps)
        else:
            return (data - mean) / (std + eps)
    
    @staticmethod
    def unnormalize(data, mean, std):
        """Unnormalize data using mean and std - compatibility with old model."""
        eps = 1e-8
        if isinstance(data, torch.Tensor):
            return mean.to(data.device) + (std.to(data.device) + eps) * data
        else:
            return mean + (std + eps) * data
    

    def sample_post(self, mu_y, phi_y):
        """Sample from posterior - compatibility with old model."""
        #yp = dist.Normal(mu_y, phi_y.sqrt()).sample()
        yp= NegativeBinomial2(mu_y, phi_y).sample()
        return yp
    
    def get_post(self, y0, embed_out, zs):
        """Get posterior predictions - compatibility with old model."""
        mu_hosp_inc, phi_hosp_inc = self.rnn_decoder_hosp_inc(y0, embed_out, zs)
        mu_hosp_prev, phi_hosp_prev = self.rnn_decoder_hosp_prev(y0, embed_out, zs)
        mu_latent_inc, phi_latent_inc = self.rnn_decoder_latent_inc(y0, embed_out, zs)
        mu_latent_prev, phi_latent_prev = self.rnn_decoder_latent_prev(y0, embed_out, zs)
        
        mu_post = torch.cat([mu_hosp_inc, mu_hosp_prev, mu_latent_inc, mu_latent_prev], dim=-1)
        phi_post = torch.cat([phi_hosp_inc, phi_hosp_prev, phi_latent_inc, phi_latent_prev], dim=-1)
        
        return mu_post, phi_post
    
    def forward(
        self, 
        x: torch.Tensor, 
        xt: torch.Tensor, 
        y0_latent_prev: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass of the STNP model.
        
        Args:
            x: Static features [batch_size, num_nodes, x_dim]
            xt: Temporal features [batch_size, seq_len, num_nodes, xt_dim]
            y0_latent_prev: Initial latent prevalence [batch_size, y_dim]
            
        Returns:
            Tuple of:
            - mu_post: Predicted means [batch_size, seq_len, y_dim * 4]
            - phi_post: Predicted dispersion parameters [batch_size, seq_len, y_dim * 4]
        """
        # Get input embeddings
        embed_out = self.get_input_embedding(x, xt)
        
        # Sample latent variables
        # Ensure global tensors are on the correct device
        mu_z_global = self.mu_z_global if isinstance(self.mu_z_global, torch.Tensor) else torch.zeros((self.seq_len, self.z_dim), device=y0_latent_prev.device)
        var_z_global = self.var_z_global if isinstance(self.var_z_global, torch.Tensor) else torch.ones((self.seq_len, self.z_dim), device=y0_latent_prev.device)
        
        zs = self.sample_z(
            mu_z_global.to(y0_latent_prev.device),
            var_z_global.to(y0_latent_prev.device),
            y0_latent_prev.shape[0]
        )
        
        # Get posterior predictions
        mu_post, phi_post = self.get_post(y0_latent_prev, embed_out, zs)
        
        return mu_post, phi_post
    
    def get_input_embedding(
        self, 
        x: torch.Tensor, 
        xt: torch.Tensor
    ) -> torch.Tensor:
        """
        Get embeddings for input features.
        
        Args:
            x: Static features [batch_size, num_nodes, x_dim]
            xt: Temporal features [batch_size, seq_len, num_nodes, xt_dim]
            
        Returns:
            Embeddings [batch_size, seq_len, embed_out_dim]
        """
        embed_out = []
        h0 = None
        
        for t in range(self.seq_len):
            # Concatenate temporal and static features
            inputs = torch.cat([xt[:, t, ...], x], dim=-1)
            
            # Get embedding for this timestep
            # Handle both tensor and numpy array edge data
            edge_index = self.edge_index if isinstance(self.edge_index, torch.Tensor) else torch.from_numpy(self.edge_index).long()
            edge_weight = self.edge_weight if isinstance(self.edge_weight, torch.Tensor) else torch.from_numpy(self.edge_weight).float()
            
            output, h0 = self.embed(
                inputs, 
                edge_index.to(x.device), 
                edge_weight.to(x.device), 
                h0
            )
            embed_out.append(output)
        
        return torch.stack(embed_out, dim=1)
    
    def sample_z(
        self, 
        mu_z: torch.Tensor, 
        var_z: torch.Tensor, 
        num_samples: int = 1
    ) -> torch.Tensor:
        """
        Sample latent variables from the global distribution.
        
        Args:
            mu_z: Mean parameters [seq_len, z_dim]
            var_z: Variance parameters [seq_len, z_dim]
            num_samples: Number of samples to generate
            
        Returns:
            Sampled latent variables [num_samples, seq_len, z_dim]
        """
        normal_samples = dist.Normal(0, 1).rsample(
            (num_samples, mu_z.shape[0], mu_z.shape[-1])
        ).to(mu_z.device)
        
        zs = mu_z.unsqueeze(0) + var_z.sqrt().unsqueeze(0) * normal_samples
        return zs
    
    def get_latent_representation(
        self, 
        embed_out: torch.Tensor, 
        y: torch.Tensor, 
        y0: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Get latent representation from observed data.
        
        Args:
            embed_out: Embeddings [batch_size, seq_len, embed_out_dim]
            y: Observed targets [batch_size, seq_len, y_dim]
            y0: Initial conditions [batch_size, y_dim]
            
        Returns:
            Tuple of:
            - mu_z: Latent means [seq_len, z_dim]
            - var_z: Latent variances [seq_len, z_dim]
        """
        # Concatenate initial conditions with observed data
        y_full = torch.cat([y0.unsqueeze(1), y], dim=1)
        
        # Prepare encoder input: embedding + current + previous targets
        enc_in = torch.cat([
            embed_out, 
            y_full[:, 1:, :], 
            y_full[:, :-1, :]
        ], dim=-1)
        
        # Encode temporal features
        ri = self.rnn_encoder(enc_in)
        
        # Aggregate across batch dimension
        rs = self.aggregator(ri)
        
        # Encode latent variables
        mu_z, var_z = self.z_encoder(rs)
        
        return mu_z, var_z
    
    def get_posterior(
        self, 
        y0: torch.Tensor, 
        embed_out: torch.Tensor, 
        zs: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Get posterior predictions for all compartments.
        
        Args:
            y0: Initial conditions [batch_size, y_dim]
            embed_out: Embeddings [batch_size, seq_len, embed_out_dim]
            zs: Latent variables [batch_size, seq_len, z_dim]
            
        Returns:
            Tuple of:
            - mu_post: Predicted means [batch_size, seq_len, y_dim * 4]
            - phi_post: Predicted dispersion parameters [batch_size, seq_len, y_dim * 4]
        """
        # Get predictions for each compartment
        mu_hosp_inc, phi_hosp_inc = self.rnn_decoder_hosp_inc(y0, embed_out, zs)
        mu_hosp_prev, phi_hosp_prev = self.rnn_decoder_hosp_prev(y0, embed_out, zs)
        mu_latent_inc, phi_latent_inc = self.rnn_decoder_latent_inc(y0, embed_out, zs)
        mu_latent_prev, phi_latent_prev = self.rnn_decoder_latent_prev(y0, embed_out, zs)
        
        # Concatenate all predictions
        mu_post = torch.cat([
            mu_hosp_inc, mu_hosp_prev, 
            mu_latent_inc, mu_latent_prev
        ], dim=-1)
        
        phi_post = torch.cat([
            phi_hosp_inc, phi_hosp_prev, 
            phi_latent_inc, phi_latent_prev
        ], dim=-1)
        
        return mu_post, phi_post
    
    def update_y_stats(self, y_mean: torch.Tensor, y_std: torch.Tensor) -> None:
        """
        Update normalization statistics.
        
        Args:
            y_mean: Mean statistics [seq_len, y_dim * 4]
            y_std: Standard deviation statistics [seq_len, y_dim * 4]
        """
        if isinstance(y_mean, np.ndarray):
            y_mean = torch.from_numpy(y_mean).float()
        if isinstance(y_std, np.ndarray):
            y_std = torch.from_numpy(y_std).float()
        
        self.y_mean = y_mean.float()
        self.y_std = y_std.float()
        self.y_mean.requires_grad = False
        self.y_std.requires_grad = False
    
    def get_latent_tensors(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get global latent variable tensors."""
        return self.mu_z_global, self.var_z_global
    
    @staticmethod
    def context_target_split(
        x: torch.Tensor, 
        xt: torch.Tensor, 
        y: torch.Tensor, 
        y0: torch.Tensor,
        context_perc: float = 0.2
    ) -> Tuple[torch.Tensor, ...]:
        """
        Split data into context and target sets for neural process training.
        
        Args:
            x: Static features [batch_size, num_nodes, x_dim]
            xt: Temporal features [batch_size, seq_len, num_nodes, xt_dim]
            y: Target values [batch_size, seq_len, y_dim]
            y0: Initial conditions [batch_size, y_dim]
            context_perc: Percentage of data to use as context (0 < context_perc < 1)
            
        Returns:
            Tuple of:
            - x_context: Context static features [n_context, num_nodes, x_dim]
            - xt_context: Context temporal features [n_context, seq_len, num_nodes, xt_dim]
            - y_context: Context targets [n_context, seq_len, y_dim]
            - y0_context: Context initial conditions [n_context, y_dim]
            - x_target: Target static features [n_target, num_nodes, x_dim]
            - xt_target: Target temporal features [n_target, seq_len, num_nodes, xt_dim]
            - y_target: Target values [n_target, seq_len, y_dim]
            - y0_target: Target initial conditions [n_target, y_dim]
            - idc: Context indices
            - idt: Target indices
        """
        assert context_perc < 1.0 and context_perc > 0.0, "context_perc must be between 0 and 1"
        
        B = y.size(0)
        assert x.size(0) == B, "Batch size mismatch between x and y"
        
        n_context = int(context_perc * B)
        
        # Randomly shuffle indices
        idx = np.arange(B)
        np.random.shuffle(idx)
        
        # Split into context and target
        idc = idx[:n_context]
        idt = idx[n_context:]
        
        return (
            x[idc, ...], xt[idc, ...], y[idc, ...], y0[idc],
            x[idt, ...], xt[idt, ...], y[idt, ...], y0[idt],
            idc, idt
        )
    
    def configure_optimizers(self):
        """Configure optimizers for PyTorch Lightning."""
        # Check if we have training_config (new style) or hparams (old style)
        if hasattr(self, 'hparams') and hasattr(self.hparams, 'lr'):
            # Old style configuration from Lightning checkpoint
            optimizer = torch.optim.Adam([
                {"params": self.rnn_encoder.parameters()}, 
                {"params": self.rnn_decoder_hosp_inc.parameters(), "lr": self.hparams.lr_decoder},
                {"params": self.rnn_decoder_hosp_prev.parameters(), "lr": self.hparams.lr_decoder}, 
                {"params": self.rnn_decoder_latent_inc.parameters(), "lr": self.hparams.lr_decoder},
                {"params": self.rnn_decoder_latent_prev.parameters(), "lr": self.hparams.lr_decoder},
                {"params": self.z_encoder.parameters()},
                {"params": self.embed.parameters()}
            ], lr=self.hparams.lr)
            
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, 
                milestones=self.hparams.lr_milestones, 
                gamma=self.hparams.lr_gamma
            )
            return [optimizer], [lr_scheduler]
        elif self.training_config is not None:
            # New style configuration
            optimizer = torch.optim.Adam([
                {"params": self.rnn_encoder.parameters()}, 
                {"params": self.rnn_decoder_hosp_inc.parameters(), "lr": self.training_config.lr_decoder},
                {"params": self.rnn_decoder_hosp_prev.parameters(), "lr": self.training_config.lr_decoder}, 
                {"params": self.rnn_decoder_latent_inc.parameters(), "lr": self.training_config.lr_decoder},
                {"params": self.rnn_decoder_latent_prev.parameters(), "lr": self.training_config.lr_decoder},
                {"params": self.z_encoder.parameters()},
                {"params": self.embed.parameters()}
            ], lr=self.training_config.lr)
            
            if hasattr(self.training_config, 'lr_milestones') and self.training_config.lr_milestones:
                scheduler = torch.optim.lr_scheduler.MultiStepLR(
                    optimizer,
                    milestones=self.training_config.lr_milestones,
                    gamma=getattr(self.training_config, 'lr_gamma', 0.8)
                )
                return [optimizer], [scheduler]
            return optimizer
        else:
            # Default optimizer configuration
            optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
            return optimizer
    
    def loss_fn(
        self,
        mu_post: torch.Tensor,
        phi_post: torch.Tensor,
        y_hosp_inc_true: torch.Tensor,
        y_hosp_prev_true: torch.Tensor,
        y_latent_inc_true: torch.Tensor,
        y_latent_prev_true: torch.Tensor,
        mu_z_post: torch.Tensor,
        var_z_post: torch.Tensor,
        mu_z_prior: torch.Tensor,
        var_z_prior: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute ELBO loss for the STNP model.
        
        Args:
            mu_post: Posterior mean predictions [batch_size, seq_len, y_dim * 4]
            phi_post: Posterior dispersion parameters [batch_size, seq_len, y_dim * 4]
            y_hosp_inc_true: True hospital incidence [batch_size, seq_len, y_dim]
            y_hosp_prev_true: True hospital prevalence [batch_size, seq_len, y_dim]
            y_latent_inc_true: True latent incidence [batch_size, seq_len, y_dim]
            y_latent_prev_true: True latent prevalence [batch_size, seq_len, y_dim]
            mu_z_post: Posterior latent mean [seq_len, z_dim]
            var_z_post: Posterior latent variance [seq_len, z_dim]
            mu_z_prior: Prior latent mean [seq_len, z_dim]
            var_z_prior: Prior latent variance [seq_len, z_dim]
            
        Returns:
            Negative ELBO loss
        """
        # KL divergence between posterior and prior latent distributions
        kl = torch.distributions.kl_divergence(
            dist.Normal(mu_z_post, var_z_post.sqrt()),
            dist.Normal(mu_z_prior, var_z_prior.sqrt())
        ).sum()
        
        # Split predictions into compartments
        mu_hosp_inc, mu_hosp_prev, mu_latent_inc, mu_latent_prev = torch.chunk(
            mu_post, self.NUM_COMP, dim=-1
        )
        phi_hosp_inc, phi_hosp_prev, phi_latent_inc, phi_latent_prev = torch.chunk(
            phi_post, self.NUM_COMP, dim=-1
        )
        
        # Negative log-likelihood for each compartment
        nll_hosp_inc = -NegativeBinomial2(mu_hosp_inc, phi_hosp_inc).log_prob(
            y_hosp_inc_true
        ).sum(dim=2).mean(dim=(0, 1))
        
        nll_hosp_prev = -NegativeBinomial2(mu_hosp_prev, phi_hosp_prev).log_prob(
            y_hosp_prev_true
        ).sum(dim=2).mean(dim=(0, 1))
        
        nll_latent_inc = -NegativeBinomial2(mu_latent_inc, phi_latent_inc).log_prob(
            y_latent_inc_true
        ).sum(dim=2).mean(dim=(0, 1))
        
        nll_latent_prev = -NegativeBinomial2(mu_latent_prev, phi_latent_prev).log_prob(
            y_latent_prev_true
        ).sum(dim=2).mean(dim=(0, 1))
        
        # Total negative log-likelihood
        nll = nll_hosp_inc + nll_hosp_prev + nll_latent_inc + nll_latent_prev
        
        # Log metrics
        self.log_dict({
            "kl": kl.item(),
            "nll_hosp_inc": nll_hosp_inc.item(),
            "nll_hosp_prev": nll_hosp_prev.item(),
            "nll_latent_inc": nll_latent_inc.item(),
            "nll_latent_prev": nll_latent_prev.item()
        }, on_epoch=True, on_step=False, prog_bar=False)
        
        # Negative ELBO
        neg_elbo = nll + kl
        return neg_elbo
    
    @torch.no_grad()
    def compute_wmape(
        self,
        x: torch.Tensor,
        xt: torch.Tensor,
        y_hosp_inc: torch.Tensor,
        y_hosp_prev: torch.Tensor,
        y_latent_inc: torch.Tensor,
        y_latent_prev: torch.Tensor,
        y0_latent_prev: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Compute Weighted Mean Absolute Percentage Error.
        
        Returns:
            Tuple of (total_wmape, hosp_inc_wmape, hosp_prev_wmape, latent_inc_wmape, latent_prev_wmape)
        """
        # Split y_mean and y_std to get latent_prev statistics
        *_, y_latent_prev_mean = torch.chunk(self.y_mean, self.NUM_COMP, dim=-1)
        *_, y_latent_prev_std = torch.chunk(self.y_std, self.NUM_COMP, dim=-1)
        
        # Normalize inputs
        y0_latent_prev_standard = self.normalize(y0_latent_prev, y_latent_prev_mean, y_latent_prev_std)
        y_latent_prev_standard = self.normalize(y_latent_prev, y_latent_prev_mean, y_latent_prev_std)
        
        # Create WMAPE metric
        wmape_fn = WeightedMeanAbsolutePercentageError().to(y_hosp_inc.device)
        
        # Get embeddings and latent representation
        embed_out = self.get_input_embedding(x, xt)
        mu_z_post, var_z_post = self.get_latent_representation(
            embed_out, y_latent_prev_standard, y0_latent_prev_standard
        )
        
        # Sample and get predictions
        zs = self.sample_z(mu_z_post, var_z_post, y0_latent_prev.shape[0])
        mu_post, phi_post = self.get_post(y0_latent_prev_standard, embed_out, zs)
        y_post = self.sample_post(mu_post, phi_post)
        
        # Split predictions
        y_hosp_inc_pred, y_hosp_prev_pred, y_latent_inc_pred, y_latent_prev_pred = torch.chunk(
            y_post, chunks=self.NUM_COMP, dim=-1
        )
        
        # Compute WMAPE for each component
        y = torch.concat([y_hosp_inc, y_hosp_prev, y_latent_inc, y_latent_prev], dim=-1)
        total_wmape = wmape_fn(y_post, y)
        hosp_inc_wmape = wmape_fn(y_hosp_inc_pred, y_hosp_inc)
        hosp_prev_wmape = wmape_fn(y_hosp_prev_pred, y_hosp_prev)
        latent_inc_wmape = wmape_fn(y_latent_inc_pred, y_latent_inc)
        latent_prev_wmape = wmape_fn(y_latent_prev_pred, y_latent_prev)
        
        return total_wmape, hosp_inc_wmape, hosp_prev_wmape, latent_inc_wmape, latent_prev_wmape
    
    def on_train_epoch_start(self) -> None:
        """Called at the beginning of training epoch."""
        self.mu_z_list = []
        self.var_z_list = []
    
    def training_step(self, batch, batch_idx) -> torch.Tensor:
        """
        Training step for PyTorch Lightning.
        
        Args:
            batch: Batch of data (x, xt, y_hosp_inc, y_hosp_prev, y_latent_inc, y_latent_prev, y0_latent_prev)
            batch_idx: Batch index
            
        Returns:
            Loss tensor
        """
        # Unpack batch
        x, xt, y_hosp_inc, y_hosp_prev, y_latent_inc, y_latent_prev, y0_latent_prev = batch
        
        # Get normalization statistics for latent_prev
        *_, y_latent_prev_mean = torch.chunk(self.y_mean, self.NUM_COMP, dim=-1)
        *_, y_latent_prev_std = torch.chunk(self.y_std, self.NUM_COMP, dim=-1)
        
        # Context-target split
        context_percentage = getattr(self, 'context_percentage', 0.2)
        if hasattr(self, 'hparams') and hasattr(self.hparams, 'context_percentage'):
            context_percentage = self.hparams.context_percentage
        
        x_context, xt_context, y_context, y0_context, \
        x_target, xt_target, y_latent_prev_target, y0_target, \
        _, idt = self.context_target_split(
            x, xt, y_latent_prev, y0_latent_prev, context_percentage
        )
        
        # Get target data for other compartments
        y_hosp_inc_target = y_hosp_inc[idt, ...]
        y_hosp_prev_target = y_hosp_prev[idt, ...]
        y_latent_inc_target = y_latent_inc[idt, ...]
        y_latent_prev_target = y_latent_prev[idt, ...]
        
        # Normalize latent_prev data
        y_latent_prev_standard = self.normalize(y_latent_prev, y_latent_prev_mean, y_latent_prev_std)
        y0_latent_prev_standard = self.normalize(y0_latent_prev, y_latent_prev_mean, y_latent_prev_std)
        y_context_standard = self.normalize(y_context, y_latent_prev_mean, y_latent_prev_std)
        y0_context_standard = self.normalize(y0_context, y_latent_prev_mean, y_latent_prev_std)
        y0_target_standard = self.normalize(y0_target, y_latent_prev_mean, y_latent_prev_std)
        
        # Posterior latent distributions
        embed_out = self.get_input_embedding(x, xt)
        mu_z_post, var_z_post = self.get_latent_representation(
            embed_out, y_latent_prev_standard, y0_latent_prev_standard
        )
        
        zs_post = self.sample_z(mu_z_post, var_z_post, y0_target.shape[0])
        
        # Get target predictions
        embed_out = self.get_input_embedding(x_target, xt_target)
        mu_post, phi_post = self.get_post(y0_target_standard, embed_out, zs_post)
        
        # Prior latent distribution
        embed_out = self.get_input_embedding(x_context, xt_context)
        mu_z_prior, var_z_prior = self.get_latent_representation(
            embed_out, y_context_standard, y0_context_standard
        )
        
        # Store latent distributions for global update
        self.mu_z_list.append(mu_z_post)
        self.var_z_list.append(var_z_post)
        
        # Compute loss
        loss = self.loss_fn(
            mu_post, phi_post,
            y_hosp_inc_target, y_hosp_prev_target,
            y_latent_inc_target, y_latent_prev_target,
            mu_z_post, var_z_post,
            mu_z_prior, var_z_prior
        )
        
        # Compute metrics
        y_ground_truth = torch.cat([
            y_hosp_inc_target, y_hosp_prev_target,
            y_latent_inc_target, y_latent_prev_target
        ], dim=-1)
        y_pred = self.sample_post(mu_post, phi_post)
        
        # Split predictions for metrics
        y_hosp_inc_pred, y_hosp_prev_pred, y_latent_inc_pred, y_latent_prev_pred = torch.chunk(
            y_pred, chunks=self.NUM_COMP, dim=-1
        )
        
        # Compute MSE and MAE metrics
        train_mse = self.mse(y_pred, y_ground_truth)
        train_mae = self.mae(y_pred, y_ground_truth)
        
        train_mse_hosp_inc = self.mse(y_hosp_inc_pred, y_hosp_inc_target)
        train_mse_hosp_prev = self.mse(y_hosp_prev_pred, y_hosp_prev_target)
        train_mse_latent_inc = self.mse(y_latent_inc_pred, y_latent_inc_target)
        train_mse_latent_prev = self.mse(y_latent_prev_pred, y_latent_prev_target)
        
        train_mae_hosp_inc = self.mae(y_hosp_inc_pred, y_hosp_inc_target)
        train_mae_hosp_prev = self.mae(y_hosp_prev_pred, y_hosp_prev_target)
        train_mae_latent_inc = self.mae(y_latent_inc_pred, y_latent_inc_target)
        train_mae_latent_prev = self.mae(y_latent_prev_pred, y_latent_prev_target)
        
        # Compute WMAPE metrics
        train_wmape, train_wmape_hosp_inc, train_wmape_hosp_prev, \
        train_wmape_latent_inc, train_wmape_latent_prev = self.compute_wmape(
            x, xt, y_hosp_inc, y_hosp_prev,
            y_latent_inc, y_latent_prev, y0_latent_prev
        )
        
        # Log metrics
        self.log_dict({
            "train_mse_hosp_inc": train_mse_hosp_inc.item(),
            "train_mse_hosp_prev": train_mse_hosp_prev.item(),
            "train_mse_latent_inc": train_mse_latent_inc.item(),
            "train_mse_latent_prev": train_mse_latent_prev.item(),
            "train_mae": train_mae.item(),
            "train_mae_hosp_inc": train_mae_hosp_inc.item(),
            "train_mae_hosp_prev": train_mae_hosp_prev.item(),
            "train_mae_latent_inc": train_mae_latent_inc.item(),
            "train_mae_latent_prev": train_mae_latent_prev.item(),
            "train_wmape": train_wmape.item(),
            "train_wmape_hosp_inc": train_wmape_hosp_inc.item(),
            "train_wmape_hosp_prev": train_wmape_hosp_prev.item(),
            "train_wmape_latent_inc": train_wmape_latent_inc.item(),
            "train_wmape_latent_prev": train_wmape_latent_prev.item()
        }, on_epoch=True, on_step=False, prog_bar=False)
        
        self.log_dict({
            "train_loss": loss.item(),
            "train_mse": train_mse.item()
        }, on_epoch=True, on_step=False, prog_bar=True)
        
        return loss
    
    def on_train_epoch_end(self) -> None:
        """Called at the end of training epoch."""
        # Update global latent variables
        if hasattr(self, 'mu_z_list') and len(self.mu_z_list) > 0:
            self.mu_z_global = torch.stack(self.mu_z_list, dim=0).mean(0)
            self.var_z_global = torch.stack(self.var_z_list, dim=0).mean(0)
        
        # Reset validation outputs
        self.validation_step_outputs = []
    
    @torch.no_grad()
    def validation_step(self, batch, batch_idx):
        """
        Validation step for PyTorch Lightning.
        
        Args:
            batch: Batch of validation data
            batch_idx: Batch index
        """
        # Unpack batch
        x, xt, y_hosp_inc, y_hosp_prev, y_latent_inc, y_latent_prev, y0_latent_prev = batch
        
        # Get normalization statistics
        *_, y_latent_prev_mean = torch.chunk(self.y_mean, self.NUM_COMP, dim=-1)
        *_, y_latent_prev_std = torch.chunk(self.y_std, self.NUM_COMP, dim=-1)
        
        # Normalize y0
        y0_latent_prev_standard = self.normalize(y0_latent_prev, y_latent_prev_mean, y_latent_prev_std)
        
        # Concatenate all targets
        y = torch.concat([y_hosp_inc, y_hosp_prev, y_latent_inc, y_latent_prev], dim=-1)
        
        # Forward pass
        mu_post, phi_post = self(x, xt, y0_latent_prev_standard)
        y_pred = self.sample_post(mu_post, phi_post)
        
        # Compute validation loss
        val_loss = -NegativeBinomial2(mu_post, phi_post).log_prob(y).sum(dim=2).mean(dim=(0, 1))
        
        # Split predictions
        y_hosp_inc_pred, y_hosp_prev_pred, y_latent_inc_pred, y_latent_prev_pred = torch.chunk(
            y_pred, chunks=self.NUM_COMP, dim=-1
        )
        
        # Compute metrics
        val_mse = self.mse(y_pred, y)
        val_mae = self.mae(y_pred, y)
        
        val_mse_hosp_inc = self.mse(y_hosp_inc_pred, y_hosp_inc)
        val_mse_hosp_prev = self.mse(y_hosp_prev_pred, y_hosp_prev)
        val_mse_latent_inc = self.mse(y_latent_inc_pred, y_latent_inc)
        val_mse_latent_prev = self.mse(y_latent_prev_pred, y_latent_prev)
        
        val_mae_hosp_inc = self.mae(y_hosp_inc_pred, y_hosp_inc)
        val_mae_hosp_prev = self.mae(y_hosp_prev_pred, y_hosp_prev)
        val_mae_latent_inc = self.mae(y_latent_inc_pred, y_latent_inc)
        val_mae_latent_prev = self.mae(y_latent_prev_pred, y_latent_prev)
        
        # Compute WMAPE
        wmape_fn = WeightedMeanAbsolutePercentageError().to(y.device)
        val_wmape = wmape_fn(y_pred, y)
        hosp_inc_wmape = wmape_fn(y_hosp_inc_pred, y_hosp_inc)
        hosp_prev_wmape = wmape_fn(y_hosp_prev_pred, y_hosp_prev)
        latent_inc_wmape = wmape_fn(y_latent_inc_pred, y_latent_inc)
        latent_prev_wmape = wmape_fn(y_latent_prev_pred, y_latent_prev)
        
        # Store outputs for epoch aggregation
        self.validation_step_outputs.append({
            "val_loss": val_loss,
            "val_mse": val_mse,
            "val_mse_hosp_inc": val_mse_hosp_inc,
            "val_mse_hosp_prev": val_mse_hosp_prev,
            "val_mse_latent_inc": val_mse_latent_inc,
            "val_mse_latent_prev": val_mse_latent_prev,
            "val_mae": val_mae,
            "val_mae_hosp_inc": val_mae_hosp_inc,
            "val_mae_hosp_prev": val_mae_hosp_prev,
            "val_mae_latent_inc": val_mae_latent_inc,
            "val_mae_latent_prev": val_mae_latent_prev,
            "val_wmape": val_wmape,
            "val_wmape_hosp_inc": hosp_inc_wmape,
            "val_wmape_hosp_prev": hosp_prev_wmape,
            "val_wmape_latent_inc": latent_inc_wmape,
            "val_wmape_latent_prev": latent_prev_wmape
        })
    
    def on_validation_epoch_end(self) -> None:
        """Called at the end of validation epoch."""
        if not self.validation_step_outputs:
            return
        
        # Aggregate validation metrics
        val_mse = torch.stack([x["val_mse"] for x in self.validation_step_outputs], dim=0).mean()
        val_mae = torch.stack([x["val_mae"] for x in self.validation_step_outputs], dim=0).mean()
        val_wmape = torch.stack([x["val_wmape"] for x in self.validation_step_outputs], dim=0).mean()
        val_loss = torch.stack([x["val_loss"] for x in self.validation_step_outputs], dim=0).mean()
        
        val_mse_hosp_inc = torch.stack([x["val_mse_hosp_inc"] for x in self.validation_step_outputs], dim=0).mean()
        val_mse_hosp_prev = torch.stack([x["val_mse_hosp_prev"] for x in self.validation_step_outputs], dim=0).mean()
        val_mse_latent_inc = torch.stack([x["val_mse_latent_inc"] for x in self.validation_step_outputs], dim=0).mean()
        val_mse_latent_prev = torch.stack([x["val_mse_latent_prev"] for x in self.validation_step_outputs], dim=0).mean()
        
        val_mae_hosp_inc = torch.stack([x["val_mae_hosp_inc"] for x in self.validation_step_outputs], dim=0).mean()
        val_mae_hosp_prev = torch.stack([x["val_mae_hosp_prev"] for x in self.validation_step_outputs], dim=0).mean()
        val_mae_latent_inc = torch.stack([x["val_mae_latent_inc"] for x in self.validation_step_outputs], dim=0).mean()
        val_mae_latent_prev = torch.stack([x["val_mae_latent_prev"] for x in self.validation_step_outputs], dim=0).mean()
        
        val_wmape_hosp_inc = torch.stack([x["val_wmape_hosp_inc"] for x in self.validation_step_outputs], dim=0).mean()
        val_wmape_hosp_prev = torch.stack([x["val_wmape_hosp_prev"] for x in self.validation_step_outputs], dim=0).mean()
        val_wmape_latent_inc = torch.stack([x["val_wmape_latent_inc"] for x in self.validation_step_outputs], dim=0).mean()
        val_wmape_latent_prev = torch.stack([x["val_wmape_latent_prev"] for x in self.validation_step_outputs], dim=0).mean()
        
        # Log aggregated metrics
        self.log_dict({
            "val_loss": val_loss.item(),
            "val_mse_hosp_inc": val_mse_hosp_inc.item(),
            "val_mse_hosp_prev": val_mse_hosp_prev.item(),
            "val_mse_latent_inc": val_mse_latent_inc.item(),
            "val_mse_latent_prev": val_mse_latent_prev.item(),
            "val_mae": val_mae.item(),
            "val_mae_hosp_inc": val_mae_hosp_inc.item(),
            "val_mae_hosp_prev": val_mae_hosp_prev.item(),
            "val_mae_latent_inc": val_mae_latent_inc.item(),
            "val_mae_latent_prev": val_mae_latent_prev.item(),
            "val_wmape": val_wmape.item(),
            "val_wmape_hosp_inc": val_wmape_hosp_inc.item(),
            "val_wmape_hosp_prev": val_wmape_hosp_prev.item(),
            "val_wmape_latent_inc": val_wmape_latent_inc.item(),
            "val_wmape_latent_prev": val_wmape_latent_prev.item()
        }, on_epoch=True, on_step=False, prog_bar=False)
        
        self.log("val_mse", val_mse.item(), on_epoch=True, on_step=False, prog_bar=True)
        
        # Clear outputs for next epoch
        self.validation_step_outputs = []
    
    def get_model_info(self) -> Dict[str, Any]:
        """Get model information and statistics."""
        return {
            "model_name": "STNP",
            "x_dim": self.config.x_dim,
            "y_dim": self.config.y_dim,
            "seq_len": self.config.seq_len,
            "embed_out_dim": self.config.embed_out_dim,
            "z_dim": self.config.z_dim,
            "r_dim": self.config.r_dim,
            "num_nodes": self.config.num_nodes,
            "num_parameters": sum(p.numel() for p in self.parameters()),
            "trainable_parameters": sum(p.numel() for p in self.parameters() if p.requires_grad)
        }
