"""
SONATA Model

Integrates:
- Martingale representation theorem for coreset selection
- Optimal stopping problem for data point selection
- Enhanced multi-scale time weighting with Itô formula
- Improved state estimation and prediction methods
"""

import numpy as np
import torch
import logging
import bisect
import tensorly as tl
from collections import deque
import math
import sys

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Set tensorly backend
tl.set_backend("pytorch")

# Constants
JITTER = 1e-4


class DynamicCoreSetTensorFactorization:
    """
    Base class for Dynamic CoreSet Tensor Factorization
    Implements core methods for data point coreset selection and streaming tensor factorization
    """
    def __init__(self, hyper_dict, data_dict):
        """Initialize the model"""
        # Basic parameters
        self.device = hyper_dict["device"]
        self.R_U = hyper_dict["R_U"]
        
        # Prior parameters
        self.v = hyper_dict["v"]  # Prior variance
        self.a0 = hyper_dict["a0"]
        self.b0 = hyper_dict["b0"]
        self.DAMPING = hyper_dict["DAMPING"]
        self.DAMPING_tau = hyper_dict["DAMPING_tau"]
        
        # Data related parameters
        self.ndims = data_dict["ndims"]
        self.nmods = len(self.ndims)
        
        self.tr_ind = data_dict["tr_ind"]
        self.tr_y = torch.tensor(data_dict["tr_y"]).to(self.device)  # N*1
        
        self.te_ind = data_dict["te_ind"]
        self.te_y = torch.tensor(data_dict["te_y"]).to(self.device)  # N*1
        
        self.train_time_ind = data_dict["tr_T_disct"]  # N_train*1
        self.test_time_ind = data_dict["te_T_disct"]  # N_test*1
        
        self.unique_train_time = list(np.unique(self.train_time_ind))
        
        self.time_uni = data_dict["time_uni"]  # N_time*1
        self.N_time = len(self.time_uni)
        
        # Create LDS parameters
        self.lds_params = self._create_lds_params(hyper_dict, data_dict)
        
        # Store LDS trajectories for each mode and factor
        self.factor_dynamics = {}
        
        # Initialize coreset manager - will be replaced in subclasses
        if "coreset_manager" in hyper_dict:
            self.coreset_manager = hyper_dict["coreset_manager"]
        else:
            # Imported here to avoid circular imports
            from utils_martingale_coreset import DataPointCoreSetManager
            self.coreset_manager = DataPointCoreSetManager(
                max_size=hyper_dict.get("coreset_max_size", 100),
                initial_threshold=hyper_dict.get("coreset_threshold", 0.5),
                adaptive_threshold=hyper_dict.get("adaptive_threshold", True),
                importance_weights=hyper_dict.get("importance_weights", (0.4, 0.3, 0.3)),
                device=self.device,
                exploration_rate=hyper_dict.get("initial_exploration_rate", 0.9),
                decay_rate=hyper_dict.get("exploration_decay_rate", 0.1)
            )
        
        # Initialize multi-scale weighting mechanism
        if "multi_scale_weighting" in hyper_dict:
            self.multi_scale_weighting = hyper_dict["multi_scale_weighting"]
        else:
            # Imported here to avoid circular imports
            from utils_martingale_coreset import MultiScaleWeighting
            self.multi_scale_weighting = MultiScaleWeighting(
                num_scales=hyper_dict.get("num_time_scales", 3),
                hidden_dim=hyper_dict.get("scale_hidden_dim", 32),
                device=self.device,
                temperature=hyper_dict.get("attention_temperature", 1.0)
            )
        
        # Factor posterior distributions
        self.post_U_m = [
            torch.rand(dim, self.R_U, 1, self.N_time).double().to(self.device)
            for dim in self.ndims
        ]  #  (dim, R_U, 1, T) * nmod
        self.post_U_v = [
            torch.eye(self.R_U).reshape(
                (1, self.R_U, self.R_U,
                 1)).repeat(dim, 1, 1, self.N_time).double().to(self.device)
            for dim in self.ndims
        ]  # (dim, R_U, R_U, T) * nmod
        
        # Noise posterior distribution
        self.post_a = self.a0
        self.post_b = self.b0
        self.E_tau = 1
        
        # Build time-data tables: given timestamp ID, returns entry indices
        if "utils_streaming" in sys.modules:
            import utils_streaming
            self.time_data_table_tr = utils_streaming.build_time_data_table(
                self.train_time_ind)
            
            self.time_data_table_te = utils_streaming.build_time_data_table(
                self.test_time_ind)
        else:
            # Basic implementation if utils_streaming not available
            self.time_data_table_tr = self._build_time_data_table(self.train_time_ind)
            self.time_data_table_te = self._build_time_data_table(self.test_time_ind)
        
        # Placeholders
        self.ind_T = None
        self.y_T = None
        self.uid_table = None
        self.data_table = None
        
        # Store messages (by uid order)
        self.msg_U_m = None
        self.msg_U_V = None
        
        # Store messages (by data-llk order)
        self.msg_U_lam_llk = None
        self.msg_U_eta_llk = None
        
        self.msg_a_llk = None
        self.msg_b_llk = None
        
        # Set product method
        self.product_method = None  # Set in subclasses
        
        logger.info(f"Initialized DynamicCoreSetTensorFactorization: R_U={self.R_U}, nmods={self.nmods}")
    
    def _build_time_data_table(self, time_indices):
        """Build time-data table if utils_streaming not available"""
        time_data_table = {}
        for i, t in enumerate(time_indices):
            if t not in time_data_table:
                time_data_table[t] = []
            time_data_table[t].append(i)
        return time_data_table
    
    def _create_lds_params(self, hyper_dict, data_dict):
        """Create LDS parameters"""
        LDS_init = {}
        LDS_init["device"] = hyper_dict["device"]
        
        # Build F,H,R matrices
        D = hyper_dict["R_U"]
        
        # Get parameters from config
        LDS_init["R"] = torch.tensor(hyper_dict.get("noise", 1.0))
        kernel = hyper_dict.get("kernel", "Matern_23")
        lengthscale = hyper_dict.get("lengthscale", 0.3)
        variance = hyper_dict.get("variance", 1.0)
        
        if kernel == "Matern_21":
            LDS_init["F"] = -1 / lengthscale * torch.eye(D)
            LDS_init["H"] = torch.eye(D)
            LDS_init["P_inf"] = torch.eye(D) * variance
            LDS_init["P_0"] = LDS_init["P_inf"]
            LDS_init["m_0"] = torch.randn(D, 1) * 0.3
        elif kernel == "Matern_23":
            lamb = np.sqrt(3) / lengthscale
            
            F = torch.zeros((2 * D, 2 * D))
            F[:D, :D] = 0
            F[:D, D:] = torch.eye(D)
            F[D:, :D] = -lamb * lamb * torch.eye(D)
            F[D:, D:] = -2 * lamb * torch.eye(D)
            
            P_inf = torch.diag(
                torch.cat((
                    variance * torch.ones(D),
                    lamb * lamb * variance * torch.ones(D),
                ))
            )
            
            LDS_init["F"] = F
            LDS_init["P_inf"] = P_inf
            LDS_init["H"] = torch.cat((torch.eye(D), torch.zeros(D, D)), dim=1)
            LDS_init["P_0"] = P_inf
            LDS_init["m_0"] = 0.1 * torch.ones(2 * D, 1)
        
        logger.info(f"LDS parameters created: kernel={kernel}, lengthscale={lengthscale}, variance={variance}")
        return LDS_init

    def ensure_factor_dynamics(self, mode, uid):
        """Ensure factor dynamics process exists for the specified factor"""
        if (mode, uid) not in self.factor_dynamics:
            # Create a new LDS_GP_streaming instance
            try:
                from model_LDS import LDS_GP_streaming
                self.factor_dynamics[(mode, uid)] = LDS_GP_streaming(self.lds_params)
            except ImportError:
                # If model_LDS not available, use a simplified version
                self.factor_dynamics[(mode, uid)] = SimplifiedLDS(self.lds_params)
        return self.factor_dynamics[(mode, uid)]
    
    def track_envloved_objects(self, T):
        """Get indices/values/object IDs of entries observed at time T"""
        eind_T = self.time_data_table_tr[T]  # List of entry IDs observed at this timestamp
        
        self.ind_T = self.tr_ind[eind_T]
        self.y_T = self.tr_y[eind_T].reshape(-1, 1, 1)
        
        self.N_T = len(self.y_T)
        
        # Build object ID and data tables
        try:
            import utils_streaming
            self.uid_table, self.data_table = utils_streaming.build_id_key_table(
                nmod=self.nmods, ind=self.ind_T
            )
        except ImportError:
            # Simple implementation if utils_streaming not available
            self.uid_table, self.data_table = self._build_id_key_table(self.ind_T)
    
    def _build_id_key_table(self, ind):
        """Build ID key table if utils_streaming not available"""
        uid_table = [[] for _ in range(self.nmods)]
        data_table = [[] for _ in range(self.nmods)]
        
        for i, indices in enumerate(ind):
            for mode, idx in enumerate(indices):
                if idx not in uid_table[mode]:
                    uid_table[mode].append(idx)
                    data_table[mode].append([i])
                else:
                    pos = uid_table[mode].index(idx)
                    data_table[mode][pos].append(i)
        
        return uid_table, data_table
    
    def filter_predict(self, T):
        """KF prediction step for trajectories of involved objects + update posterior"""
        current_time_stamp = self.time_uni[T]
        
        # Predict for each involved object
        for mode in range(self.nmods):
            for uid in self.uid_table[mode]:
                # Ensure factor dynamics process exists
                factor_process = self.ensure_factor_dynamics(mode, uid)
                
                # Perform prediction
                factor_process.filter_predict(current_time_stamp)
                
                # Update posterior
                H = factor_process.H
                m = factor_process.m_pred_list[-1]
                P = factor_process.P_pred_list[-1]
                
                self.post_U_m[mode][uid, :, :, T] = torch.mm(H, m)
                self.post_U_v[mode][uid, :, :, T] = torch.mm(torch.mm(H, P), H.T)
    
    def update_coreset(self, T):
        """
        Update data point coreset
        Submit current time step's observed data points to coreset manager for evaluation
        """
        # Prepare current time step's data batch
        data_batch = []
        
        for i in range(self.N_T):
            indices = self.ind_T[i]
            y = self.y_T[i].item()
            time_ind = T
            
            data_batch.append((indices, y, time_ind))
        
        # Use coreset manager to update coreset
        added, removed = self.coreset_manager.update_coreset(data_batch, self)
        
        # Log coreset size and changes
        if T % 10 == 0:  # Log every 10 time steps
            logger.info(f"T={T}, coreset size: {self.coreset_manager.get_coreset_size()}, " 
                       f"added: {len(added)}, removed: {len(removed)}")
        
        # Update model confidence (if available)
        if T > 0:
            pred_error = self.compute_prediction_error(T)
            if pred_error > 0:
                confidence = 1.0 / (1.0 + pred_error)
                self.coreset_manager.update_confidence(confidence)
        
        return added, removed

    def compute_prediction_error(self, T):
        """Calculate current prediction error for confidence update"""
        if hasattr(self, 'te_ind') and hasattr(self, 'te_y'):
            # Use a small subset of test data to calculate error
            sample_size = min(100, len(self.te_ind))
            if sample_size > 0:
                indices = np.random.choice(len(self.te_ind), sample_size, replace=False)
                sample_ind = self.te_ind[indices]
                sample_y = self.te_y[indices]
                sample_time = np.ones_like(indices) * T
                
                pred, _ = self.model_test(sample_ind, sample_y, sample_time)
                
                # Calculate MSE
                mse = torch.mean((pred.squeeze() - sample_y.squeeze()) ** 2)
                return mse.item()
        
        return 0.0
    
    def get_scale_weights(self, T, mode=None):
        """Get multi-scale weights"""
        try:
            # Collect hidden states from involved object trajectories
            h_scales = []
            
            # Collect states from different trajectories
            for m in range(self.nmods):
                # Only consider current mode (if specified)
                if mode is not None and m != mode:
                    continue
                    
                # Extract first few objects from this mode (max 3)
                sample_uids = self.uid_table[m][:3] if m in self.uid_table else []
                
                for uid in sample_uids:
                    if (m, uid) in self.factor_dynamics:
                        factor_process = self.factor_dynamics[(m, uid)]
                        if hasattr(factor_process, 'm') and factor_process.m is not None:
                            # Convert to float type and ensure consistent dimensions
                            state = factor_process.m.clone().detach().to(torch.float32)
                            # Ensure flattened to 1D vector
                            state = state.reshape(-1)
                            h_scales.append(state)
                            
                            # Maximum 3 hidden states
                            if len(h_scales) >= 3:
                                break
                
                # Maximum 3 hidden states per mode
                if len(h_scales) >= 3:
                    break
            
            # If not enough scales, return uniform weights
            if len(h_scales) < 1:
                return torch.ones(1, 1, device=self.device)
                
            # Calculate weights
            weights = self.multi_scale_weighting.compute_weights(h_scales)
            return weights
            
        except Exception as e:
            logger.error(f"Error calculating multi-scale weights: {e}")
            # Return uniform weights on error
            return torch.ones(1, 1, device=self.device)
    
    def msg_llk_init(self):
        """Initialize llk-msg for CEP inner loop"""
        N_T = len(self.y_T)  # Current time step entry count
        
        # Initialize msg_U_llk using natural parameters: lam = S_inv, eta = S_inv x m
        self.msg_U_lam_llk = [
            1e-4 * torch.eye(self.R_U).reshape((1, self.R_U, self.R_U)).repeat(
                N_T, 1, 1).double().to(self.device) for i in range(self.nmods)
        ]  # (N*R_U*R_U)*nmod
        self.msg_U_eta_llk = [
            1e-3 * torch.rand(N_T, self.R_U, 1).double().to(self.device)
            for i in range(self.nmods)
        ]  # (N*R_U*1)*nmod
        
        # tau messages
        self.msg_a = torch.ones(N_T, 1).double().to(self.device)  # N*1
        self.msg_b = torch.ones(N_T, 1).double().to(self.device)  # N*1
    
    def filter_update(self, T, mode, add_to_list=True):
        """KF update step for trajectories of involved objects"""
        # Update all objects involved in this mode
        for msg_id, uid in enumerate(self.uid_table[mode]):
            # Check if corresponding message exists
            if msg_id >= len(self.msg_U_m[mode]):
                continue
                
            # Get approximate msg as KF observation
            y = self.msg_U_m[mode][msg_id]
            R = self.msg_U_V[mode][msg_id]
            
            # Get factor process
            factor_process = self.ensure_factor_dynamics(mode, uid)
            
            # KF update step
            factor_process.filter_update(y=y, R=R, add_to_list=add_to_list)
            
            # Update posterior objects
            H = factor_process.H
            m = factor_process.m
            P = factor_process.P
            
            # Update posterior
            self.post_U_m[mode][uid, :, :, T] = torch.mm(H, m)
            self.post_U_v[mode][uid, :, :, T] = torch.mm(torch.mm(H, P), H.T)

    def smooth(self):
        """Smooth all object trajectories"""
        for (mode, uid), factor_process in self.factor_dynamics.items():
            factor_process.smooth()
    
    def inner_smooth(self):
        """Online smoothing for evaluation during training, clean up and update post_U after smoothing"""
        self.smooth()
        self.get_post_U()
        
        # Reset all factor processes' smooth lists
        for (_, _), factor_dynamics in self.factor_dynamics.items():
            factor_dynamics.reset_smooth_list()
    
    def get_post_U(self):
        """Get final post_U using smoothing results"""
        for T, time_stamp in enumerate(self.time_uni):
            for mode in range(self.nmods):
                for uid in range(self.ndims[mode]):
                    if (mode, uid) in self.factor_dynamics:
                        factor_process = self.factor_dynamics[(mode, uid)]
                        
                        if len(factor_process.time_stamp_list) > 0:
                            # At least one observation
                            
                            if time_stamp in factor_process.time_stamp_list:
                                # Timestamp appeared before
                                
                                T_id = factor_process.time_2_ind_table[time_stamp]
                                # Update posterior based on smoothed state
                                
                                if T_id < len(factor_process.m_smooth_list):
                                    H = factor_process.H
                                    m = factor_process.m_smooth_list[T_id]
                                    P = factor_process.P_smooth_list[T_id]
                                    
                                    self.post_U_m[mode][uid, :, :, T] = torch.mm(H, m)
                                    self.post_U_v[mode][uid, :, :, T] = torch.mm(
                                        torch.mm(H, P), H.T)
                            
                            else:
                                # Timestamp never appeared before
                                # Locate position of unseen timestamp
                                loc = bisect.bisect(factor_process.time_stamp_list,
                                                   time_stamp)
                                
                                if loc == 0 and len(factor_process.m_smooth_list) > 0:
                                    # First backwards Gaussian jump extrapolation
                                    prev_time_stamp = factor_process.time_stamp_list[loc]
                                    prev_m = factor_process.m_smooth_list[loc]
                                    prev_P = factor_process.P_smooth_list[loc]
                                    prev_time_int = prev_time_stamp - time_stamp
                                    
                                    prev_A = torch.inverse(
                                        torch.matrix_exp(factor_process.F *
                                                       prev_time_int).double())
                                    prev_Q = factor_process.P_inf - torch.mm(
                                        torch.mm(prev_A, factor_process.P_inf), prev_A.T)
                                    
                                    jump_m = torch.mm(prev_A, prev_m)
                                    jump_P = (torch.mm(torch.mm(prev_A, prev_P),
                                                     prev_A.T) + prev_Q)
                                    
                                    H = factor_process.H
                                    self.post_U_m[mode][uid, :, :, T] = torch.mm(H, jump_m)
                                    self.post_U_v[mode][uid, :, :, T] = torch.mm(
                                        torch.mm(H, jump_P), H.T)
                                
                                elif loc < len(factor_process.time_stamp_list) and len(factor_process.m_smooth_list) > loc:
                                    # Interpolation, merge (according to time sequence interpolation formula)
                                    
                                    prev_time_stamp = factor_process.time_stamp_list[loc - 1]
                                    next_time_stamp = factor_process.time_stamp_list[loc]
                                    
                                    if loc - 1 < len(factor_process.m_smooth_list) and loc < len(factor_process.m_smooth_list):
                                        prev_m = factor_process.m_smooth_list[loc - 1]
                                        prev_P = factor_process.P_smooth_list[loc - 1]
                                        
                                        next_m = factor_process.m_smooth_list[loc]
                                        next_P = factor_process.P_smooth_list[loc]
                                        
                                        prev_time_int = time_stamp - prev_time_stamp
                                        next_time_int = next_time_stamp - time_stamp
                                        
                                        prev_A = torch.matrix_exp(
                                            factor_process.F * prev_time_int).double()
                                        prev_Q = factor_process.P_inf - torch.mm(
                                            torch.mm(prev_A, factor_process.P_inf), prev_A.T)
                                        
                                        Q1_inv = torch.inverse(
                                            torch.mm(torch.mm(prev_A, prev_P),
                                                   prev_A.T) + prev_Q)
                                        
                                        next_A = torch.matrix_exp(
                                            factor_process.F * next_time_int).double()
                                        next_Q = factor_process.P_inf - torch.mm(
                                            torch.mm(next_A, factor_process.P_inf), next_A.T)
                                        
                                        Q2_inv = torch.inverse(
                                            torch.mm(torch.mm(next_A, next_P),
                                                   next_A.T) + next_Q)
                                        
                                        merge_P = torch.inverse(Q1_inv + torch.mm(
                                            next_A.T, torch.mm(Q2_inv, next_A)))
                                        
                                        temp_term = torch.mm(
                                            Q1_inv, torch.mm(
                                                prev_A, prev_m)) + torch.mm(
                                                    Q2_inv, torch.mm(next_A, next_m))
                                        merge_m = torch.mm(merge_P, temp_term)
                                        
                                        H = factor_process.H
                                        self.post_U_m[mode][uid, :, :, T] = torch.mm(H, merge_m)
                                        self.post_U_v[mode][uid, :, :, T] = torch.mm(
                                            torch.mm(H, merge_P), H.T)
                                
                                elif loc > 0 and loc - 1 < len(factor_process.m_smooth_list):
                                    # Extrapolate at end, forward Gaussian jump
                                    prev_time_stamp = factor_process.time_stamp_list[loc - 1]
                                    prev_m = factor_process.m_smooth_list[loc - 1]
                                    prev_P = factor_process.P_smooth_list[loc - 1]
                                    prev_time_int = time_stamp - prev_time_stamp
                                    
                                    prev_A = torch.matrix_exp(
                                        factor_process.F * prev_time_int).double()
                                    prev_Q = factor_process.P_inf - torch.mm(
                                        torch.mm(prev_A, factor_process.P_inf), prev_A.T)
                                    
                                    jump_m = torch.mm(prev_A, prev_m)
                                    jump_P = (torch.mm(torch.mm(prev_A, prev_P),
                                                     prev_A.T) + prev_Q)
                                    
                                    H = factor_process.H
                                    self.post_U_m[mode][uid, :, :, T] = torch.mm(H, jump_m)
                                    self.post_U_v[mode][uid, :, :, T] = torch.mm(
                                        torch.mm(H, jump_P), H.T)
    
    def model_test(self, test_ind, test_y, test_time):
        """
        Model testing and evaluation - implemented for data point coreset
        Handles coreset and non-coreset data points differently
        """
        MSE_loss = torch.nn.MSELoss()
        MAE_loss = torch.nn.L1Loss()
        
        loss_test = {}
        
        # Check which test data points are in coreset
        is_coreset = []
        for i, indices in enumerate(test_ind):
            # Check if data point is in coreset
            is_core = self.coreset_manager.is_in_coreset(indices)
            is_coreset.append(is_core)
                    
        # Handle coreset and non-coreset data points separately
        pred = torch.zeros(len(test_ind), device=self.device)
        
        # Handle coreset data points (using full model)
        core_indices = [i for i, is_core in enumerate(is_coreset) if is_core]
        if core_indices:
            try:
                core_pred = self.model_test_coreset(
                    test_ind[core_indices], test_time[core_indices])
                pred[core_indices] = core_pred.squeeze().to(pred.dtype)
            except Exception as e:
                logger.error(f"Coreset prediction error: {e}")
                # Fill with zeros on error
                pred[core_indices] = 0.0
        
        # Handle non-coreset data points (using approximate model)
        noncore_indices = [i for i, is_core in enumerate(is_coreset) if not is_core]
        if noncore_indices:
            try:
                noncore_pred = self.model_test_noncore(
                    test_ind[noncore_indices], test_time[noncore_indices])
                pred[noncore_indices] = noncore_pred.squeeze().to(pred.dtype)
            except Exception as e:
                logger.error(f"Non-coreset prediction error: {e}")
                # Fill with zeros on error
                pred[noncore_indices] = 0.0

        try:
            # Calculate error metrics
            loss_test["rmse"] = torch.sqrt(
                MSE_loss(pred.squeeze(),
                        test_y.squeeze().to(self.device)))
            loss_test["MAE"] = MAE_loss(pred.squeeze(),
                                    test_y.squeeze().to(self.device))
        except Exception as e:
            logger.error(f"Error calculating error metrics: {e}")
            # Use large values on error
            loss_test["rmse"] = torch.tensor(9999.0, device=self.device)
            loss_test["MAE"] = torch.tensor(999.0, device=self.device)
        
        return pred, loss_test

    def model_test_coreset(self, test_ind, test_time):
        """Handle coreset data points testing method"""
        # Implemented in subclasses
        raise NotImplementedError("Should be implemented in subclass")
    
    def model_test_noncore(self, test_ind, test_time):
        """Handle non-coreset data points testing method"""
        # Implemented in subclasses
        raise NotImplementedError("Should be implemented in subclass")
    
    def reset(self):
        """Reset model state"""
        for (_, _), factor_dynamics in self.factor_dynamics.items():
            factor_dynamics.reset_list()
        self.factor_dynamics = {}


class DCTF_CP(DynamicCoreSetTensorFactorization):
    """
    Dynamic CoreSet Tensor Factorization CP form
    CP model with data point coreset support
    """
    def __init__(self, hyper_dict, data_dict):
        """Initialize CP model"""
        super().__init__(hyper_dict, data_dict)
        self.product_method = "hadamard"  # CP
        
        # For CP, gamma is a constant all-one vector
        self.post_gamma_m = torch.ones(self.R_U, 1).double().to(self.device)  # (R)*1
        
        logger.info("Initialized DCTF_CP model")
    
    def product_with_gamma(self, E_z, E_z_2, mode):
        """Multiply with gamma: for CP, gamma is constant all-1 vector, so we actually do nothing here"""
        return E_z, E_z_2
    
    def msg_approx_U(self, T, mode):
        """Approximate msg from data-llk groups at T"""
        # Reset msg_U_m, msg_U_V
        msg_U_m_mode = []
        msg_U_V_mode = []
        
        condi_modes = [i for i in range(self.nmods)]
        condi_modes.remove(mode)  # [1,2], [0,2]
        
        # Import utils_streaming for moment product
        try:
            import utils_streaming
            E_z, E_z_2 = utils_streaming.moment_product(
                modes=condi_modes,
                ind=self.ind_T,
                U_m=[ele[:, :, :, T] for ele in self.post_U_m],
                U_v=[ele[:, :, :, T] for ele in self.post_U_v],
                order="second",
                sum_2_scaler=False,
                device=self.device,
                product_method=self.product_method,
            )
        except ImportError:
            # Simple implementation if utils_streaming not available
            E_z, E_z_2 = self._moment_product(
                modes=condi_modes,
                ind=self.ind_T,
                U_m=[ele[:, :, :, T] for ele in self.post_U_m],
                U_v=[ele[:, :, :, T] for ele in self.post_U_v],
                T=T
            )
        
        E_z, E_z_2 = self.product_with_gamma(E_z, E_z_2, mode)
        
        # Apply multi-scale weights
        weights = self.get_scale_weights(T, mode)
        
        # First use natural parameters for easier msg merging
        msg_U_lam_new = self.E_tau * E_z_2  # (N,R,R)
        msg_U_eta_new = self.y_T * E_z * self.E_tau  # (N,R,1)
        
        # Distinguish between coreset and non-coreset data points, apply different weights
        for i in range(len(self.ind_T)):
            # Check if data point is in coreset
            indices = self.ind_T[i]
            is_in_coreset = self.coreset_manager.is_in_coreset(indices)
            
            # Apply different weights to coreset and non-coreset data points
            if is_in_coreset:
                # Standard weights for coreset data points
                pass
            else:
                # Use smaller weights for non-coreset data points, reducing influence
                msg_U_lam_new[i] = msg_U_lam_new[i] * 0.9
                msg_U_eta_new[i] = msg_U_eta_new[i] * 0.9
        
        # Apply multi-scale weights
        if weights is not None and weights.numel() > 0:
            try:
                # Ensure weights have correct shape
                if weights.dim() == 1:
                    weights = weights.unsqueeze(1)  # Convert to column vector
                    
                # Limit scale count
                max_scales = min(weights.shape[0], 3)  # Consider max 3 scales
                
                # Use only global scale weight to simplify implementation
                if max_scales > 0:
                    # Get weight value and ensure it's valid
                    scale_weight = weights[0].item()
                    if scale_weight > 0 and not math.isnan(scale_weight) and not math.isinf(scale_weight):
                        msg_U_lam_new = scale_weight * msg_U_lam_new
                        msg_U_eta_new = scale_weight * msg_U_eta_new
            except Exception as e:
                logger.error(f"Error applying weights: {e}")
        
        # DAMPING step:
        self.msg_U_lam_llk[mode] = (self.DAMPING * self.msg_U_lam_llk[mode] +
                                (1 - self.DAMPING) * msg_U_lam_new)
        
        self.msg_U_eta_llk[mode] = (self.DAMPING * self.msg_U_eta_llk[mode] +
                                (1 - self.DAMPING) * msg_U_eta_new)
        
        # Fill msg_U_M, msg_U_V
        for i in range(len(self.uid_table[mode])):
            uid = self.uid_table[mode][i]  # Embedding id
            eid = self.data_table[mode][i]  # Associated entry id
            
            S_inv_cur = self.msg_U_lam_llk[mode][eid].sum(dim=0)  # (R,R)
            S_inv_Beta_cur = self.msg_U_eta_llk[mode][eid].sum(dim=0)  # (R,1)
            
            try:
                # Calculate covariance and mean
                U_V = torch.linalg.inv(S_inv_cur)
                U_M = torch.mm(U_V, S_inv_Beta_cur)  # (R,1)
            except Exception as e:
                # Handle matrix inversion failure
                logger.warning(f"Matrix inversion failed: {e}")
                jitter = 1e-3 * torch.eye(S_inv_cur.size(0)).to(S_inv_cur.device)
                U_V = torch.linalg.inv(S_inv_cur + jitter)
                U_M = torch.mm(U_V, S_inv_Beta_cur)
            
            msg_U_m_mode.append(U_M)
            msg_U_V_mode.append(U_V)
        
        self.msg_U_m.append(msg_U_m_mode)
        self.msg_U_V.append(msg_U_V_mode)
    
    def _moment_product(self, modes, ind, U_m, U_v, T):
        """Simple implementation of moment product if utils_streaming not available"""
        N = len(ind)
        E_z = torch.zeros(N, self.R_U, 1, device=self.device)
        E_z_2 = torch.zeros(N, self.R_U, self.R_U, device=self.device)
        
        # For each data point
        for i in range(N):
            # For CP, we do element-wise product of factors
            # Initialize with ones
            mean_prod = torch.ones(self.R_U, 1, device=self.device)
            cov_prod = torch.eye(self.R_U, device=self.device)
            
            # For each mode in conditional modes
            for mode in modes:
                idx = ind[i][mode]
                m = U_m[mode][idx]
                v = U_v[mode][idx]
                
                # Hadamard product for CP
                mean_prod = mean_prod * m
                cov_prod = cov_prod * (v + torch.mm(m, m.T))
            
            E_z[i] = mean_prod
            E_z_2[i] = cov_prod
        
        return E_z, E_z_2

    def msg_approx_tau(self, T):
        """Approximate msg for tau"""
        all_modes = [i for i in range(self.nmods)]
        
        try:
            import utils_streaming
            E_z, E_z_2 = utils_streaming.moment_product(
                modes=all_modes,
                ind=self.ind_T,
                U_m=[ele[:, :, :, T] for ele in self.post_U_m],
                U_v=[ele[:, :, :, T] for ele in self.post_U_v],
                order="second",
                sum_2_scaler=False,
                device=self.device,
                product_method=self.product_method,
            )
        except ImportError:
            # Simple implementation if utils_streaming not available
            E_z, E_z_2 = self._moment_product(
                modes=all_modes,
                ind=self.ind_T,
                U_m=[ele[:, :, :, T] for ele in self.post_U_m],
                U_v=[ele[:, :, :, T] for ele in self.post_U_v],
                T=T
            )
        
        self.msg_a = 1.5 * torch.ones(self.N_T, 1).to(self.device)
        
        term1 = 0.5 * torch.square(self.y_T)  # N_T*1
        
        term2 = self.y_T.reshape(-1, 1) * torch.matmul(
            E_z.transpose(dim0=1, dim1=2), 
            torch.ones(self.R_U, 1).double().to(self.device)).reshape(-1, 1)  # N_T*1
        
        temp = torch.matmul(E_z_2, torch.ones(self.R_U, 1).double().to(self.device))  # N*R*1
        term3 = 0.5 * torch.matmul(temp.transpose(dim0=1, dim1=2),
                                 torch.ones(self.R_U, 1).double().to(self.device)).reshape(-1, 1)  # N*1
        
        self.msg_b = self.DAMPING_tau * self.msg_b + (1 - self.DAMPING_tau) * (
            term1.reshape(-1, 1) - term2.reshape(-1, 1) + term3.reshape(-1, 1)
        )  # N*1
    
    def post_update_tau(self, T=None):
        """Update posterior factor tau based on current msg factors"""
        self.post_a = self.post_a + self.msg_a.sum() - self.N_T
        self.post_b = self.post_b + self.msg_b.sum()
        self.E_tau = self.post_a / self.post_b
    
    def model_test_coreset(self, test_ind, test_time):
        """
        Test prediction function for coreset data points
        Uses full CP decomposition model for coreset data points
        """
        all_modes = [i for i in range(self.nmods)]
        
        # Use CP decomposition for prediction
        try:
            import utils_streaming
            pred = utils_streaming.moment_product_T(
                modes=all_modes,
                ind=test_ind,
                ind_T=test_time,
                U_m_T=self.post_U_m,
                U_v_T=self.post_U_v,
                order="first",
                sum_2_scaler=True,
                device=self.device,
                product_method=self.product_method,
            )
        except ImportError:
            # Simple implementation if utils_streaming not available
            pred = self._moment_product_T(
                test_ind=test_ind,
                test_time=test_time
            )
        
        return pred
    
    def _moment_product_T(self, test_ind, test_time):
        """Simple implementation of moment product T if utils_streaming not available"""
        N = len(test_ind)
        pred = torch.zeros(N, device=self.device)
        
        # For each test point
        for i in range(N):
            indices = test_ind[i]
            time = test_time[i]
            
            # Get factor means for each mode
            factors = []
            for mode in range(self.nmods):
                idx = indices[mode]
                factors.append(self.post_U_m[mode][idx, :, :, time])
            
            # CP prediction: product of factors summed over rank
            prod = torch.ones(self.R_U, 1, device=self.device)
            for f in factors:
                prod = prod * f
            
            pred[i] = torch.sum(prod)
        
        return pred
    
    def model_test_noncore(self, test_ind, test_time):
        """
        Test prediction function for non-coreset data points
        Uses approximate model based on coreset data points
        """
        # Get data points in coreset
        coreset_data = self.coreset_manager.get_coreset_data()
        
        if not coreset_data:
            # If coreset is empty, use standard test method
            return self.model_test_coreset(test_ind, test_time)
        
        # For each test data point, find most similar coreset data point
        pred = torch.zeros(len(test_ind), device=self.device)
        
        for i, (indices, time) in enumerate(zip(test_ind, test_time)):
            # Find nearest coreset data point
            nearest_core_idx = -1
            min_distance = float('inf')
            
            for core_idx, (core_indices, _, core_time) in enumerate(coreset_data):
                # Calculate index similarity: number of matching indices
                common_indices = sum(1 for a, b in zip(indices, core_indices) if a == b)
                index_sim = common_indices / len(indices)
                
                # Calculate time similarity: closeness of time steps
                time_diff = abs(time - core_time)
                time_sim = math.exp(-0.1 * time_diff)  # Smaller time diff means higher similarity
                
                # Combined similarity
                similarity = 0.7 * index_sim + 0.3 * time_sim
                distance = 1.0 - similarity
                
                if distance < min_distance:
                    min_distance = distance
                    nearest_core_idx = core_idx
            
            # If found most similar coreset data point, use its prediction result
            if nearest_core_idx >= 0:
                core_indices, _, core_time = coreset_data[nearest_core_idx]
                
                # Generate prediction for this coreset data point
                try:
                    import utils_streaming
                    core_pred = utils_streaming.moment_product(
                        modes=list(range(self.nmods)),
                        ind=np.array([core_indices]),
                        U_m=[ele[:, :, :, core_time] for ele in self.post_U_m],
                        U_v=[ele[:, :, :, core_time] for ele in self.post_U_v],
                        order="first",
                        sum_2_scaler=True,
                        device=self.device,
                        product_method=self.product_method,
                    )
                except ImportError:
                    # Simple implementation
                    core_pred = self._single_prediction(core_indices, core_time)
                
                # Generate prediction for current test data point
                try:
                    import utils_streaming
                    test_pred = utils_streaming.moment_product(
                        modes=list(range(self.nmods)),
                        ind=np.array([indices]),
                        U_m=[ele[:, :, :, time] for ele in self.post_U_m],
                        U_v=[ele[:, :, :, time] for ele in self.post_U_v],
                        order="first",
                        sum_2_scaler=True,
                        device=self.device,
                        product_method=self.product_method,
                    )
                except ImportError:
                    # Simple implementation
                    test_pred = self._single_prediction(indices, time)
                
                # Mix prediction results based on similarity
                similarity = 1.0 - min_distance
                pred[i] = similarity * core_pred.item() + (1 - similarity) * test_pred.item()
            else:
                # If no similar coreset data point found, use standard model
                try:
                    import utils_streaming
                    test_pred = utils_streaming.moment_product(
                        modes=list(range(self.nmods)),
                        ind=np.array([indices]),
                        U_m=[ele[:, :, :, time] for ele in self.post_U_m],
                        U_v=[ele[:, :, :, time] for ele in self.post_U_v],
                        order="first",
                        sum_2_scaler=True,
                        device=self.device,
                        product_method=self.product_method,
                    )
                except ImportError:
                    # Simple implementation
                    test_pred = self._single_prediction(indices, time)
                pred[i] = test_pred.item()
        
        return pred
    
    def _single_prediction(self, indices, time):
        """Make prediction for a single data point"""
        # Get factor means for each mode
        factors = []
        for mode in range(self.nmods):
            idx = indices[mode]
            factors.append(self.post_U_m[mode][idx, :, :, time])
        
        # CP prediction: product of factors summed over rank
        prod = torch.ones(self.R_U, 1, device=self.device)
        for f in factors:
            prod = prod * f
        
        return torch.sum(prod)


class DCTF_Tucker(DCTF_CP):
    """
    Dynamic CoreSet Tensor Factorization Tucker form
    Tucker model with data point coreset support
    """
    def __init__(self, hyper_dict, data_dict):
        """Initialize Tucker model"""
        super().__init__(hyper_dict, data_dict)
        
        self.DAMPING_gamma = hyper_dict.get("DAMPING_gamma", 0.5)
        
        self.product_method = "kronecker"
        self.nmod_list = [self.R_U for k in range(self.nmods)]
        
        """Vectorized Tucker-Core llk-msg and post"""
        self.gamma_size = np.prod([self.nmod_list])  # R_U^{K}
        
        # gamma posterior
        self.post_gamma_m = (torch.rand(self.gamma_size,
                                      1).double().to(self.device))  # (R^K)*1
        self.post_gamma_v = (torch.eye(self.gamma_size).double().to(
            self.device))  # (R^K)*(R^K)
            
        logger.info("Initialized DCTF_Tucker model")
    
    def product_with_gamma(self, E_z, E_z_2, mode):
        """Multiply with gamma: for tucker, gamma is folded tucker core, we do tensor-matrix product here"""
        E_gamma_tensor = tl.tensor(self.post_gamma_m.reshape(
            self.nmod_list))  # (R^k *1)-> (R * R * R ...)
        E_gamma_mat_k = tl.unfold(E_gamma_tensor, mode).double()
        
        # Some intermediate terms (calculate E_a_2 = gamma_fold * z\z\.T * gamma_fold.T)
        term1 = torch.matmul(E_z_2, E_gamma_mat_k.T)  # N * R_U^{K-1} * R_U
        E_a_2 = torch.matmul(term1.transpose(dim0=1, dim1=2),
                           E_gamma_mat_k.T).transpose(
                               dim0=1, dim1=2)  # N * R_U * R_U
        
        # Calculate E_a = gamma_fold * z\
        E_a = torch.matmul(E_z.transpose(dim0=1, dim1=2),
                         E_gamma_mat_k.T).transpose(
                             dim0=1, dim1=2)  # num_eid * R_U * 1
        
        return E_a, E_a_2
    
    def msg_llk_init(self):
        """Initialize llk-msg for CEP inner loop, for Tucker including U, tau, and gamma msgs"""
        N_T = len(self.y_T)  # Current time step entry count
        
        # Initialize msg_U_llk using natural parameters: lam = S_inv, eta = S_inv x m
        self.msg_U_lam_llk = [
            1e-3 * torch.eye(self.R_U).reshape((1, self.R_U, self.R_U)).repeat(
                N_T, 1, 1).double().to(self.device) for i in range(self.nmods)
        ]  # (N*R_U*R_U)*nmod
        self.msg_U_eta_llk = [
            1e-3 * torch.rand(N_T, self.R_U, 1).double().to(self.device)
            for i in range(self.nmods)
        ]  # (N*R_U*1)*nmod
        
        # tau messages
        self.msg_a = torch.ones(N_T, 1).double().to(self.device)  # N*1
        self.msg_b = torch.ones(N_T, 1).double().to(self.device)  # N*1
        
        # gamma message initialization
        self.msg_gamma_lam = 1e-4 * torch.eye(self.gamma_size).reshape(
            (1, self.gamma_size, self.gamma_size)).repeat(
                self.N_T, 1, 1).double().to(self.device)  # N*(R^K)*(R^K)
        self.msg_gamma_eta = 1e-4 * torch.rand(self.N_T, self.gamma_size,
                                             1).double().to(self.device)
    
    def msg_approx_gamma(self, T):
        """Approximate gamma message"""
        all_modes = [i for i in range(self.nmods)]
        
        try:
            import utils_streaming
            E_z, E_z_2 = utils_streaming.moment_product(
                modes=all_modes,
                ind=self.ind_T,
                U_m=[ele[:, :, :, T] for ele in self.post_U_m],
                U_v=[ele[:, :, :, T] for ele in self.post_U_v],
                order="second",
                sum_2_scaler=False,
                device=self.device,
                product_method=self.product_method,
            )
        except ImportError:
            # Simple implementation if utils_streaming not available
            E_z, E_z_2 = self._kronecker_moment_product(
                modes=all_modes,
                ind=self.ind_T,
                U_m=[ele[:, :, :, T] for ele in self.post_U_m],
                U_v=[ele[:, :, :, T] for ele in self.post_U_v],
                T=T
            )
        
        # Apply multi-scale weights
        weights = self.get_scale_weights(T)
        
        msg_gamma_lam_new = self.E_tau * E_z_2  # N*(R^K)*(R^K)
        
        msg_gamma_eta_new = self.E_tau * E_z * self.y_T.reshape(-1, 1, 1)  # N*(R^K)*1
        
        # Distinguish between coreset and non-coreset data points, apply different weights
        for i in range(len(self.ind_T)):
            # Check if data point is in coreset
            indices = self.ind_T[i]
            is_in_coreset = self.coreset_manager.is_in_coreset(indices)
            
            # Apply different weights to coreset and non-coreset data points
            if is_in_coreset:
                # Standard weights for coreset data points
                pass
            else:
                # Use smaller weights for non-coreset data points
                msg_gamma_lam_new[i] = msg_gamma_lam_new[i] * 0.9
                msg_gamma_eta_new[i] = msg_gamma_eta_new[i] * 0.9
        
        # Apply multi-scale weights
        if weights is not None and weights.numel() > 0:
            try:
                # Use only global scale weight to simplify implementation
                if weights.shape[0] > 0:
                    scale_weight = weights[0].item()
                    if scale_weight > 0:
                        msg_gamma_lam_new = scale_weight * msg_gamma_lam_new
                        msg_gamma_eta_new = scale_weight * msg_gamma_eta_new
            except Exception as e:
                logger.error(f"Error applying gamma weights: {e}")
        
        self.msg_gamma_lam = (self.DAMPING_gamma * self.msg_gamma_lam +
                            (1 - self.DAMPING_gamma) * msg_gamma_lam_new
                            )  # N*(R^K)*(R^K)
        self.msg_gamma_eta = (self.DAMPING_gamma * self.msg_gamma_eta +
                            (1 - self.DAMPING_gamma) * msg_gamma_eta_new
                            )  # N*(R^K)*1
    
    def _kronecker_moment_product(self, modes, ind, U_m, U_v, T):
        """Simple implementation of kronecker moment product if utils_streaming not available"""
        N = len(ind)
        total_R = np.prod([self.R_U for _ in range(len(modes))])
        E_z = torch.zeros(N, total_R, 1, device=self.device)
        E_z_2 = torch.zeros(N, total_R, total_R, device=self.device)
        
        # For each data point
        for i in range(N):
            # For Tucker, we do Kronecker product of factors
            # Get means and covariances for each mode
            means = []
            covs = []
            for mode in modes:
                idx = ind[i][mode]
                m = U_m[mode][idx]
                v = U_v[mode][idx]
                means.append(m)
                covs.append(v)
            
            # Calculate Kronecker product (simplified)
            # This is a very simplified implementation and might not handle all cases correctly
            # For a real implementation, utils_streaming.moment_product should be used
            m_kron = means[0]
            v_kron = covs[0]
            for j in range(1, len(means)):
                m_kron = torch.kron(m_kron, means[j])
                v_kron = torch.kron(v_kron, covs[j])
            
            E_z[i] = m_kron.reshape(-1, 1)
            E_z_2[i] = v_kron + torch.mm(m_kron.reshape(-1, 1), m_kron.reshape(1, -1))
        
        return E_z, E_z_2
    
    def msg_approx_tau(self, T):
        """Approximate msg for tau"""
        all_modes = [i for i in range(self.nmods)]
        
        try:
            import utils_streaming
            E_z, E_z_2 = utils_streaming.moment_product(
                modes=all_modes,
                ind=self.ind_T,
                U_m=[ele[:, :, :, T] for ele in self.post_U_m],
                U_v=[ele[:, :, :, T] for ele in self.post_U_v],
                order="second",
                sum_2_scaler=False,
                device=self.device,
                product_method=self.product_method,
            )
        except ImportError:
            # Simple implementation if utils_streaming not available
            E_z, E_z_2 = self._kronecker_moment_product(
                modes=all_modes,
                ind=self.ind_T,
                U_m=[ele[:, :, :, T] for ele in self.post_U_m],
                U_v=[ele[:, :, :, T] for ele in self.post_U_v],
                T=T
            )
        
        self.msg_a = 1.5 * torch.ones(self.N_T, 1).to(self.device)
        
        term1 = 0.5 * torch.square(self.y_T)  # N_T*1
        
        term2 = self.y_T.reshape(-1, 1) * torch.matmul(
            E_z.transpose(dim0=1, dim1=2),
            self.post_gamma_m).reshape(-1, 1)  # N_T*1
        
        temp = torch.matmul(E_z_2, self.post_gamma_m)  # N*R*1
        term3 = 0.5 * torch.matmul(temp.transpose(dim0=1, dim1=2),
                                  self.post_gamma_m).reshape(-1, 1)  # N*1
        
        self.msg_b = self.DAMPING_tau * self.msg_b + (1 - self.DAMPING_tau) * (
            term1.reshape(-1, 1) - term2.reshape(-1, 1) + term3.reshape(-1, 1)
        )  # N*1
    
    def post_update_gamma(self, T=None):
        """Update gamma posterior"""
        post_gamma_lam = torch.linalg.inv(self.post_gamma_v)
        post_gamma_eta = torch.mm(post_gamma_lam, self.post_gamma_m)
        
        self.post_gamma_v = torch.linalg.inv(
            self.msg_gamma_lam.sum(dim=0) + post_gamma_lam)  # (R^K) * (R^K)
        
        self.post_gamma_m = torch.mm(
            self.post_gamma_v, post_gamma_eta + self.msg_gamma_eta.sum(dim=0))
    
    def model_test_coreset(self, test_ind, test_time):
        """
        Test prediction function for coreset data points
        Uses full Tucker decomposition model for coreset data points
        """
        all_modes = [i for i in range(self.nmods)]
        
        try:
            import utils_streaming
            E_z = utils_streaming.moment_product_T(
                modes=all_modes,
                ind=test_ind,
                ind_T=test_time,
                U_m_T=self.post_U_m,
                U_v_T=self.post_U_v,
                order="first",
                sum_2_scaler=False,
                device=self.device,
                product_method=self.product_method,
            )
        except ImportError:
            # Simple implementation if utils_streaming not available
            E_z = self._kronecker_moment_product_T(
                test_ind=test_ind,
                test_time=test_time
            )
        
        pred = torch.matmul(E_z.transpose(dim0=1, dim1=2),
                          self.post_gamma_m).squeeze()
        
        return pred
    
    def _kronecker_moment_product_T(self, test_ind, test_time):
        """Simple implementation of kronecker moment product T if utils_streaming not available"""
        N = len(test_ind)
        total_R = np.prod([self.R_U for _ in range(self.nmods)])
        E_z = torch.zeros(N, total_R, 1, device=self.device)
        
        # For each test point
        for i in range(N):
            indices = test_ind[i]
            time = test_time[i]
            
            # Get factor means for each mode
            means = []
            for mode in range(self.nmods):
                idx = indices[mode]
                means.append(self.post_U_m[mode][idx, :, :, time])
            
            # Calculate Kronecker product (simplified)
            m_kron = means[0]
            for j in range(1, len(means)):
                m_kron = torch.kron(m_kron, means[j])
            
            E_z[i] = m_kron.reshape(-1, 1)
        
        return E_z
    
    def model_test_noncore(self, test_ind, test_time):
        """
        Test prediction function for non-coreset data points
        Uses approximate model based on coreset data points
        """
        # Get data points in coreset
        coreset_data = self.coreset_manager.get_coreset_data()
        
        if not coreset_data:
            # If coreset is empty, use standard test method
            return self.model_test_coreset(test_ind, test_time)
        
        # For each test data point, find most similar coreset data point
        pred = torch.zeros(len(test_ind), device=self.device)
        
        for i, (indices, time) in enumerate(zip(test_ind, test_time)):
            # Find nearest coreset data point
            nearest_core_idx = -1
            min_distance = float('inf')
            
            for core_idx, (core_indices, _, core_time) in enumerate(coreset_data):
                # Calculate index similarity: number of matching indices
                common_indices = sum(1 for a, b in zip(indices, core_indices) if a == b)
                index_sim = common_indices / len(indices)
                
                # Calculate time similarity: closeness of time steps
                time_diff = abs(time - core_time)
                time_sim = math.exp(-0.1 * time_diff)  # Smaller time diff means higher similarity
                
                # Combined similarity
                similarity = 0.7 * index_sim + 0.3 * time_sim
                distance = 1.0 - similarity
                
                if distance < min_distance:
                    min_distance = distance
                    nearest_core_idx = core_idx
            
            # If found most similar coreset data point, use its mapping relationship for prediction
            if nearest_core_idx >= 0:
                core_indices, _, core_time = coreset_data[nearest_core_idx]
                
                # Get tensor representation for test data point
                try:
                    import utils_streaming
                    E_z_test = utils_streaming.moment_product(
                        modes=list(range(self.nmods)),
                        ind=np.array([indices]),
                        U_m=[ele[:, :, :, time] for ele in self.post_U_m],
                        U_v=[ele[:, :, :, time] for ele in self.post_U_v],
                        order="first",
                        sum_2_scaler=False,
                        device=self.device,
                        product_method=self.product_method,
                    )
                except ImportError:
                    # Simple implementation
                    E_z_test = self._kronecker_single_prediction(indices, time)
                
                # Get tensor representation for coreset data point
                try:
                    import utils_streaming
                    E_z_core = utils_streaming.moment_product(
                        modes=list(range(self.nmods)),
                        ind=np.array([core_indices]),
                        U_m=[ele[:, :, :, core_time] for ele in self.post_U_m],
                        U_v=[ele[:, :, :, core_time] for ele in self.post_U_v],
                        order="first",
                        sum_2_scaler=False,
                        device=self.device,
                        product_method=self.product_method,
                    )
                except ImportError:
                    # Simple implementation
                    E_z_core = self._kronecker_single_prediction(core_indices, core_time)
                
                # Use coreset data point's tensor representation and tucker core for prediction
                pred_core = torch.matmul(E_z_core.transpose(dim0=1, dim1=2),
                                       self.post_gamma_m).squeeze()
                
                # Use test data point's tensor representation and tucker core for prediction
                pred_test = torch.matmul(E_z_test.transpose(dim0=1, dim1=2),
                                      self.post_gamma_m).squeeze()
                
                # Mix prediction results based on similarity
                similarity = 1.0 - min_distance
                pred[i] = similarity * pred_core.item() + (1 - similarity) * pred_test.item()
            else:
                # If no similar coreset data point found, use standard model
                try:
                    import utils_streaming
                    E_z_test = utils_streaming.moment_product(
                        modes=list(range(self.nmods)),
                        ind=np.array([indices]),
                        U_m=[ele[:, :, :, time] for ele in self.post_U_m],
                        U_v=[ele[:, :, :, time] for ele in self.post_U_v],
                        order="first",
                        sum_2_scaler=False,
                        device=self.device,
                        product_method=self.product_method,
                    )
                except ImportError:
                    # Simple implementation
                    E_z_test = self._kronecker_single_prediction(indices, time)
                
                pred[i] = torch.matmul(E_z_test.transpose(dim0=1, dim1=2),
                                    self.post_gamma_m).squeeze().item()
        
        return pred
    
    def _kronecker_single_prediction(self, indices, time):
        """Make prediction tensor for a single data point using Kronecker product"""
        total_R = np.prod([self.R_U for _ in range(self.nmods)])
        
        # Get factor means for each mode
        means = []
        for mode in range(self.nmods):
            idx = indices[mode]
            means.append(self.post_U_m[mode][idx, :, :, time])
        
        # Calculate Kronecker product (simplified)
        m_kron = means[0]
        for j in range(1, len(means)):
            m_kron = torch.kron(m_kron, means[j])
        
        return m_kron.reshape(1, -1, 1)


class MartingaleDCTF:
    """
    Martingale Dynamic CoreSet Tensor Factorization base class
    Extends original DCTF with martingale representation theorem and optimal stopping
    """
    def __init__(self, hyper_dict, data_dict):
        """Initialize model"""
        # Call parent initialization in subclasses
        # super().__init__(hyper_dict, data_dict)
        
        # Replace coreset manager with martingale theory based manager
        try:
            from utils_martingale_coreset import MartingaleDataPointCoreSetManager
            self.coreset_manager = MartingaleDataPointCoreSetManager(
                max_size=hyper_dict.get("coreset_max_size", 100),
                initial_threshold=hyper_dict.get("coreset_threshold", 0.5),
                adaptive_threshold=hyper_dict.get("adaptive_threshold", True),
                importance_weights=hyper_dict.get("importance_weights", (0.3, 0.2, 0.2, 0.3)),
                device=self.device,
                exploration_rate=hyper_dict.get("initial_exploration_rate", 0.9),
                decay_rate=hyper_dict.get("exploration_decay_rate", 0.1),
                prediction_history_size=hyper_dict.get("prediction_history_size", 50),
                discount_factor=hyper_dict.get("discount_factor", 0.9),
                simulation_samples=hyper_dict.get("simulation_samples", 5)
            )
        except ImportError:
            logger.error("Could not import MartingaleDataPointCoreSetManager, using base manager")
        
        # Use improved multi-scale weighting mechanism
        try:
            from utils_martingale_coreset import EnhancedMultiScaleWeighting
            self.multi_scale_weighting = EnhancedMultiScaleWeighting(
                num_scales=hyper_dict.get("num_time_scales", 3),
                hidden_dim=hyper_dict.get("scale_hidden_dim", 32),
                device=self.device,
                temperature=hyper_dict.get("attention_temperature", 1.0),
                time_scale_factor=hyper_dict.get("time_scale_factor", 0.1)
            )
        except ImportError:
            logger.error("Could not import EnhancedMultiScaleWeighting, using base weighting")
        
        # Save recent state estimates for martingale increment calculation
        self.recent_state_estimates = []
        self.max_state_history = 20
        
        # Recent prediction errors for adaptive adjustment
        self.recent_prediction_errors = []
        
        # Model confidence estimate for exploration-exploitation balance
        self.model_confidence = 0.5
        
        # Initialize time step related attributes
        self.current_msg_init_time = 0
        
        logger.info(f"Initialized martingale theory based DCTF model")
    
    def filter_predict(self, T):
        """Enhanced filtering prediction step that tracks state estimates"""
        # Call parent method for basic prediction
        super().filter_predict(T)
        
        # Save current state estimate for martingale increment calculation
        if T % 5 == 0:  # Save every 5 steps to reduce storage
            # Save current state snapshot
            state_snapshot = {
                'time': T,
                'post_means': [m[:, :, :, T].clone().detach() if m.size(3) > T else None for m in self.post_U_m],
                'post_vars': [v[:, :, :, T].clone().detach() if v.size(3) > T else None for v in self.post_U_v]
            }
            
            # Add to history
            self.recent_state_estimates.append(state_snapshot)
            
            # Limit history size
            if len(self.recent_state_estimates) > self.max_state_history:
                self.recent_state_estimates.pop(0)
    
    def update_coreset(self, T):
        """Use martingale theory based coreset update mechanism"""
        # Prepare current time step's data batch
        data_batch = []
        
        for i in range(self.N_T):
            indices = self.ind_T[i]
            y = self.y_T[i].item()
            time_ind = T
            
            data_batch.append((indices, y, time_ind))
        
        # Use coreset manager to update coreset
        added, removed = self.coreset_manager.update_coreset(data_batch, self)
        
        # Log coreset size and changes
        if T % 10 == 0:  # Log every 10 time steps
            logger.info(f"T={T}, coreset size: {self.coreset_manager.get_coreset_size()}, " 
                       f"added: {len(added)}, removed: {len(removed)}, "
                       f"model confidence: {self.model_confidence:.4f}")
        
        # Update model confidence
        if T > 0:
            pred_error = self.compute_prediction_error(T)
            self.recent_prediction_errors.append(pred_error)
            if len(self.recent_prediction_errors) > 10:
                self.recent_prediction_errors.pop(0)
            
            # Use exponential smoothing to calculate model confidence
            avg_error = np.mean(self.recent_prediction_errors) if self.recent_prediction_errors else pred_error
            if avg_error > 0:
                new_confidence = 1.0 / (1.0 + avg_error)
                # Smooth update
                self.model_confidence = 0.8 * self.model_confidence + 0.2 * new_confidence
            
            # Update coreset manager's confidence
            self.coreset_manager.update_confidence(self.model_confidence)
        
        return added, removed
    
    def get_scale_weights(self, T, mode=None):
        """Enhanced multi-scale weight calculation using Ito formula time scale transform"""
        try:
            # Collect hidden states
            h_scales = []
            
            # Collect states from different trajectories
            for m in range(self.nmods):
                # Only consider current mode (if specified)
                if mode is not None and m != mode:
                    continue
                    
                # Extract first few objects from this mode
                sample_uids = self.uid_table[m][:3] if m in self.uid_table else []
                
                for uid in sample_uids:
                    if (m, uid) in self.factor_dynamics:
                        factor_process = self.factor_dynamics[(m, uid)]
                        if hasattr(factor_process, 'm') and factor_process.m is not None:
                            # Convert to float type and ensure consistent dimensions
                            state = factor_process.m.clone().detach().to(torch.float32)
                            # Ensure flattened to 1D vector
                            state = state.reshape(-1)
                            h_scales.append(state)
                            
                            # Maximum 3 hidden states
                            if len(h_scales) >= 3:
                                break
                
                # Maximum 3 hidden states per mode
                if len(h_scales) >= 3:
                    break
            
            # If not enough scales, return uniform weights
            if len(h_scales) < 1:
                return torch.ones(1, 1, device=self.device)
                
            # Use enhanced multi-scale weighting to calculate weights
            weights = self.multi_scale_weighting.compute_weights(h_scales)
            return weights
            
        except Exception as e:
            logger.error(f"Error calculating multi-scale weights: {e}")
            # Return uniform weights on error
            return torch.ones(1, 1, device=self.device)
    
    def compute_prediction_error(self, T):
        """Calculate current prediction error, used to update model confidence"""
        if hasattr(self, 'te_ind') and hasattr(self, 'te_y'):
            # Use a small subset of test data
            sample_size = min(50, len(self.te_ind))
            if sample_size > 0:
                try:
                    indices = np.random.choice(len(self.te_ind), sample_size, replace=False)
                    sample_ind = self.te_ind[indices]
                    sample_y = self.te_y[indices]
                    sample_time = np.ones_like(indices) * T
                    
                    # Get prediction and variance
                    pred, _ = self.model_test(sample_ind, sample_y, sample_time)
                    
                    # Calculate MSE
                    mse = torch.mean((pred.squeeze() - sample_y.squeeze()) ** 2)
                    return mse.item()
                except Exception as e:
                    logger.error(f"Error calculating prediction error: {e}")
                    return 1.0
        
        return 1.0
    
    def msg_llk_init(self):
        """Initialize llk-msg for CEP inner loop, adding martingale information"""
        # Call parent method for basic initialization
        super().msg_llk_init()
        
        # Record current message initialization time point
        try:
            # Safely set current_msg_init_time
            if hasattr(self, 'unique_train_time') and len(self.unique_train_time) > 0:
                # Use first time point as default
                self.current_msg_init_time = self.unique_train_time[0]
        except Exception as e:
            logger.error(f"Error setting current_msg_init_time: {e}")
            self.current_msg_init_time = 0
    
    def inner_smooth(self):
        """Enhanced online smoothing, adding martingale theory dynamic weighting"""
        # Call parent method for basic smoothing
        super().inner_smooth()
        
        # If enough state history estimates, can adjust smoothing strength using martingale increments
        if len(self.recent_state_estimates) > 1:
            try:
                # Calculate change rate between consecutive state estimates
                changes = []
                for i in range(1, len(self.recent_state_estimates)):
                    prev = self.recent_state_estimates[i-1]
                    curr = self.recent_state_estimates[i]
                    
                    # Calculate state change rate
                    total_change = 0
                    count = 0
                    
                    for m in range(len(prev['post_means'])):
                        if prev['post_means'][m] is not None and curr['post_means'][m] is not None:
                            # Randomly sample some points for calculation, avoid calculating all points
                            mask = torch.rand_like(prev['post_means'][m][:, 0, 0]) < 0.1
                            if mask.sum() > 0:
                                prev_sample = prev['post_means'][m][mask]
                                curr_sample = curr['post_means'][m][mask]
                                
                                # Calculate relative change
                                if prev_sample.abs().mean() > 1e-6:
                                    change = (curr_sample - prev_sample).abs().mean() / prev_sample.abs().mean()
                                    total_change += change.item()
                                    count += 1
                    
                    if count > 0:
                        changes.append(total_change / count)
                
                # If state change rate large, system in rapid change phase, can adjust smoothing strategy
                if changes and np.mean(changes) > 0.1:
                    # Can implement special handling when states rapidly changing, e.g., reduce smoothing strength
                    # This is just an example, can be adjusted based on actual needs
                    logger.info(f"Detected large state change rate: {np.mean(changes):.4f}, adjusting smoothing strategy")
                    
                    # Special smoothing handling can be implemented here...
            except Exception as e:
                logger.error(f"Error in dynamic smoothing adjustment: {e}")
    
    def model_test(self, test_ind, test_y, test_time):
        """Enhanced model test method, considering martingale increments"""
        # Call parent method for testing
        pred, loss_test = super().model_test(test_ind, test_y, test_time)
        
        # Save recent prediction results for martingale increment calculation
        if hasattr(self, 'recent_predictions'):
            # Limit size
            if len(self.recent_predictions) > 100:
                self.recent_predictions.pop(0)
                
            # Save a sample
            if len(test_ind) > 10:
                indices = np.random.choice(len(test_ind), 10, replace=False)
                self.recent_predictions.append({
                    'time': test_time[0] if test_time.shape else test_time,
                    'indices': test_ind[indices],
                    'true': test_y[indices].clone().detach() if isinstance(test_y, torch.Tensor) else test_y[indices],
                    'pred': pred[indices].clone().detach() if isinstance(pred, torch.Tensor) else pred[indices]
                })
        else:
            self.recent_predictions = []
            
        return pred, loss_test


class MartingaleDCTF_CP(MartingaleDCTF, DCTF_CP):
    """
    Martingale Dynamic CoreSet Tensor Factorization CP form
    """
    def __init__(self, hyper_dict, data_dict):
        """Initialize CP model"""
        # First call DCTF_CP initialization
        DCTF_CP.__init__(self, hyper_dict, data_dict)
        
        # Then replace coreset manager and multi-scale weighting mechanism
        try:
            from utils_martingale_coreset import MartingaleDataPointCoreSetManager
            self.coreset_manager = MartingaleDataPointCoreSetManager(
                max_size=hyper_dict.get("coreset_max_size", 100),
                initial_threshold=hyper_dict.get("coreset_threshold", 0.5),
                adaptive_threshold=hyper_dict.get("adaptive_threshold", True),
                importance_weights=hyper_dict.get("importance_weights", (0.3, 0.2, 0.2, 0.3)),
                device=self.device,
                exploration_rate=hyper_dict.get("initial_exploration_rate", 0.9),
                decay_rate=hyper_dict.get("exploration_decay_rate", 0.1),
                prediction_history_size=hyper_dict.get("prediction_history_size", 50),
                discount_factor=hyper_dict.get("discount_factor", 0.9),
                simulation_samples=hyper_dict.get("simulation_samples", 5)
            )
        except ImportError:
            logger.error("Could not import MartingaleDataPointCoreSetManager, using base manager")
        
        try:
            from utils_martingale_coreset import EnhancedMultiScaleWeighting
            self.multi_scale_weighting = EnhancedMultiScaleWeighting(
                num_scales=hyper_dict.get("num_time_scales", 3),
                hidden_dim=hyper_dict.get("scale_hidden_dim", 32),
                device=self.device,
                temperature=hyper_dict.get("attention_temperature", 1.0),
                time_scale_factor=hyper_dict.get("time_scale_factor", 0.1)
            )
        except ImportError:
            logger.error("Could not import EnhancedMultiScaleWeighting, using base weighting")
        
        # Initialize other martingale theory related variables
        self.recent_state_estimates = []
        self.max_state_history = 20
        self.recent_prediction_errors = []
        self.model_confidence = 0.5
        
        # Initialize time step related attributes
        self.current_msg_init_time = 0
        
        logger.info(f"Initialized martingale theory based DCTF_CP: R_U={self.R_U}, nmods={self.nmods}")
    
    def model_test_coreset(self, test_ind, test_time):
        """Enhanced coreset data points test method"""
        # Basic test logic same as parent class
        pred = super().model_test_coreset(test_ind, test_time)
        
        # If enough state history estimates, can use martingale increments to adjust prediction
        if len(self.recent_state_estimates) > 1 and len(self.recent_prediction_errors) > 0:
            try:
                # Get recent prediction error trend
                error_trend = np.mean(self.recent_prediction_errors)
                
                # If error trend large, can adjust prediction
                if error_trend > 0.5:  # Threshold adjustable
                    # Calculate adjustment factor
                    adjust_factor = 1.0 - 0.1 * min(1.0, error_trend)  # Max 10% adjustment
                    
                    # Apply adjustment (simple linear adjustment, can use more complex strategy if needed)
                    pred = pred * adjust_factor
                    
                    # if test_time[0] % 20 == 0:  # Avoid excessive logging
                    #     logger.info(f"Prediction adjustment based on error trend: factor={adjust_factor:.4f}")
            except Exception as e:
                logger.error(f"Error in prediction adjustment: {e}")
        
        return pred
    
    def model_test_noncore(self, test_ind, test_time):
        """Enhanced non-coreset data points test method"""
        # Call parent method
        pred = super().model_test_noncore(test_ind, test_time)
        
        # Similar to model_test_coreset, can add martingale theory based adjustment here
        if len(self.recent_prediction_errors) > 0:
            try:
                # Non-coreset data points predictions typically have higher uncertainty
                # So adjustment may be stronger
                error_trend = np.mean(self.recent_prediction_errors)
                
                if error_trend > 0.3:  # Use smaller threshold for non-coreset
                    # More cautious adjustment for non-coreset data points
                    adjust_factor = 1.0 - 0.15 * min(1.0, error_trend)  # Max 15% adjustment
                    
                    # Apply adjustment
                    pred = pred * adjust_factor
                    
                    # if test_time[0] % 20 == 0:  # Avoid excessive logging
                    #     logger.info(f"Non-coreset prediction adjustment: factor={adjust_factor:.4f}")
            except Exception as e:
                logger.error(f"Error in non-coreset prediction adjustment: {e}")
        
        return pred


class MartingaleDCTF_Tucker(MartingaleDCTF, DCTF_Tucker):
    """
    Martingale Dynamic CoreSet Tensor Factorization Tucker form
    """
    def __init__(self, hyper_dict, data_dict):
        """Initialize Tucker model"""
        # First call DCTF_Tucker initialization
        DCTF_Tucker.__init__(self, hyper_dict, data_dict)
        
        # Then replace coreset manager and multi-scale weighting mechanism
        try:
            from utils_martingale_coreset import MartingaleDataPointCoreSetManager
            self.coreset_manager = MartingaleDataPointCoreSetManager(
                max_size=hyper_dict.get("coreset_max_size", 100),
                initial_threshold=hyper_dict.get("coreset_threshold", 0.5),
                adaptive_threshold=hyper_dict.get("adaptive_threshold", True),
                importance_weights=hyper_dict.get("importance_weights", (0.3, 0.2, 0.2, 0.3)),
                device=self.device,
                exploration_rate=hyper_dict.get("initial_exploration_rate", 0.9),
                decay_rate=hyper_dict.get("exploration_decay_rate", 0.1),
                prediction_history_size=hyper_dict.get("prediction_history_size", 50),
                discount_factor=hyper_dict.get("discount_factor", 0.9),
                simulation_samples=hyper_dict.get("simulation_samples", 5)
            )
        except ImportError:
            logger.error("Could not import MartingaleDataPointCoreSetManager, using base manager")
        
        try:
            from utils_martingale_coreset import EnhancedMultiScaleWeighting
            self.multi_scale_weighting = EnhancedMultiScaleWeighting(
                num_scales=hyper_dict.get("num_time_scales", 3),
                hidden_dim=hyper_dict.get("scale_hidden_dim", 32),
                device=self.device,
                temperature=hyper_dict.get("attention_temperature", 1.0),
                time_scale_factor=hyper_dict.get("time_scale_factor", 0.1)
            )
        except ImportError:
            logger.error("Could not import EnhancedMultiScaleWeighting, using base weighting")
        
        # Initialize other martingale theory related variables
        self.recent_state_estimates = []
        self.max_state_history = 20
        self.recent_prediction_errors = []
        self.model_confidence = 0.5
        
        # Initialize time step related attributes
        self.current_msg_init_time = 0
        
        logger.info(f"Initialized martingale theory based DCTF_Tucker: R_U={self.R_U}, nmods={self.nmods}")

    def post_update_gamma(self, T=None):
        """Enhanced gamma posterior update, considering state change rate"""
        # Call parent method for basic update
        super().post_update_gamma(T)
        
        # If enough state history estimates, can use change rate information for adjustment
        if len(self.recent_state_estimates) > 1:
            try:
                # Calculate change rate between most recent two states
                prev = self.recent_state_estimates[-2]
                curr = self.recent_state_estimates[-1]
                
                # State change rate
                total_change = 0
                count = 0
                
                for m in range(len(prev['post_means'])):
                    if prev['post_means'][m] is not None and curr['post_means'][m] is not None:
                        # Random sampling for calculation
                        mask = torch.rand_like(prev['post_means'][m][:, 0, 0]) < 0.1
                        if mask.sum() > 0:
                            prev_sample = prev['post_means'][m][mask]
                            curr_sample = curr['post_means'][m][mask]
                            
                            # Calculate relative change
                            if prev_sample.abs().mean() > 1e-6:
                                change = (curr_sample - prev_sample).abs().mean() / prev_sample.abs().mean()
                                total_change += change.item()
                                count += 1
                
                if count > 0:
                    change_rate = total_change / count
                    
                    # If change rate greater than threshold, system in rapid change phase
                    # Can adaptively adjust gamma
                    if change_rate > 0.2:  # Threshold adjustable (Tucker: 5)
                        # E.g., increase gamma variance to make model more sensitive to changes
                        variance_factor = 1.0 + 0.5 * min(1.0, change_rate)  # Max 50% increase
                        
                        # Apply to gamma posterior variance
                        self.post_gamma_v = self.post_gamma_v * variance_factor
                        
                        if T is not None and T % 20 == 0:  # Avoid excessive logging
                            logger.info(f"Detected large state change: {change_rate:.4f}, adjusting gamma variance: {variance_factor:.4f}")
            except Exception as e:
                logger.error(f"Error in gamma adaptive adjustment: {e}")
    
    def model_test_coreset(self, test_ind, test_time):
        """Enhanced coreset data points test method"""
        # Basic test logic same as parent class
        pred = super().model_test_coreset(test_ind, test_time)
        
        # Add similar adjustment logic as in MartingaleDCTF_CP
        if len(self.recent_prediction_errors) > 0:
            try:
                error_trend = np.mean(self.recent_prediction_errors)
                
                if error_trend > 0.5:
                    adjust_factor = 1.0 - 0.1 * min(1.0, error_trend)
                    pred = pred * adjust_factor
                    
                    # if test_time[0] % 20 == 0:
                    #     logger.info(f"Tucker coreset prediction adjustment: factor={adjust_factor:.4f}")
            except Exception as e:
                logger.error(f"Error in Tucker coreset prediction adjustment: {e}")
        
        return pred
    
    def model_test_noncore(self, test_ind, test_time):
        """Enhanced non-coreset data points test method"""
        pred = super().model_test_noncore(test_ind, test_time)
        
        # Add similar adjustment logic as in MartingaleDCTF_CP
        if len(self.recent_prediction_errors) > 0:
            try:
                error_trend = np.mean(self.recent_prediction_errors)
                
                if error_trend > 0.3:
                    adjust_factor = 1.0 - 0.15 * min(1.0, error_trend)
                    pred = pred * adjust_factor
                    
                    # if test_time[0] % 20 == 0:
                    #     logger.info(f"Tucker non-coreset prediction adjustment: factor={adjust_factor:.4f}")
            except Exception as e:
                logger.error(f"Error in Tucker non-coreset prediction adjustment: {e}")
        
        return pred

# Factory function for easier model creation
def create_martingale_dctf(hyper_dict, data_dict):
    """Create martingale theory based DCTF model
    
    Args:
        hyper_dict: Hyperparameter dictionary
        data_dict: Data dictionary
        
    Returns:
        Martingale DCTF model instance
    """
    method = hyper_dict.get("method", "CP").upper()
    
    if method == "CP":
        return MartingaleDCTF_CP(hyper_dict, data_dict)
    elif method == "TUCKER":
        return MartingaleDCTF_Tucker(hyper_dict, data_dict)
    else:
        raise ValueError(f"Unsupported method: {method}")


# Simplified LDS implementation for when model_LDS is not available
class SimplifiedLDS:
    """
    Simplified Linear Dynamical System for when model_LDS is not available
    Provides basic functionality needed by DCTF models
    """
    def __init__(self, params):
        self.F = params.get("F", torch.eye(params["m_0"].shape[0]))
        self.H = params.get("H", torch.eye(params["m_0"].shape[0]))
        self.P_inf = params.get("P_inf", torch.eye(params["m_0"].shape[0]))
        self.P_0 = params.get("P_0", self.P_inf)
        self.m_0 = params.get("m_0", torch.zeros(params["m_0"].shape))
        self.R = params.get("R", torch.tensor(1.0))
        self.device = params.get("device", torch.device("cpu"))
        
        # State variables
        self.m = self.m_0.clone()
        self.P = self.P_0.clone()
        
        # Storage for trajectories
        self.time_stamp_list = []
        self.time_2_ind_table = {}
        self.m_pred_list = []
        self.P_pred_list = []
        self.m_filt_list = []
        self.P_filt_list = []
        self.m_smooth_list = []
        self.P_smooth_list = []
    
    def filter_predict(self, time_stamp):
        """KF prediction step"""
        # Add timestamp if new
        if time_stamp not in self.time_stamp_list:
            self.time_stamp_list.append(time_stamp)
            self.time_stamp_list.sort()
            # Update mapping from timestamp to index
            self.time_2_ind_table = {t: i for i, t in enumerate(self.time_stamp_list)}
        
        # Simple predict: m_pred = F * m, P_pred = F*P*F^T + Q
        time_idx = self.time_2_ind_table[time_stamp]
        
        # If this is the first prediction, use initial state
        if len(self.m_filt_list) == 0:
            m_prev = self.m_0
            P_prev = self.P_0
        else:
            # Otherwise use the latest filtered state
            m_prev = self.m_filt_list[-1]
            P_prev = self.P_filt_list[-1]
        
        # Predict
        m_pred = torch.mm(self.F, m_prev)
        P_pred = torch.mm(torch.mm(self.F, P_prev), self.F.t())
        
        # Store results
        self.m_pred_list.append(m_pred.clone())
        self.P_pred_list.append(P_pred.clone())
    
    def filter_update(self, y, R, add_to_list=True):
        """KF update step"""
        # Get the latest predicted state
        if len(self.m_pred_list) == 0:
            # If no prediction yet, use initial state
            m_pred = self.m_0
            P_pred = self.P_0
        else:
            m_pred = self.m_pred_list[-1]
            P_pred = self.P_pred_list[-1]
        
        # Calculate Kalman gain
        PHt = torch.mm(P_pred, self.H.t())
        S = torch.mm(torch.mm(self.H, P_pred), self.H.t()) + R
        try:
            K = torch.mm(PHt, torch.inverse(S))
        except:
            # Add jitter if inversion fails
            jitter = 1e-5 * torch.eye(S.shape[0], device=S.device)
            K = torch.mm(PHt, torch.inverse(S + jitter))
        
        # Update
        innovation = y - torch.mm(self.H, m_pred)
        m_filt = m_pred + torch.mm(K, innovation)
        P_filt = P_pred - torch.mm(K, torch.mm(self.H, P_pred))
        
        # Ensure P is symmetric
        P_filt = 0.5 * (P_filt + P_filt.t())
        
        # Store current state
        self.m = m_filt.clone()
        self.P = P_filt.clone()
        
        # Add to lists if requested
        if add_to_list:
            self.m_filt_list.append(m_filt.clone())
            self.P_filt_list.append(P_filt.clone())
    
    def smooth(self):
        """RTS smoother"""
        # Initialize with the last filtered state
        if len(self.m_filt_list) == 0:
            return  # Nothing to smooth
        
        self.m_smooth_list = [self.m_filt_list[-1].clone()]
        self.P_smooth_list = [self.P_filt_list[-1].clone()]
        
        # Backward pass
        for t in range(len(self.m_filt_list) - 2, -1, -1):
            m_filt_t = self.m_filt_list[t]
            P_filt_t = self.P_filt_list[t]
            m_pred_t1 = self.m_pred_list[t + 1]
            P_pred_t1 = self.P_pred_list[t + 1]
            m_smooth_t1 = self.m_smooth_list[0]  # Latest smoothed state
            P_smooth_t1 = self.P_smooth_list[0]  # Latest smoothed covariance
            
            # Calculate smoother gain
            try:
                J_t = torch.mm(torch.mm(P_filt_t, self.F.t()), torch.inverse(P_pred_t1))
            except:
                # Add jitter if inversion fails
                jitter = 1e-5 * torch.eye(P_pred_t1.shape[0], device=P_pred_t1.device)
                J_t = torch.mm(torch.mm(P_filt_t, self.F.t()), torch.inverse(P_pred_t1 + jitter))
            
            # Smooth
            m_smooth_t = m_filt_t + torch.mm(J_t, m_smooth_t1 - m_pred_t1)
            P_smooth_t = P_filt_t + torch.mm(J_t, torch.mm(P_smooth_t1 - P_pred_t1, J_t.t()))
            
            # Ensure P is symmetric
            P_smooth_t = 0.5 * (P_smooth_t + P_smooth_t.t())
            
            # Insert at beginning to maintain reverse order
            self.m_smooth_list.insert(0, m_smooth_t.clone())
            self.P_smooth_list.insert(0, P_smooth_t.clone())
    
    def reset_list(self):
        """Reset all trajectory lists"""
        self.time_stamp_list = []
        self.time_2_ind_table = {}
        self.m_pred_list = []
        self.P_pred_list = []
        self.m_filt_list = []
        self.P_filt_list = []
        self.reset_smooth_list()
    
    def reset_smooth_list(self):
        """Reset only smooth lists"""
        self.m_smooth_list = []
        self.P_smooth_list = []


# Import sys for module checking
import sys