import numpy as np

class AdaptiveLearningRateScheduler:
    def __init__(self, optimizers_dict, initial_lrs, config):
        self.optimizers = optimizers_dict
        self.initial_lrs = initial_lrs
        self.config = config

        self.history = {
            'accuracies': [],
            'weight_diffs': {'text': [], 'audio': [], 'video': [], 'avg': []},
            'weight_cosines': {'text': [], 'audio': [], 'video': [], 'avg': []},
            'concept_losses': []
        }

        self.patience = {
            'main': 4,         
            'concept': 3,      
            'text_weight': 2,  
            'audio_weight': 4,   
            'video_weight': 5  
        }
        
        self.factor = {
            'main': 0.8,          
            'concept': 0.7,       
            'text_weight': 0.6,   
            'audio_weight': 0.8,  
            'video_weight': 0.9   
        }
        
        self.min_lr_factor = 1e-3

        self.wait_counts = {key: 0 for key in optimizers_dict.keys()}
        self.best_metrics = {
            'accuracy': 0.0,
            'text_diff': float('inf'),
            'audio_diff': float('inf'),
            'video_diff': float('inf'),
            'avg_diff': float('inf'),
            'text_cosine': 0.0,
            'audio_cosine': 0.0,
            'video_cosine': 0.0
        }

        self.stagnation_threshold = 3 
        self.improvement_threshold = 0.001 
        
    def step(self, accuracy, weight_analysis, concept_loss):
        self._update_history(accuracy, weight_analysis, concept_loss)
        epoch = len(self.history['accuracies'])
        
        print(f"  Learning Rates (Epoch {epoch}):")

        self._adjust_main_lr(accuracy)

        self._adjust_concept_lr(concept_loss, epoch)

        self._adjust_weight_predictors_lr(weight_analysis)

        self._handle_global_stagnation()

        current_lrs = self.get_current_lrs()
        for name, lr in current_lrs.items():
            print(f"    {name}: {lr:.2e}")
    
    def _update_history(self, accuracy, weight_analysis, concept_loss):
        self.history['accuracies'].append(accuracy)
        self.history['weight_diffs']['text'].append(weight_analysis['text_diff'])
        self.history['weight_diffs']['audio'].append(weight_analysis['audio_diff'])
        self.history['weight_diffs']['video'].append(weight_analysis['video_diff'])
        self.history['weight_diffs']['avg'].append(weight_analysis['avg_diff'])
        self.history['weight_cosines']['text'].append(weight_analysis['text_cosine'])
        self.history['weight_cosines']['audio'].append(weight_analysis['audio_cosine'])
        self.history['weight_cosines']['video'].append(weight_analysis['video_cosine'])
        self.history['weight_cosines']['avg'].append(weight_analysis['avg_cosine'])
        self.history['concept_losses'].append(concept_loss)
    
    def _adjust_main_lr(self, accuracy):
        if accuracy > self.best_metrics['accuracy']:
            self.best_metrics['accuracy'] = accuracy
            self.wait_counts['main'] = 0
        else:
            self.wait_counts['main'] += 1
            if self.wait_counts['main'] >= self.patience['main']:
                self._reduce_lr('main')
                self.wait_counts['main'] = 0
    
    def _adjust_concept_lr(self, concept_loss, epoch):
        cosine_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.config.num_epochs))
        base_lr = self.initial_lrs['concept'] * cosine_factor

        if len(self.history['concept_losses']) >= 3:
            recent_losses = self.history['concept_losses'][-3:]
            if recent_losses[-1] > recent_losses[-2] > recent_losses[-3]:
                base_lr *= 0.8
        
        new_lr = max(base_lr, self.initial_lrs['concept'] * self.min_lr_factor)
        self._set_lr('concept', new_lr)
    
    def _adjust_weight_predictors_lr(self, weight_analysis):
        weight_metrics = {
            'text_weight': (weight_analysis['text_diff'], weight_analysis['text_cosine'], 'text'),
            'audio_weight': (weight_analysis['audio_diff'], weight_analysis['audio_cosine'], 'audio'),
            'video_weight': (weight_analysis['video_diff'], weight_analysis['video_cosine'], 'video')
        }
        
        for weight_key, (current_diff, current_cosine, modal) in weight_metrics.items():
            diff_key = f'{modal}_diff'
            cosine_key = f'{modal}_cosine'

            diff_improved = current_diff < self.best_metrics[diff_key]
            cosine_improved = current_cosine > self.best_metrics[cosine_key]
            
            if diff_improved or cosine_improved:
                if diff_improved:
                    self.best_metrics[diff_key] = current_diff
                if cosine_improved:
                    self.best_metrics[cosine_key] = current_cosine
                self.wait_counts[weight_key] = 0

                if diff_improved and len(self.history['weight_diffs'][modal]) >= 2:
                    improvement = (self.history['weight_diffs'][modal][-2] - 
                                 self.history['weight_diffs'][modal][-1])
                    if improvement > 0.01:
                        self._boost_lr(weight_key, factor=1.1)
                        
            else:
                self.wait_counts[weight_key] += 1
                if self.wait_counts[weight_key] >= self.patience[weight_key]:
                    self._reduce_lr(weight_key)
                    self.wait_counts[weight_key] = 0
    
    def _handle_global_stagnation(self):
        if len(self.history['accuracies']) < self.stagnation_threshold:
            return

        recent_accs = self.history['accuracies'][-self.stagnation_threshold:]
        acc_stagnant = max(recent_accs) - min(recent_accs) < self.improvement_threshold

        recent_diffs = self.history['weight_diffs']['avg'][-self.stagnation_threshold:]
        weight_stagnant = abs(recent_diffs[0] - recent_diffs[-1]) < self.improvement_threshold
        
        if acc_stagnant and weight_stagnant:
            current_lr = self.optimizers['video_weight'].param_groups[0]['lr']
            new_lr = min(current_lr * 1.5, self.initial_lrs['video_weight'] * 2)
            self._set_lr('video_weight', new_lr)
    
    def _reduce_lr(self, optimizer_name):
        current_lr = self.optimizers[optimizer_name].param_groups[0]['lr']
        factor = self.factor.get(optimizer_name, 0.7)
        new_lr = max(current_lr * factor, 
                    self.initial_lrs[optimizer_name] * self.min_lr_factor)
        self._set_lr(optimizer_name, new_lr)
    
    def _boost_lr(self, optimizer_name, factor=1.1):
        current_lr = self.optimizers[optimizer_name].param_groups[0]['lr']
        max_lr = self.initial_lrs[optimizer_name] * 1.5 
        new_lr = min(current_lr * factor, max_lr)
        if new_lr > current_lr:
            self._set_lr(optimizer_name, new_lr)
    
    def _set_lr(self, optimizer_name, new_lr):
        for param_group in self.optimizers[optimizer_name].param_groups:
            param_group['lr'] = new_lr
    
    def get_current_lrs(self):
        return {name: opt.param_groups[0]['lr'] for name, opt in self.optimizers.items()}
    
    def get_statistics(self):
        if not self.history['accuracies']:
            return {}
            
        return {
            'best_accuracy': max(self.history['accuracies']),
            'best_avg_weight_diff': min(self.history['weight_diffs']['avg']),
            'current_lrs': self.get_current_lrs(),
            'lr_reductions': {name: count for name, count in self.wait_counts.items()},
            'epochs_processed': len(self.history['accuracies'])
        }
