import torch
    
class ImprovedParetoOptimizer:
    def __init__(self, num_losses=3, alpha=0.7, max_norm=1.0, min_weight=0.05):
        self.num_losses = num_losses
        self.alpha = alpha  
        self.max_norm = max_norm
        self.min_weight = min_weight  
        self.prev_losses = None
        self.step_count = 0  
        self.epoch_count = 0 

        self.weight_history = []
        self.loss_history = []
        
    def new_epoch(self):
        self.epoch_count += 1
        
    def get_weights(self, losses):
        self.step_count += 1

        self.loss_history.append(losses.clone().detach())
        
        if self.prev_losses is None:
            self.prev_losses = losses.clone().detach()
            weights = self._compute_magnitude_weights(losses)
            self.weight_history.append(weights.clone())
            return weights

        if self.step_count <= 5:
            weights = self._compute_magnitude_weights(losses)

            self.prev_losses = 0.3 * losses.detach() + 0.7 * self.prev_losses
            
            self.weight_history.append(weights.clone())
            return weights

        loss_ratios = losses / (self.prev_losses + 1e-8)

        self.prev_losses = self.alpha * self.prev_losses + (1 - self.alpha) * losses.detach()

        magnitude_weights = self._compute_magnitude_weights(losses)

        change_factor = torch.clamp(loss_ratios, 0.5, 2.0)
        change_weights = change_factor / change_factor.sum()

        progress_factor = min(1.0, (self.step_count - 5) / 50)
        weights = (1 - 0.3 * progress_factor) * magnitude_weights + (0.3 * progress_factor) * change_weights

        weights = torch.clamp(weights, self.min_weight, self.max_norm)
        weights = weights / weights.sum()

        self.weight_history.append(weights.clone())
        
        return weights
    
    def _compute_magnitude_weights(self, losses):
        normalized_losses = losses / (losses.sum() + 1e-8)

        weights = normalized_losses

        weights = torch.clamp(weights, self.min_weight, 1.0)
        weights = weights / weights.sum()
        
        return weights
    
    def get_statistics(self):
        if not self.weight_history:
            return {}
        
        recent_weights = torch.stack(self.weight_history[-10:])
        
        return {
            'total_steps': self.step_count,
            'epochs': self.epoch_count,
            'recent_weight_std': recent_weights.std(dim=0),
            'recent_weight_mean': recent_weights.mean(dim=0),
            'weight_stability': recent_weights.std(dim=0).mean().item()
        }
    
