"""
Martingale Coreset Selection Utilities

Enhanced coreset selection mechanism with:
- Martingale increment scoring
- Optimal stopping strategies
- Bellman equation optimization
- Multi-scale time weighting with Ito formula
- Improved replacement strategy with fill ratio threshold
"""

import numpy as np
import torch
import math
import logging
from collections import deque
import heapq

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


class DataPointCoreSetManager:
    """
    Core data point selection manager for tensor factorization
    Selects important data points based on uncertainty, influence, and novelty
    """
    
    def __init__(self, max_size=100, initial_threshold=0.5, adaptive_threshold=True, 
                 importance_weights=(0.4, 0.3, 0.3), device=torch.device("cpu"), 
                 exploration_rate=0.9, decay_rate=0.1, batch_replace_size=5,
                 fill_ratio_threshold=0.8):
        """
        Initialize the data point coreset manager
        
        Args:
            max_size: Maximum coreset size
            initial_threshold: Initial importance threshold
            adaptive_threshold: Whether to use adaptive thresholding
            importance_weights: Weights for (uncertainty, influence, novelty)
            device: Computation device
            exploration_rate: Initial exploration rate
            decay_rate: Exploration rate decay coefficient
            batch_replace_size: Number of data points to replace in a batch
            fill_ratio_threshold: Ratio threshold to start replacement (e.g., 0.8 means start at 80% capacity)
        """
        self.coreset = []  # Stores coreset data points [(indices, y, time_ind, score), ...]
        self.coreset_indices = set()  # Hash set of coreset data point indices
        self.max_size = max_size
        self.threshold = initial_threshold
        self.adaptive = adaptive_threshold
        self.weights = importance_weights  # α, β, γ weights
        self.device = device
        self.batch_replace_size = batch_replace_size
        self.fill_ratio_threshold = fill_ratio_threshold
        
        # Calculate the size threshold at which to start replacement
        self.size_threshold = int(self.max_size * self.fill_ratio_threshold)
        
        # Exploration-exploitation balance parameters
        self.epsilon_0 = exploration_rate  # Initial exploration rate ϵ0
        self.lambda_decay = decay_rate  # Exploration rate decay coefficient λ
        self.confidence = 0.0  # Current model confidence
        
        # History for novelty and influence calculations
        self.historical_points = []
        self.max_history_size = 1000  # Limit history size
        
        logger.info(f"Data point coreset manager initialized: max_size={max_size}, "
                   f"threshold={initial_threshold}, fill_ratio_threshold={fill_ratio_threshold}")
    
    def compute_importance_score(self, indices, y, time_ind, model):
        """
        Calculate importance score for a data point
        
        Args:
            indices: Data point indices (ℓn)
            y: Observed value (yn)
            time_ind: Timestamp index (tn)
            model: DCTF model for uncertainty metrics
            
        Returns:
            Importance score
        """
        # Get weights
        w_u, w_i, w_n = self.weights  # α, β, γ
        
        # Calculate uncertainty
        uncertainty = self._compute_uncertainty(indices, time_ind, model)
        
        # Calculate influence
        influence = self._compute_influence(indices, time_ind, model)
        
        # Calculate novelty
        novelty = self._compute_novelty(indices, y, time_ind)
        
        # Combined score
        score = w_u * uncertainty + w_i * influence + w_n * novelty
        
        return score.item() if isinstance(score, torch.Tensor) else score
    
    def _compute_uncertainty(self, indices, time_ind, model):
        """
        Calculate data point uncertainty
        Uses prediction variance as uncertainty measure
        
        Args:
            indices: Data point indices
            time_ind: Timestamp index
            model: DCTF model
            
        Returns:
            Uncertainty score
        """
        try:
            # Get factor variance for each mode
            variances = []
            for mode, idx in enumerate(indices):
                if hasattr(model, 'post_U_v') and model.post_U_v:
                    # Get factor variance at this time point
                    var = torch.diagonal(model.post_U_v[mode][idx, :, :, time_ind]).mean()
                    variances.append(var.item())
            
            # If variances can't be obtained, use estimated prediction error
            if not variances:
                return 1.0
                
            # Use average variance across all modes as data point uncertainty
            return np.mean(variances)
            
        except Exception as e:
            logger.error(f"Error computing uncertainty: {e}")
            return 1.0  # Return default value on error
    
    def _compute_influence(self, indices, time_ind, model):
        """
        Calculate data point influence
        Measures how this point affects other data points
        
        Args:
            indices: Data point indices
            time_ind: Timestamp index
            model: DCTF model
            
        Returns:
            Influence score
        """
        try:
            # Calculate relevance of this point to others in coreset
            influence_score = 0.0
            count = 0
            
            # Get embedding for this data point
            point_embedding = []
            for mode, idx in enumerate(indices):
                if hasattr(model, 'post_U_m') and model.post_U_m:
                    # Get factor mean at this time point
                    emb = model.post_U_m[mode][idx, :, :, time_ind]
                    point_embedding.append(emb.flatten().detach().cpu().numpy())
            
            if not point_embedding:
                return 0.5  # Return medium influence if embeddings unavailable
                
            # Calculate similarity with other points in coreset
            for core_indices, _, core_time, _ in self.coreset:
                if tuple(indices) != tuple(core_indices):  # Don't compare with self
                    # Get embedding for coreset point
                    core_embedding = []
                    for mode, idx in enumerate(core_indices):
                        if hasattr(model, 'post_U_m') and model.post_U_m:
                            # Get factor mean at that time point
                            emb = model.post_U_m[mode][idx, :, :, core_time]
                            core_embedding.append(emb.flatten().detach().cpu().numpy())
                    
                    if core_embedding:
                        # Calculate similarity between points (cosine similarity)
                        similarity = self._compute_point_similarity(point_embedding, core_embedding)
                        influence_score += similarity
                        count += 1
            
            return influence_score / max(1, count)
            
        except Exception as e:
            logger.error(f"Error computing influence: {e}")
            return 0.5  # Return medium influence on error
    
    def _compute_point_similarity(self, emb1, emb2):
        """Calculate similarity between two data point embeddings"""
        try:
            # Concatenate embeddings across modes
            vec1 = np.concatenate([e.flatten() for e in emb1])
            vec2 = np.concatenate([e.flatten() for e in emb2])
            
            # Calculate cosine similarity
            similarity = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2) + 1e-8)
            return max(0, similarity)  # Ensure non-negative
            
        except Exception as e:
            logger.error(f"Error computing point similarity: {e}")
            return 0.0
    
    def _compute_novelty(self, indices, y, time_ind):
        """
        Calculate data point novelty
        Measures difference from existing coreset
        
        Args:
            indices: Data point indices
            y: Observed value
            time_ind: Timestamp index
            
        Returns:
            Novelty score
        """
        try:
            # Initially assume maximum novelty
            if not self.coreset:
                return 1.0
                
            # Convert indices to list for safe comparison
            indices_list = [int(i) for i in indices]
            
            # Check if data point with same indices already exists
            already_exists = False
            for core_indices, _, _, _ in self.coreset:
                # Safely compare two lists
                if len(indices_list) == len(core_indices):
                    if all(int(a) == int(b) for a, b in zip(indices_list, core_indices)):
                        already_exists = True
                        break
            
            if already_exists:
                return 0.0  # Point with identical indices exists, zero novelty
            
            # Calculate distance to nearest time point
            time_diffs = []
            for _, _, core_time, _ in self.coreset:
                time_diff = abs(int(time_ind) - int(core_time))
                time_diffs.append(time_diff)
            
            # Nearest time point distance
            nearest_time_diff = min(time_diffs) if time_diffs else 1.0
            
            # Time novelty: farther from nearest time point means more novel
            time_novelty = 1.0 - math.exp(-0.1 * nearest_time_diff)
            
            # Index novelty: check how many mode dimensions are new
            index_novelty = 0.0
            unique_indices_per_mode = [set() for _ in range(len(indices_list))]
            
            # Collect existing indices in coreset for each mode
            for core_indices, _, _, _ in self.coreset:
                for mode, idx in enumerate(core_indices):
                    if mode < len(unique_indices_per_mode):
                        unique_indices_per_mode[mode].add(int(idx))
            
            # Calculate proportion of new indices
            new_mode_count = 0
            for mode, idx in enumerate(indices_list):
                if mode < len(unique_indices_per_mode):
                    # Safe membership check
                    if int(idx) not in unique_indices_per_mode[mode]:
                        new_mode_count += 1
            
            index_novelty = new_mode_count / len(indices_list) if indices_list else 0.0
            
            # Combined novelty (time and index)
            novelty = 0.6 * time_novelty + 0.4 * index_novelty
            
            return max(0.0, min(1.0, novelty))  # Ensure in [0,1] range
            
        except Exception as e:
            logger.error(f"Error computing novelty: {e}")
            return 0.5  # Return medium novelty on error
    
    def compute_exploration_rate(self):
        """
        Calculate current exploration rate
        Rate decreases as model confidence increases
        """
        epsilon = self.epsilon_0 * math.exp(-self.lambda_decay * self.confidence)
        return min(1.0, max(0.01, epsilon))  # Ensure within reasonable range
    
    def update_confidence(self, new_confidence):
        """Update model confidence"""
        # Confidence in 0-1 range, controls exploration rate
        self.confidence = max(0.0, min(1.0, new_confidence))
    
    def update_threshold(self, scores):
        """Dynamically update threshold θt based on score distribution"""
        if not self.adaptive or not scores:
            return self.threshold
        
        # Adapt threshold based on score distribution
        scores_array = np.array(scores)
        mean_score = np.mean(scores_array)
        std_score = np.std(scores_array)
        
        # Set threshold to mean minus half standard deviation (ensure selecting enough samples)
        new_threshold = mean_score - 0.5 * std_score
        
        # Limit threshold change range to avoid excessive fluctuation
        self.threshold = max(0.1, min(0.9, new_threshold))
        return self.threshold
    
    def select_points_to_replace(self, candidates, batch_size=None):
        """
        Select which candidate points should replace existing coreset points
        
        Args:
            candidates: Candidate data points [(indices, y, time_ind, score), ...]
            batch_size: Number of points to replace at once (default: self.batch_replace_size)
            
        Returns:
            List of points to add, list of indices to remove
        """
        if batch_size is None:
            batch_size = self.batch_replace_size
            
        # Sort candidates by score in descending order
        sorted_candidates = sorted(candidates, key=lambda x: x[3], reverse=True)
        
        # If coreset is not yet full, no need for replacement
        if len(self.coreset) < self.size_threshold:
            # Add up to size_threshold
            return sorted_candidates[:self.size_threshold - len(self.coreset)], []
            
        # Sort existing coreset by score in ascending order
        self.coreset.sort(key=lambda x: x[3])
        
        # Calculate how many points we can replace
        replace_count = min(batch_size, len(sorted_candidates))
        
        # Check if candidate scores are better than lowest coreset scores
        points_to_add = []
        indices_to_remove = []
        
        for i in range(replace_count):
            if i >= len(sorted_candidates):
                break
                
            # Only replace if candidate score is higher than lowest coreset score
            if i >= len(self.coreset):
                break
                
            if sorted_candidates[i][3] > self.coreset[i][3]:
                points_to_add.append(sorted_candidates[i])
                indices_to_remove.append(i)  # Index of point to remove
            else:
                # If this candidate isn't better, subsequent ones won't be either
                break
                
        return points_to_add, indices_to_remove
    
    def should_start_replacement(self):
        """
        Check if the coreset size has reached the threshold to start replacement
        
        Returns:
            True if replacement should be considered, False otherwise
        """
        return len(self.coreset) >= self.size_threshold
    
    def update_coreset(self, data_batch, model):
        """
        Update coreset with new data points using improved replacement strategy
        
        Args:
            data_batch: Data batch, each element contains (indices, y, time_ind)
            model: DCTF model for importance metrics
            
        Returns:
            Added and removed data points
        """
        # If batch is empty, return empty results
        if not data_batch:
            return [], []
            
        candidates = []
        scores = []
        
        # Calculate importance score for each data point
        for batch_idx, (indices, y, time_ind) in enumerate(data_batch):
            try:
                # Check if already in coreset
                indices_tuple = tuple(int(idx) for idx in indices)
                if indices_tuple in self.coreset_indices:
                    continue
                
                # Calculate importance score
                score = self.compute_importance_score(indices, y, time_ind, model)
                
                # Add to candidate list
                candidates.append((indices, y, time_ind, score))
                scores.append(score)
            except Exception as e:
                logger.error(f"Error calculating data point importance score: {e}")
        
        # Update threshold
        current_threshold = self.update_threshold(scores)
        
        # Apply exploration-exploitation balance
        epsilon = self.compute_exploration_rate()
        
        # Filter candidates based on threshold and exploration rate
        filtered_candidates = []
        for indices, y, time_ind, score in candidates:
            # Exploration-exploitation balance
            if np.random.random() < epsilon:
                # Explore: select randomly
                if np.random.random() < 0.3:  # 30% chance to select
                    filtered_candidates.append((indices, y, time_ind, score))
            else:
                # Exploit: select based on score
                if score > current_threshold:
                    filtered_candidates.append((indices, y, time_ind, score))
        
        added = []
        removed = []
        
        # Check if we should start replacement
        if self.should_start_replacement():
            logger.info(f"Coreset size ({len(self.coreset)}) has reached threshold ({self.size_threshold}). "
                      f"Starting replacement strategy.")
            
            # Apply replacement strategy
            points_to_add, indices_to_remove = self.select_points_to_replace(filtered_candidates)
            
            # Remove points with lowest scores
            for idx in sorted(indices_to_remove, reverse=True):  # Remove from highest index to lowest
                if idx < len(self.coreset):
                    indices, y, time_ind, _ = self.coreset[idx]
                    indices_tuple = tuple(int(i) for i in indices)
                    
                    # Remove from coreset and hash set
                    self.coreset.pop(idx)
                    if indices_tuple in self.coreset_indices:
                        self.coreset_indices.remove(indices_tuple)
                        
                    removed.append((indices, time_ind))
            
            # Add new points
            for indices, y, time_ind, score in points_to_add:
                try:
                    indices_tuple = tuple(int(idx) for idx in indices)
                    if indices_tuple not in self.coreset_indices:
                        self.coreset.append((indices, y, time_ind, score))
                        self.coreset_indices.add(indices_tuple)
                        added.append((indices, time_ind))
                except Exception as e:
                    logger.error(f"Error adding replacement data point to coreset: {e}")
        else:
            # If we haven't reached threshold, directly add points
            # Sort by score
            filtered_candidates.sort(key=lambda x: x[3], reverse=True)
            
            # Add up to size_threshold
            for indices, y, time_ind, score in filtered_candidates[:self.size_threshold - len(self.coreset)]:
                try:
                    # Generate index hash for quick lookup
                    indices_tuple = tuple(int(idx) for idx in indices)
                    if indices_tuple not in self.coreset_indices:
                        self.coreset.append((indices, y, time_ind, score))
                        self.coreset_indices.add(indices_tuple)
                        added.append((indices, time_ind))
                except Exception as e:
                    logger.error(f"Error adding data point to coreset: {e}")
        
        # Add to history
        for indices, y, time_ind, _ in candidates:
            try:
                self.historical_points.append((indices, y, time_ind))
                # Limit history size
                if len(self.historical_points) > self.max_history_size:
                    self.historical_points.pop(0)
            except Exception as e:
                logger.error(f"Error updating history: {e}")
        
        # Log coreset statistics
        if added or removed:
            logger.info(f"Coreset update: {len(added)} points added, {len(removed)} points removed. "
                      f"Current size: {len(self.coreset)}/{self.max_size} (threshold: {self.size_threshold})")
        
        return added, removed
    
    def is_in_coreset(self, indices):
        """Check if specified data point is in the coreset"""
        try:
            # Safely convert indices to integer tuple
            indices_tuple = tuple(int(idx) for idx in indices)
            return indices_tuple in self.coreset_indices
        except Exception:
            return False
    
    def get_coreset_data(self):
        """Get all data points in the coreset"""
        return [(indices, y, time_ind) for indices, y, time_ind, _ in self.coreset]
    
    def get_coreset_size(self):
        """Get coreset size"""
        return len(self.coreset)


class MultiScaleWeighting:
    """
    Multi-scale weighting mechanism for time-dependent data
    Dynamically adjusts weights for different time scales
    """
    
    def __init__(self, num_scales=3, hidden_dim=32, device=torch.device("cpu"), temperature=1.0):
        """
        Initialize multi-scale weighting mechanism
        
        Args:
            num_scales: Number of time scales
            hidden_dim: Hidden layer dimension
            device: Computation device
            temperature: Attention softmax temperature parameter
        """
        self.num_scales = num_scales
        self.hidden_dim = hidden_dim
        self.device = device
        self.temperature = temperature
        
        # Initialize parameters
        self.W = torch.nn.Parameter(torch.randn(hidden_dim, hidden_dim, device=device))
        self.v = torch.nn.Parameter(torch.randn(hidden_dim, device=device))
        self.b = torch.nn.Parameter(torch.zeros(hidden_dim, device=device))
        self.gamma_k = torch.nn.Parameter(torch.ones(num_scales, device=device))
        
        logger.info(f"Multi-scale weighting mechanism initialized: num_scales={num_scales}, hidden_dim={hidden_dim}")
    
    def compute_weights(self, h_k):
        """
        Calculate weights for different time scales
        
        Args:
            h_k: Time scale hidden states list
            
        Returns:
            Normalized weight vector, shape [num_scales, 1]
        """
        # Prevent empty list
        if not h_k or len(h_k) == 0:
            return torch.ones(1, 1, device=self.device)
        
        try:
            # Prepare weight storage
            weights = torch.zeros(min(self.num_scales, len(h_k)), device=self.device)
            
            # Calculate attention score for each scale
            for k in range(min(self.num_scales, len(h_k))):
                # Get hidden state for current scale
                h = h_k[k]
                
                # Ensure h is a 1D tensor
                if isinstance(h, torch.Tensor) and h.dim() > 1:
                    h = h.reshape(-1)
                
                # Convert to tensor if not already
                if not isinstance(h, torch.Tensor):
                    h = torch.tensor(h, device=self.device)
                
                # Map to hidden space
                h_mapped = torch.matmul(self.W, h.float()) + self.b
                
                # Apply tanh activation
                h_tanh = torch.tanh(h_mapped)
                
                # Calculate attention score
                score = torch.matmul(self.v, h_tanh) * self.gamma_k[k]
                weights[k] = score.item() if isinstance(score, torch.Tensor) else score
            
            # Ensure all weights are finite
            weights = torch.where(torch.isfinite(weights), weights, torch.zeros_like(weights))
            
            # If all weights are zero, return uniform weights
            if torch.sum(weights) == 0:
                return torch.ones_like(weights) / weights.size(0)
            
            # Use softmax with temperature parameter
            weights = torch.nn.functional.softmax(weights / self.temperature, dim=0)
            
            return weights.unsqueeze(1)  # Return [num_scales, 1] shape
            
        except Exception as e:
            logger.error(f"Error calculating multi-scale weights: {e}")
            # Return uniform weights on error
            uniform_weights = torch.ones(min(self.num_scales, len(h_k)), device=self.device)
            return (uniform_weights / uniform_weights.sum()).unsqueeze(1)


class MartingaleDataPointCoreSetManager(DataPointCoreSetManager):
    """Enhanced coreset manager using martingale theory"""
    
    def __init__(self, max_size=100, initial_threshold=0.5, adaptive_threshold=True,
                 importance_weights=(0.3, 0.2, 0.2, 0.3), device=torch.device("cpu"),
                 exploration_rate=0.9, decay_rate=0.1, prediction_history_size=50,
                 discount_factor=0.9, simulation_samples=5, batch_replace_size=20,
                 fill_ratio_threshold=0.8):
        """
        Initialize the manager
        
        Args:
            importance_weights: Importance weights (uncertainty, influence, novelty, martingale_increment)
            prediction_history_size: Prediction history size
            discount_factor: Bellman equation discount factor
            simulation_samples: Number of samples for simulation evaluation
            batch_replace_size: Number of data points to replace in a batch
            fill_ratio_threshold: Ratio threshold to start replacement (e.g., 0.8 means start at 80% capacity)
            Other parameters same as parent class
        """
        # Call parent initialization with first three weights
        super().__init__(max_size, initial_threshold, adaptive_threshold, 
                        importance_weights[:3], device, exploration_rate, decay_rate, 
                        batch_replace_size, fill_ratio_threshold)
        
        # Add martingale increment weight
        self.martingale_weight = importance_weights[3] if len(importance_weights) > 3 else 0.3
        
        # Store past prediction distributions for martingale increment
        self.prediction_history = deque(maxlen=prediction_history_size)
        self.value_function_cache = {}
        
        # Bellman equation parameters
        self.discount_factor = discount_factor
        self.simulation_samples = simulation_samples
        
        # Record coreset score distribution for stopping decisions
        self.score_history = []
        
        logger.info(f"Martingale theory coreset manager initialized: max_size={max_size}, "
                   f"importance_weights={importance_weights}, fill_ratio_threshold={fill_ratio_threshold}")
        
    def _compute_martingale_increment(self, indices, time_ind, model):
        """Calculate data point's martingale increment (information gain)
        
        Quantifies information gain based on model prediction ability change
        """
        try:
            # 1. Get current model's prediction ability
            before_error = self._compute_prediction_variance(model)
            
            # 2. Simulate adding data point to model
            # Note: we only evaluate impact, don't actually modify model
            sample_points = [(indices, None, time_ind)]
            after_error = self._evaluate_with_additional_points(model, sample_points)
            
            # 3. Information gain = error reduction
            increment = max(0, before_error - after_error)
            
            # Normalize
            return np.tanh(increment * 5)  # Map increment to [0,1] interval
            
        except Exception as e:
            logger.error(f"Error calculating martingale increment: {e}")
            return 0.0
    
    def _compute_prediction_variance(self, model):
        """Calculate current model's prediction variance
        
        Higher variance indicates lower prediction ability
        """
        try:
            # If test data available, use small test sample to evaluate
            if hasattr(model, 'te_ind') and hasattr(model, 'te_y'):
                sample_size = min(20, len(model.te_ind))
                if sample_size > 0:
                    indices = np.random.choice(len(model.te_ind), sample_size, replace=False)
                    sample_ind = model.te_ind[indices]
                    sample_y = model.te_y[indices]
                    
                    # Get current time step
                    current_time = max(model.unique_train_time) if model.unique_train_time else 0
                    sample_time = np.ones_like(indices) * current_time
                    
                    # Get prediction and variance
                    pred, _ = model.model_test(sample_ind, sample_y, sample_time)
                    
                    # Calculate variance between prediction and actual value
                    variance = torch.var((pred.squeeze() - sample_y.squeeze())).item()
                    return variance
            
            # If no test data, use model internal state estimate
            if hasattr(model, 'post_U_v'):
                # Use average variance of posterior distribution as proxy for prediction uncertainty
                total_var = 0
                count = 0
                for mode in range(len(model.post_U_v)):
                    if model.post_U_v[mode].numel() > 0:
                        total_var += torch.mean(torch.diagonal(model.post_U_v[mode][:, :, :, -1], dim1=1, dim2=2)).item()
                        count += 1
                
                if count > 0:
                    return total_var / count
            
            # Default return a fixed value
            return 1.0
            
        except Exception as e:
            logger.error(f"Error calculating prediction variance: {e}")
            return 1.0
    
    def _evaluate_with_additional_points(self, model, additional_points):
        """Evaluate model prediction ability after adding extra points
        
        Args:
            model: Current model
            additional_points: Extra points to add, each element (indices, y, time_ind)
            
        Returns:
            Prediction variance after adding points
        """
        try:
            # Current coreset
            current_coreset = self.get_coreset_data()
            
            # Merge current coreset and extra points
            combined_coreset = current_coreset + additional_points
            
            # Calculate prediction ability based on merged coreset
            # Note: We use heuristic estimation instead of actually training model
            # Actually training model would be too computationally expensive
            
            # 1. Check if extra points add new modes or time coverage
            # Calculate modes and time periods covered by coreset
            current_modes = set()
            current_times = set()
            for indices, _, time_ind in current_coreset:
                for mode, idx in enumerate(indices):
                    current_modes.add((mode, int(idx)))
                current_times.add(time_ind)
            
            # Calculate new modes and time periods added by extra points
            new_modes = 0
            new_times = 0
            for indices, _, time_ind in additional_points:
                for mode, idx in enumerate(indices):
                    if (mode, int(idx)) not in current_modes:
                        new_modes += 1
                if time_ind not in current_times:
                    new_times += 1
            
            # 2. Estimate prediction variance reduction based on new coverage
            # Assume each new mode reduces variance by 1%, each new time period by 0.5%
            current_variance = self._compute_prediction_variance(model)
            estimated_variance = current_variance * (1 - 0.01 * new_modes - 0.005 * new_times)
            
            # 3. Consider variance reduction from increased sample count
            # Assume variance inversely proportional to square root of sample count
            n1 = len(current_coreset) + 1  # Avoid division by zero
            n2 = n1 + len(additional_points)
            ratio = math.sqrt(n1 / n2)
            
            estimated_variance = estimated_variance * ratio
            
            return max(0.1, estimated_variance)  # Ensure positive variance
            
        except Exception as e:
            logger.error(f"Error evaluating extra points: {e}")
            return self._compute_prediction_variance(model)  # Return current variance on error
    
    def compute_importance_score(self, indices, y, time_ind, model):
        """Calculate data point importance score, adding martingale increment"""
        try:
            # Original scores
            uncertainty = self._compute_uncertainty(indices, time_ind, model)
            influence = self._compute_influence(indices, time_ind, model)
            novelty = self._compute_novelty(indices, y, time_ind)
            
            # Get first three weights from original weights
            w_u, w_i, w_n = self.weights
            
            # Calculate additional martingale increment score
            martingale_increment = self._compute_martingale_increment(indices, time_ind, model)
            
            # Combined score
            score = (w_u * uncertainty + 
                    w_i * influence + 
                    w_n * novelty + 
                    self.martingale_weight * martingale_increment)
            
            if isinstance(score, torch.Tensor):
                score = score.item()
                
            # Store score for statistics
            self.score_history.append(score)
            if len(self.score_history) > 100:  # Keep last 100 scores
                self.score_history.pop(0)
            
            return score
            
        except Exception as e:
            logger.error(f"Error calculating importance score: {e}")
            return 0.5  # Return medium importance on error
    
    def optimal_stopping_strategy(self, candidates, model, budget):
        """Use optimal stopping theory to decide which candidates to add
        
        Args:
            candidates: Candidate data points, each (indices, y, time_ind, score)
            model: Current model
            budget: Maximum data points to add
            
        Returns:
            Selected data points list
        """
        try:
            if not candidates:
                return []
                
            # Sort candidate data points
            sorted_candidates = sorted(candidates, key=lambda x: x[3], reverse=True)
            
            # Calculate dynamic threshold
            if self.score_history:
                # Calculate threshold based on score distribution
                mean_score = np.mean(self.score_history)
                std_score = np.std(self.score_history) if len(self.score_history) > 1 else 0.1
                
                # Dynamic threshold: mean minus part of standard deviation
                # Larger coreset means higher threshold
                coreset_fill_ratio = len(self.coreset) / self.max_size
                threshold = mean_score - (1.0 - coreset_fill_ratio) * std_score
                
                # Adjust threshold based on exploration rate
                exploration_rate = self.compute_exploration_rate()
                threshold = threshold * (1 - exploration_rate)
            else:
                # Use initial threshold if no history data
                threshold = self.threshold
            
            # Apply threshold filtering
            selected = []
            for indices, y, time_ind, score in sorted_candidates:
                # Check if already in coreset
                indices_tuple = tuple(int(idx) for idx in indices)
                if indices_tuple in self.coreset_indices:
                    continue
                
                # If score above threshold and budget not exceeded, select point
                if score > threshold and len(selected) < budget:
                    selected.append((indices, y, time_ind, score))
                else:
                    # Once a point below threshold is encountered, stop
                    # This is the manifestation of optimal stopping strategy
                    break
            
            return selected
            
        except Exception as e:
            logger.error(f"Error in optimal stopping strategy: {e}")
            # On error, simply return first budget points
            return sorted_candidates[:budget]
    
    def _apply_bellman_optimization(self, candidates, model):
        """Apply Bellman equation optimization to coreset update
        
        Consider impact of current decision on future, based on dynamic programming
        """
        try:
            # If too many candidates, first filter with heuristic
            if len(candidates) > 10:
                # Keep top 10 candidates by score
                sorted_candidates = sorted(candidates, key=lambda x: x[3], reverse=True)
                candidates = sorted_candidates[:10]
            
            # Current state value
            current_value = self._estimate_state_value(model)
            
            best_action = []
            max_future_value = current_value  # Initial value is taking no action
            
            # Consider all possible candidate point combinations (max 3 to avoid explosion)
            max_to_add = min(3, len(candidates))
            
            # Iterate all possible selection quantities
            for num_to_add in range(1, max_to_add + 1):
                # For each quantity, consider all possible combinations
                from itertools import combinations
                for combo in combinations(candidates, num_to_add):
                    # Check if combination valid (within budget)
                    if len(self.coreset) + len(combo) <= self.max_size:
                        # Calculate immediate reward for this combination
                        immediate_reward = sum(c[3] for c in combo)
                        
                        # Estimate future value after selecting this combination
                        points_to_add = [(c[0], c[1], c[2]) for c in combo]
                        future_state_value = self._estimate_future_value(model, points_to_add)
                        
                        # Total value = immediate reward + discount factor * future value
                        total_value = immediate_reward + self.discount_factor * future_state_value
                        
                        if total_value > max_future_value:
                            max_future_value = total_value
                            best_action = list(combo)
            
            return best_action
            
        except Exception as e:
            logger.error(f"Error in Bellman optimization: {e}")
            # On error, simply return highest scoring points
            return sorted(candidates, key=lambda x: x[3], reverse=True)[:3]
    
    def _estimate_state_value(self, model):
        """Estimate current state's value function
        
        Args:
            model: Current model
            
        Returns:
            State value, higher means stronger prediction ability
        """
        # Simply use inverse of prediction variance as state value
        prediction_variance = self._compute_prediction_variance(model)
        if prediction_variance > 0:
            return 1.0 / prediction_variance
        return 0.0
    
    def _estimate_future_value(self, model, additional_points):
        """Estimate future state value after adding points
        
        Args:
            model: Current model
            additional_points: Extra points to add
            
        Returns:
            Estimated future state value
        """
        # Estimate prediction variance after adding points
        future_variance = self._evaluate_with_additional_points(model, additional_points)
        if future_variance > 0:
            return 1.0 / future_variance
        return 0.0
    
    def compute_exploration_rate(self):
        """Calculate current exploration rate, using martingale measure transform
        
        Dynamically adjusts exploration rate based on coreset saturation and model confidence
        """
        try:
            # Base exploration rate
            base_epsilon = self.epsilon_0 * math.exp(-self.lambda_decay * self.confidence)
            
            # Coreset fill ratio effect on exploration
            coreset_fill_ratio = len(self.coreset) / self.max_size
            # Higher fill ratio means lower exploration rate
            fill_factor = math.exp(-2.0 * coreset_fill_ratio)
            
            # Calculate current information uncertainty
            if self.prediction_history:
                uncertainties = [entry['uncertainty'] for entry in self.prediction_history]
                uncertainty = np.mean(uncertainties)
            else:
                uncertainty = 0.5
            
            # Higher information uncertainty means higher exploration rate
            uncertainty_factor = 1.0 + uncertainty
            
            # Final exploration rate
            adjusted_epsilon = base_epsilon * fill_factor * uncertainty_factor
            
            # Ensure within reasonable range
            return min(1.0, max(0.01, adjusted_epsilon))
            
        except Exception as e:
            logger.error(f"Error calculating exploration rate: {e}")
            return self.epsilon_0  # Return initial exploration rate on error
    
    def update_coreset(self, data_batch, model):
        """Update coreset using martingale theory, optimal stopping strategy, and improved replacement"""
        try:
            if not data_batch:
                return [], []
                
            # Store importance scores
            candidates = []
            scores = []
            
            # Calculate importance score for each data point
            for batch_idx, (indices, y, time_ind) in enumerate(data_batch):
                # Check if already in coreset
                indices_tuple = tuple(int(idx) for idx in indices)
                if indices_tuple in self.coreset_indices:
                    continue
                    
                # Calculate importance score
                score = self.compute_importance_score(indices, y, time_ind, model)
                
                # Add to candidate list
                candidates.append((indices, y, time_ind, score))
                scores.append(score)
            
            # Update threshold
            if self.adaptive:
                self.update_threshold(scores)
            
            # Calculate maximum data points that can be added
            remaining_budget = self.max_size - len(self.coreset)
            
            # Decide which strategy to use
            if len(candidates) <= 3 or np.random.random() < 0.3:
                # Use Bellman equation optimization (small scale or 30% probability)
                selected = self._apply_bellman_optimization(candidates, model)
            else:
                # Use optimal stopping strategy (large scale or 70% probability)
                selected = self.optimal_stopping_strategy(candidates, model, remaining_budget)
            
            added = []
            removed = []
            
            # Check if we should start replacement strategy
            if self.should_start_replacement():
                logger.info(f"Coreset size ({len(self.coreset)}) has reached threshold ({self.size_threshold}). "
                          f"Starting replacement strategy.")
                
                # Apply batch replacement strategy
                points_to_add, indices_to_remove = self.select_points_to_replace(selected, self.batch_replace_size)
                
                # Remove points with lowest scores
                for idx in sorted(indices_to_remove, reverse=True):  # Remove from highest index to lowest
                    if idx < len(self.coreset):
                        indices, y, time_ind, _ = self.coreset[idx]
                        indices_tuple = tuple(int(i) for i in indices)
                        
                        # Remove from coreset and hash set
                        self.coreset.pop(idx)
                        if indices_tuple in self.coreset_indices:
                            self.coreset_indices.remove(indices_tuple)
                            
                        removed.append((indices, time_ind))
                
                # Add new points
                for indices, y, time_ind, score in points_to_add:
                    indices_tuple = tuple(int(idx) for idx in indices)
                    if indices_tuple not in self.coreset_indices:
                        self.coreset.append((indices, y, time_ind, score))
                        self.coreset_indices.add(indices_tuple)
                        added.append((indices, time_ind))
            else:
                # If we haven't reached threshold, directly add points
                # Add up to size_threshold
                for indices, y, time_ind, score in selected[:self.size_threshold - len(self.coreset)]:
                    indices_tuple = tuple(int(idx) for idx in indices)
                    if indices_tuple not in self.coreset_indices:
                        self.coreset.append((indices, y, time_ind, score))
                        self.coreset_indices.add(indices_tuple)
                        added.append((indices, time_ind))
            
            # Update model confidence
            if added:
                pred_error = self._compute_prediction_error(model)
                confidence = 1.0 / (1.0 + pred_error) if pred_error > 0 else 0.9
                self.update_confidence(confidence)
                
                # Record current prediction state
                self.prediction_history.append({
                    'error': pred_error,
                    'uncertainty': 1.0 - confidence,
                    'coreset_size': len(self.coreset)
                })
            
            # Log coreset statistics
            if added or removed:
                logger.info(f"Coreset update: {len(added)} points added, {len(removed)} points removed. "
                          f"Current size: {len(self.coreset)}/{self.max_size} (threshold: {self.size_threshold})")
            
            return added, removed
            
        except Exception as e:
            logger.error(f"Error updating coreset: {e}")
            return [], []
    
    def _compute_prediction_error(self, model):
        """Calculate current prediction error"""
        try:
            if hasattr(model, 'te_ind') and hasattr(model, 'te_y'):
                # Use small test data sample to evaluate
                sample_size = min(50, len(model.te_ind))
                if sample_size > 0:
                    indices = np.random.choice(len(model.te_ind), sample_size, replace=False)
                    sample_ind = model.te_ind[indices]
                    sample_y = model.te_y[indices]
                    
                    # Get current time
                    current_time = max(model.unique_train_time) if model.unique_train_time else 0
                    sample_time = np.ones_like(indices) * current_time
                    
                    # Calculate prediction
                    pred, _ = model.model_test(sample_ind, sample_y, sample_time)
                    
                    # Calculate MSE
                    mse = torch.mean((pred.squeeze() - sample_y.squeeze()) ** 2)
                    return mse.item()
            
            # Default return 1.0
            return 1.0
            
        except Exception as e:
            logger.error(f"Error calculating prediction error: {e}")
            return 1.0


class EnhancedMultiScaleWeighting(MultiScaleWeighting):
    """Enhanced multi-scale weighting mechanism with Ito formula time scale transform"""
    
    def __init__(self, num_scales=3, hidden_dim=32, device=torch.device("cpu"), 
                 temperature=1.0, time_scale_factor=0.1):
        """
        Initialize multi-scale weighting mechanism
        
        Args:
            num_scales: Number of time scales
            hidden_dim: Hidden layer dimension
            device: Computation device
            temperature: Attention softmax temperature parameter
            time_scale_factor: Time scale transform factor
        """
        super().__init__(num_scales, hidden_dim, device, temperature)
        self.time_scale_factor = time_scale_factor
        
        # Time scale transform parameters
        self.scale_transform = torch.nn.Parameter(torch.randn(num_scales, device=device))
        self.quadratic_term = torch.nn.Parameter(torch.ones(num_scales, device=device) * 0.5)
        
    def compute_weights(self, h_k):
        """Calculate weights for different time scales with Ito formula transform
        
        Args:
            h_k: Time scale hidden states list
            
        Returns:
            Normalized weight vector [num_scales, 1]
        """
        try:
            # First calculate base weights using parent method
            base_weights = super().compute_weights(h_k)
            
            # Apply time scale transform
            adjusted_weights = self._apply_time_scale_transformation(base_weights, h_k)
            
            return adjusted_weights
            
        except Exception as e:
            logger.error(f"Error calculating multi-scale weights: {e}")
            # Return uniform weights on error
            return torch.ones(min(self.num_scales, max(1, len(h_k))), 1, device=self.device) / max(1, len(h_k))
    
    def _apply_time_scale_transformation(self, weights, h_k):
        """Apply Ito formula based nonlinear time scale transform
        
        Simplified Ito formula transform: w_i' = w_i + drift_i * dt + volatility_i * dW_i + 0.5 * volatility_i^2 * dt
        
        Args:
            weights: Base weights [num_scales, 1]
            h_k: Hidden states list
            
        Returns:
            Transformed weights [num_scales, 1]
        """
        try:
            if len(h_k) == 0 or weights.numel() == 0:
                return weights
                
            # Calculate time increment
            dt = self.time_scale_factor
            
            # Extract statistics from hidden states as proxy for random fluctuation
            volatilities = []
            drifts = []
            
            for i, h in enumerate(h_k):
                if isinstance(h, torch.Tensor):
                    # Use hidden state statistics as proxy for random fluctuation and drift
                    h_flat = h.flatten()
                    if h_flat.numel() > 1:
                        volatility = torch.std(h_flat).item()
                        drift = torch.mean(h_flat).item()
                    else:
                        volatility = 0.1
                        drift = 0.0
                else:
                    # If not tensor, use default values
                    volatility = 0.1
                    drift = 0.0
                
                volatilities.append(volatility)
                drifts.append(drift)
            
            # Ensure consistent length
            n = min(len(volatilities), len(drifts), weights.shape[0])
            
            # Create copy of adjusted weights
            adjusted_weights = weights.clone()[:n]
            
            # Apply Ito formula adjustment
            for i in range(n):
                # Generate "random" increment (using pseudo-random method)
                dW = torch.randn(1).item() * np.sqrt(dt)
                
                # Ito correction term
                ito_correction = 0.5 * volatilities[i]**2 * dt
                
                # Apply Ito formula: dX = drift*dt + volatility*dW + 0.5*volatility^2*dt
                delta = drifts[i] * dt + volatilities[i] * dW + ito_correction
                
                # Apply to weight
                adjusted_weights[i] = adjusted_weights[i] * (1 + delta)
            
            # Ensure weights are positive
            adjusted_weights = torch.abs(adjusted_weights)
            
            # Normalize
            if torch.sum(adjusted_weights) > 0:
                adjusted_weights = adjusted_weights / torch.sum(adjusted_weights)
            else:
                adjusted_weights = torch.ones_like(adjusted_weights) / adjusted_weights.shape[0]
            
            # Expand dimension to maintain consistent shape
            return adjusted_weights.unsqueeze(1)
            
        except Exception as e:
            logger.error(f"Error applying time scale transform: {e}")
            return weights


def make_martingale_coreset_dict(config, args=None):
    """Create configuration dictionary for martingale coreset
    
    Args:
        config: Original configuration dictionary
        args: Command line arguments
        
    Returns:
        Extended configuration dictionary
    """
    # First get basic configuration
    try:
        # Import here to avoid circular imports
        from utils_streaming import make_hyper_dict
        hyper_dict = make_hyper_dict(config, args)
    except ImportError:
        # If utils_streaming not available, create a basic dict
        hyper_dict = config.copy()
    
    # Add coreset related parameters with default values
    hyper_dict["coreset_max_size"] = config.get("coreset_max_size", 100)
    hyper_dict["coreset_threshold"] = config.get("coreset_threshold", 0.5)
    hyper_dict["adaptive_threshold"] = config.get("adaptive_threshold", True)
    hyper_dict["importance_weights"] = config.get("importance_weights", (0.3, 0.2, 0.2, 0.3))
    hyper_dict["batch_replace_size"] = config.get("batch_replace_size", 5)
    hyper_dict["fill_ratio_threshold"] = config.get("fill_ratio_threshold", 0.8)
    
    # Add multi-scale weight related parameters
    hyper_dict["num_time_scales"] = config.get("num_time_scales", 3)
    hyper_dict["scale_hidden_dim"] = config.get("scale_hidden_dim", 32)
    hyper_dict["attention_temperature"] = config.get("attention_temperature", 1.0)
    hyper_dict["time_scale_factor"] = config.get("time_scale_factor", 0.1)
    
    # Add exploration-exploitation balance parameters
    hyper_dict["initial_exploration_rate"] = config.get("initial_exploration_rate", 0.9)
    hyper_dict["exploration_decay_rate"] = config.get("exploration_decay_rate", 0.1)
    
    # Add martingale theory related parameters
    hyper_dict["prediction_history_size"] = config.get("prediction_history_size", 50)
    hyper_dict["discount_factor"] = config.get("discount_factor", 0.9)
    hyper_dict["simulation_samples"] = config.get("simulation_samples", 5)
    hyper_dict["bellman_optimization"] = config.get("bellman_optimization", True)
    
    return hyper_dict

# Alias function to resolve import issues
def create_martingale_coreset_dict(config, args=None):
    """Create martingale coreset configuration dictionary (alias function)
    
    For API compatibility, this function is just an alias for make_martingale_coreset_dict
    
    Args:
        config: Configuration dictionary
        args: Command line arguments
        
    Returns:
        Extended configuration dictionary
    """
    return make_martingale_coreset_dict(config, args)