import torch
import torch.nn.functional as F
import torch.optim as optim
import wandb
import collections
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM
from tqdm import tqdm
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import os
import gc

EPS = 1e-8  # Small constant to avoid numerical issues

def get_gpu_memory_usage():
    """Get current GPU memory usage in MB"""
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated() / 1024**2
    return 0

def print_memory_usage(prefix=""):
    """Print current memory usage with optional prefix"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**2
        cached = torch.cuda.memory_reserved() / 1024**2
        print(f"{prefix}Memory - Allocated: {allocated:.2f}MB, Cached: {cached:.2f}MB")

def clear_memory():
    """Clear GPU memory and run garbage collection"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

class ProposedModelTrainer:
    """
    Trainer for fine-tuning language models using proposed model.
    """
    
    def __init__(self, model, config, train_dataset, tokenizer, local_rank=0, world_size=1, inference_only=False):
        """
        Initialize the trainer.
        
        Args:
            model: Base language model to be fine-tuned
            config: Configuration dictionary containing:
                - beta: KL divergence regularization parameter for pi model
                - batch_size: Training batch size
                - num_train_epochs: Number of training epochs
                - mu_learning_rate: Learning rate for mu phase
                - pi_learning_rate: Learning rate for pi phase
                - output_dir: Directory for saving checkpoints
                - wandb_project: Project name for wandb logging
                - lora_r: LoRA rank
                - lora_alpha: LoRA alpha
                - lora_dropout: LoRA dropout
            train_dataset: Dataset containing training samples with keys:
                'x': context string
                'y_w': winning response
                'y_l': losing response
            tokenizer: Tokenizer for text processing
            local_rank: Local rank for distributed training
            world_size: Total number of distributed processes
            inference_only: If True, skip training setup
        """
        self.tokenizer = tokenizer
        self.device = torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu')
        self.local_rank = local_rank
        self.world_size = world_size
        self.local_step = 0
        
        # Load reference model using FP16 instead of quantization
        self.reference_model = AutoModelForCausalLM.from_pretrained(
            config['reference_model_id'],
            trust_remote_code=config['trust_remote_code'],
            torch_dtype=torch.float16,  # Use FP16 instead of 4-bit quantization
            device_map="auto"
        )
        
        if not inference_only:
            # Training mode setup
            # Configure LoRA
            lora_config = LoraConfig(
                r=config['lora_r'],
                lora_alpha=config['lora_alpha'],
                target_modules=config['target_modules'],
                lora_dropout=config['lora_dropout'],
                bias=config['bias'],
                task_type="CAUSAL_LM"  # Fixed task type for causal language modeling
            )

            # Prepare model for LoRA training with 4-bit quantization
            self.model = prepare_model_for_kbit_training(model)
            self.model = get_peft_model(self.model, lora_config)
            
            self.config = config
            self.train_dataset = train_dataset
            
            # Create distributed sampler if using multiple GPUs
            if world_size > 1:
                self.sampler = DistributedSampler(train_dataset)
            else:
                self.sampler = None
            
            # Initialize optimizers with separate learning rates
            self.mu_optimizer = optim.AdamW(
                self.model.parameters(),
                lr=config['mu_learning_rate']
            )
            self.pi_optimizer = optim.AdamW(
                self.model.parameters(),
                lr=config['pi_learning_rate']
            )
            
            if 'wandb_project' in config:
                wandb.init(
                    project=config['wandb_project'],
                    config=config,
                    name=config.get('wandb_run_name', None)  # Use run name if provided
                )
            
            # Enable gradient checkpointing for memory efficiency
            # Handle DDP-wrapped models
            if isinstance(self.model, DDP):
                self.model.module.gradient_checkpointing_enable()
            else:
                self.model.gradient_checkpointing_enable()

            # Add tokenization cache with max size
            self.tokenization_cache = collections.OrderedDict()
            self.max_cache_size = 10000  # Limit cache size
            
            # Add gradient accumulation steps
            self.gradient_accumulation_steps = config.get('gradient_accumulation_steps', 4)
            
            # Add sub-batch size
            self.sub_batch_size = config.get('sub_batch_size', 4)
            
            # Pre-tokenize the entire dataset
            print("Pre-tokenizing dataset...")
            for sample in self.train_dataset:
                x, y_w, y_l = sample['x'], sample['y_w'], sample['y_l']
                
                # Cache all required context-candidate pairs
                self._cache_tokenization(f"<|im_start|>user\n{x}<|im_end|>\n<|im_start|>assistant\n", y_w)
                self._cache_tokenization(f"<|im_start|>user\n{x}<|im_end|>\n<|im_start|>assistant\n", y_l)
                self._cache_tokenization(f"<|im_start|>user\n{x}<|im_end|>\n<|im_start|>assistant\n{y_w}<|im_end|>\n<|im_start|>user\nCould you give me a more preferred response than this?<|im_end|>\n<|im_start|>assistant\n", y_l)
                self._cache_tokenization(f"<|im_start|>user\n{x}<|im_end|>\n<|im_start|>assistant\n{y_l}<|im_end|>\n<|im_start|>user\nCould you give me a more preferred response than this?<|im_end|>\n<|im_start|>assistant\n", y_w)
        else:
            # Inference mode setup - minimal configuration
            self.model = model
            self.config = config
        
        # Move model to appropriate device
        self.model = self.model.to(self.device)
        self.model.eval()  # Set to evaluation mode for inference
        
        # Set reference model to eval mode
        self.reference_model.eval()
        
        # Wrap model with DDP if using multiple GPUs
        if world_size > 1:
            self.model = DDP(self.model, device_ids=[local_rank])
    
    def get_batches(self):
        """Get batches with distributed sampling support"""
        if self.sampler is not None:
            self.sampler.set_epoch(0)  # Set epoch for shuffling
            indices = list(self.sampler)
        else:
            indices = list(range(len(self.train_dataset)))
        
        batch_size = self.config['batch_size'] // self.world_size
        for i in range(0, len(indices), batch_size):
            batch_indices = indices[i:i + batch_size]
            batch = [self.train_dataset[idx] for idx in batch_indices]
            yield batch

    def _cache_tokenization(self, context, candidate):
        """Cache tokenization results for a context-candidate pair with LRU eviction"""
        key = (context, candidate)
        
        try:
            # If key exists, move it to the end (most recently used)
            if key in self.tokenization_cache:
                self.tokenization_cache.move_to_end(key)
                return
            
            # If cache is full, remove the least recently used item
            if len(self.tokenization_cache) >= self.max_cache_size:
                self.tokenization_cache.popitem(last=False)
            
            # Add new tokenization to cache - keep on CPU and ensure long type
            context_tokens = self.tokenizer(context, return_tensors="pt")["input_ids"][0].long()
            candidate_tokens = self.tokenizer(candidate, return_tensors="pt")["input_ids"][0].long()
            
            # Check for potential issues
            if len(context_tokens) == 0 or len(candidate_tokens) == 0:
                print(f"Warning: Empty tokenization for context or candidate")
                print(f"Context: {context[:100]}...")
                print(f"Candidate: {candidate[:100]}...")
                print(f"Context tokens shape: {context_tokens.shape}, type: {context_tokens.dtype}")
                print(f"Candidate tokens shape: {candidate_tokens.shape}, type: {candidate_tokens.dtype}")
            
            self.tokenization_cache[key] = (context_tokens, candidate_tokens)  # Store on CPU
            
        except Exception as e:
            print(f"Error in tokenization cache: {str(e)}")
            print(f"Context: {context[:100]}...")
            print(f"Candidate: {candidate[:100]}...")
            raise

    
    def compute_seq_log_probs_batch(self, model, contexts, candidates):
        """
        Compute log probabilities for multiple context-candidate pairs in sub-batches.
        
        Args:
            model: Model to compute probabilities from
            contexts: List of context strings
            candidates: List of candidate responses
            
        Returns:
            Tensor of average log probabilities per token for each sequence
        """
        # Handle DDP-wrapped models
        if isinstance(model, DDP):
            model = model.module
        
        total_log_probs = []
        
        # Process in larger sub-batches to reduce overhead
        sub_batch_size = 4
        for i in range(0, len(contexts), sub_batch_size):
            sub_contexts = contexts[i:i + sub_batch_size]
            sub_candidates = candidates[i:i + sub_batch_size]
            
            batch_size = len(sub_contexts)
            max_length = 0
            input_ids_list = []
            labels_list = []
            
            # Get tokenized inputs from cache and prepare batched inputs
            for context, candidate in zip(sub_contexts, sub_candidates):
                context_tokens, candidate_tokens = self.tokenization_cache[(context, candidate)]
                input_ids = torch.cat([context_tokens, candidate_tokens])
                labels = torch.cat([torch.full_like(context_tokens, -100), candidate_tokens])
                
                max_length = max(max_length, len(input_ids))
                input_ids_list.append(input_ids)
                labels_list.append(labels)
            
            # Pad sequences to max_length
            padded_input_ids = torch.stack([
                F.pad(ids, (0, max_length - len(ids)), value=self.tokenizer.pad_token_id)
                for ids in input_ids_list
            ]).to(self.device).long()  # Ensure input_ids are long type
            
            padded_labels = torch.stack([
                F.pad(labs, (0, max_length - len(labs)), value=-100)
                for labs in labels_list
            ]).to(self.device).long()  # Ensure labels are long type
            
            # Add debugging information
            # if self.local_rank == 0:
            #     print(f"\nDebug info for batch:")
            #     print(f"Max sequence length: {max_length}")
            #     print(f"Input shape: {padded_input_ids.shape}")
            #     print(f"Labels shape: {padded_labels.shape}")
            #     print(f"Input type: {padded_input_ids.dtype}")
            #     print(f"Labels type: {padded_labels.dtype}")
            #     print(f"First sequence length: {len(input_ids_list[0])}")
            #     print(f"Last sequence length: {len(input_ids_list[-1])}")
            
            # Single forward pass for sub-batch with mixed precision
            with torch.cuda.amp.autocast():
                # Only use no_grad for reference model
                if model is self.reference_model:
                    with torch.no_grad():
                        outputs = model(input_ids=padded_input_ids, labels=padded_labels)
                else:
                    outputs = model(input_ids=padded_input_ids, labels=padded_labels)
            
            # The loss is already averaged per token, so we just negate it to get log probability per token
            sub_log_probs = -outputs.loss * torch.ones(len(sub_contexts), device=self.device)
            total_log_probs.append(sub_log_probs)
            
            # Clear CUDA cache after each sub-batch
            del padded_input_ids, padded_labels, outputs
            torch.cuda.empty_cache()
            
            # Monitor memory usage
            # if self.local_rank == 0 and i % 10 == 0:
            #     allocated = torch.cuda.memory_allocated() / 1024**2
            #     cached = torch.cuda.memory_reserved() / 1024**2
            #     print(f"Memory after sub-batch {i//sub_batch_size}: Allocated={allocated:.2f}MB, Cached={cached:.2f}MB")
        
        return torch.cat(total_log_probs)

    def compute_kl_divergence(self, model, reference_model, contexts, batch):
        """
        Compute sequence-level KL divergence using importance sampling.
        
        Args:
            model: Current model to evaluate
            reference_model: Reference model to compare against
            contexts: List of context strings
            batch: Current batch containing y_w and y_l
            
        Returns:
            Average KL divergence across all contexts
        """
        kl_div = torch.zeros(1, device=self.device)
        
        # Extract y_w and y_l from batch
        y_w_list = [sample['y_w'] for sample in batch]
        y_l_list = [sample['y_l'] for sample in batch]
        
        # Process contexts in batches to reduce memory usage
        batch_size = 32
        for i in range(0, len(contexts), batch_size):
            batch_contexts = contexts[i:i + batch_size]
            batch_y_w = y_w_list[i:i + batch_size]
            batch_y_l = y_l_list[i:i + batch_size]
            
            # Compute log probabilities for both models
            with torch.cuda.amp.autocast():
                # Model log probs for y_w and y_l
                model_log_probs_w = self.compute_seq_log_probs_batch(model, batch_contexts, batch_y_w)
                model_log_probs_l = self.compute_seq_log_probs_batch(model, batch_contexts, batch_y_l)
                
                # Reference model log probs for y_w and y_l
                with torch.no_grad():
                    ref_log_probs_w = self.compute_seq_log_probs_batch(reference_model, batch_contexts, batch_y_w)
                    ref_log_probs_l = self.compute_seq_log_probs_batch(reference_model, batch_contexts, batch_y_l)
            
            # Compute importance weights (π(y|x)/π_d(y|x))
            log_weights_w = model_log_probs_w - ref_log_probs_w.detach()
            log_weights_l = model_log_probs_l - ref_log_probs_l.detach()
            
            # Compute log ratios (log(π(y|x)/π_ref(y|x)))
            log_ratios_w = model_log_probs_w - ref_log_probs_w.detach()
            log_ratios_l = model_log_probs_l - ref_log_probs_l.detach()
            
            # Compute weighted KL terms
            kl_terms_w = torch.exp(log_weights_w) * log_ratios_w
            kl_terms_l = torch.exp(log_weights_l) * log_ratios_l
            
            # Average over both y_w and y_l
            kl_div += (kl_terms_w + kl_terms_l).mean()
        
        return kl_div / len(contexts)  # Average over all contexts

    def compute_mu_loss(self, batch):
        """
        Compute the mu loss for the first phase of training.
        
        The mu loss encourages the model to match the reference model's behavior
        while maintaining the preference ordering between winning and losing responses.
        
        Args:
            batch: Current batch of training samples
            
        Returns:
            Total loss including KL regularization
        """
        # print_memory_usage("Before mu_loss computation: ")
        
        # Add regularization term with same beta as pi_loss
        beta_pi = self.config.get('beta_pi', 0.1)
        
        # Extract all samples at once
        x_list = [sample['x'] for sample in batch]
        y_w_list = [sample['y_w'] for sample in batch]
        y_l_list = [sample['y_l'] for sample in batch]
        
        # Prepare all contexts at once
        context_mu_y_l = [f"<|im_start|>user\n{x}<|im_end|>\n<|im_start|>assistant\n{y_w}<|im_end|>\n<|im_start|>user\nCould you give me a more preferred response than this?<|im_end|>\n<|im_start|>assistant\n" 
                         for x, y_w in zip(x_list, y_w_list)]
        context_pi_d_y_l = [f"<|im_start|>user\n{x}<|im_end|>\n<|im_start|>assistant\n" for x in x_list]
        context_mu_y_w = [f"<|im_start|>user\n{x}<|im_end|>\n<|im_start|>assistant\n{y_l}<|im_end|>\n<|im_start|>user\nCould you give me a more preferred response than this?<|im_end|>\n<|im_start|>assistant\n" 
                         for x, y_l in zip(x_list, y_l_list)]
        context_pi_d_y_w = [f"<|im_start|>user\n{x}<|im_end|>\n<|im_start|>assistant\n" for x in x_list]
        
        # Cache all required tokenizations first
        for context, candidate in zip(context_mu_y_l, y_l_list):
            self._cache_tokenization(context, candidate)
        for context, candidate in zip(context_pi_d_y_l, y_l_list):
            self._cache_tokenization(context, candidate)
        for context, candidate in zip(context_mu_y_w, y_w_list):
            self._cache_tokenization(context, candidate)
        for context, candidate in zip(context_pi_d_y_w, y_w_list):
            self._cache_tokenization(context, candidate)
        
        # Use mixed precision context
        with torch.cuda.amp.autocast():
            # Compute all log probabilities in batches
            log_prob_mu_y_l = self.compute_seq_log_probs_batch(self.model, context_mu_y_l, y_l_list)
            log_prob_pi_d_y_l = self.compute_seq_log_probs_batch(self.reference_model, context_pi_d_y_l, y_l_list)
            log_prob_mu_y_w = self.compute_seq_log_probs_batch(self.model, context_mu_y_w, y_w_list)
            log_prob_pi_d_y_w = self.compute_seq_log_probs_batch(self.reference_model, context_pi_d_y_w, y_w_list)
            
            # Compute loss components
            loss_l = log_prob_mu_y_l - log_prob_pi_d_y_l.detach()
            loss_w = log_prob_mu_y_w - log_prob_pi_d_y_w.detach()
            clamped_l = torch.clamp(loss_l, -10.0, 10.0)
            clamped_w = torch.clamp(loss_w, -10.0, 10.0)
            m = torch.max(clamped_l, clamped_w)
            exp_diff = torch.exp(clamped_l - m) - torch.exp(clamped_w - m)
            loss = exp_diff * torch.exp(m)
            
            # Compute sequence-level KL divergence
            kl_div = self.compute_kl_divergence(self.model, self.reference_model, context_pi_d_y_l, batch)
            
            # Compute total loss
            total_loss = loss.mean() + beta_pi * kl_div
        
        wandb.log({'mu_loss': total_loss.item(), 'mu_loss_kl': beta_pi * kl_div.item(), 'mu_loss_loss': loss.mean().item(), 'step': self.local_step})
        return total_loss

    def compute_pi_loss(self, batch):
        """
        Compute the pi loss for the second phase of training.
        
        The pi loss maximizes the reward while staying close to the reference model
        through KL regularization.
        
        Args:
            batch: Current batch of training samples
            
        Returns:
            Total loss including KL regularization
        """
        # print_memory_usage("Before pi_loss computation: ")
        
        x_list = [sample['x'] for sample in batch]
        y_w_list = [sample['y_w'] for sample in batch]
        y_l_list = [sample['y_l'] for sample in batch]
        
        # Prepare contexts and candidates for batched processing
        pi_contexts = [f"<|im_start|>user\n{x}<|im_end|>\n<|im_start|>assistant\n" for x in x_list]
        mu_contexts_w = [f"<|im_start|>user\n{x}<|im_end|>\n<|im_start|>assistant\n{y_w}<|im_end|>\n<|im_start|>user\nCould you give me a more preferred response than this?<|im_end|>\n<|im_start|>assistant\n" 
                        for x, y_w in zip(x_list, y_w_list)]
        mu_contexts_l = [f"<|im_start|>user\n{x}<|im_end|>\n<|im_start|>assistant\n{y_l}<|im_end|>\n<|im_start|>user\nCould you give me a more preferred response than this?<|im_end|>\n<|im_start|>assistant\n" 
                        for x, y_l in zip(x_list, y_l_list)]
        
        # Cache all required tokenizations first
        for context, candidate in zip(pi_contexts, y_w_list):
            self._cache_tokenization(context, candidate)
        for context, candidate in zip(pi_contexts, y_l_list):
            self._cache_tokenization(context, candidate)
        for context, candidate in zip(mu_contexts_w, y_l_list):
            self._cache_tokenization(context, candidate)
        for context, candidate in zip(mu_contexts_l, y_w_list):
            self._cache_tokenization(context, candidate)
        
        beta = self.config.get('beta', 0.0)  # Get beta from config, default to 0.0
        beta_pi = self.config.get('beta_pi', 0.1)
        
        # Use mixed precision context
        with torch.cuda.amp.autocast():
            # Compute all log probabilities in batches
            log_prob_pi_y_w = self.compute_seq_log_probs_batch(self.model, pi_contexts, y_w_list)
            log_prob_pi_y_l = self.compute_seq_log_probs_batch(self.model, pi_contexts, y_l_list)
            log_prob_mu_y_w = self.compute_seq_log_probs_batch(self.model, mu_contexts_l, y_w_list)
            log_prob_mu_y_l = self.compute_seq_log_probs_batch(self.model, mu_contexts_w, y_l_list)
            
            # Get reference model log probabilities
            with torch.no_grad():
                log_prob_ref_y_w = self.compute_seq_log_probs_batch(self.reference_model, pi_contexts, y_w_list)
                log_prob_ref_y_l = self.compute_seq_log_probs_batch(self.reference_model, pi_contexts, y_l_list)
            
            # Compute mu terms
            mu_w = torch.exp(log_prob_mu_y_w)
            mu_l = torch.exp(log_prob_mu_y_l)
            

            log_denom_w = torch.log1p(mu_w) + beta + beta * mu_w
            log_denom_l = torch.log1p(-mu_l) + beta - beta * mu_l
            
            # Compute log terms
            log_term_w = log_prob_pi_y_w - log_denom_w.detach()
            log_term_l = log_prob_pi_y_l - log_denom_l.detach()

            ratio_pi_w = torch.exp(log_prob_pi_y_w - log_prob_ref_y_w.detach())
            ratio_pi_l = torch.exp(log_prob_pi_y_l - log_prob_ref_y_l.detach())
            
            # Compute weighted loss
            loss = (ratio_pi_w * log_term_w + ratio_pi_l * log_term_l).mean()
            
            # Compute sequence-level KL divergence
            kl_div = self.compute_kl_divergence(self.model, self.reference_model, pi_contexts, batch)
            
            # Total loss: weighted cross-entropy + β_π * KL
            total_loss = loss + beta_pi * kl_div
            
        wandb.log({
            'pi_loss': total_loss.item(),
            'pi_loss_kl': beta_pi * kl_div.item(),
            'pi_loss_main': loss.item(),
            'step': self.local_step
        })
        return total_loss

    def train(self):
        total_batches = len(self.train_dataset) // (self.config['batch_size'] // self.world_size)
        if len(self.train_dataset) % (self.config['batch_size'] // self.world_size) != 0:
            total_batches += 1
        
        # Initialize gradient scalers at the beginning of training
        self.mu_scaler = torch.cuda.amp.GradScaler()
        self.pi_scaler = torch.cuda.amp.GradScaler()
        
        
        if self.local_rank == 0:
            print("\nPhase 1: Training with mu_loss...")
        self.model.train()
        
        # First phase: Train with mu_loss
        mu_pbar = tqdm(range(self.config.get('mu_epochs', 1)), desc="μ Training Epochs", disable=self.local_rank != 0)
        for epoch in mu_pbar:
            batch_pbar = tqdm(self.get_batches(), total=total_batches, desc=f"μ Epoch {epoch+1}", 
                            disable=self.local_rank != 0, leave=False)
            
            for batch in batch_pbar:
                # Compute mu_loss with mixed precision
                with torch.cuda.amp.autocast():
                    mu_loss = self.compute_mu_loss(batch)
                
                # Update model with gradient scaling
                self.mu_optimizer.zero_grad()
                self.mu_scaler.scale(mu_loss).backward()
                self.mu_scaler.unscale_(self.mu_optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                self.mu_scaler.step(self.mu_optimizer)
                self.mu_scaler.update()
                
                # Clear CUDA cache after each update
                torch.cuda.empty_cache()
                
                # Update progress bar
                if self.local_rank == 0:
                    batch_pbar.set_postfix({'mu_loss': f"{mu_loss.item():.4f}"})
                
                # Save model checkpoint
                self.local_step += 1
                if self.local_rank == 0 and self.local_step % self.config['save_steps'] == 0:
                    self.save_model(save_path=os.path.join(self.config['output_dir'], f"mu_step_{self.local_step}"))
        
        # Second phase: Train with pi_loss
        if self.local_rank == 0:
            print("\nPhase 2: Training with pi_loss...")
        
        self.model.train()
        
        pi_pbar = tqdm(range(self.config.get('pi_epochs', 2)), desc="π Training Epochs", disable=self.local_rank != 0)
        for epoch in pi_pbar:
            batch_pbar = tqdm(self.get_batches(), total=total_batches, desc=f"π Epoch {epoch+1}", 
                            disable=self.local_rank != 0, leave=False)
            
            for batch in batch_pbar:
                # Compute pi_loss with mixed precision
                with torch.cuda.amp.autocast():
                    pi_loss = self.compute_pi_loss(batch)
                
                # Update model with gradient scaling
                self.pi_optimizer.zero_grad()
                self.pi_scaler.scale(pi_loss).backward()
                self.pi_scaler.unscale_(self.pi_optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                self.pi_scaler.step(self.pi_optimizer)
                self.pi_scaler.update()
                
                # Clear CUDA cache after each update
                torch.cuda.empty_cache()
                
                # Update progress bar
                if self.local_rank == 0:
                    batch_pbar.set_postfix({'pi_loss': f"{pi_loss.item():.4f}"})
                
                # Save model checkpoint
                self.local_step += 1
                if self.local_rank == 0 and self.local_step % self.config['save_steps'] == 0:
                    self.save_model(save_path=os.path.join(self.config['output_dir'], f"step_{self.local_step}"))
        
        if self.local_rank == 0:
            print("\nTraining completed.")
        
    def save_model(self, save_path=None):
        """
        Save model checkpoint and tokenizer.
        
        Args:
            save_path: Optional path to save the model. If None, uses default path
                      based on current step and beta value.
        """
        if self.local_rank == 0:
            if save_path is None:
                save_path = os.path.join(self.config['output_dir'], f"step_{self.local_step}")
                os.makedirs(save_path, exist_ok=True)
            
            # Extract the base model from DDP if needed
            model_to_save = self.model.module if isinstance(self.model, DDP) else self.model
            
            # Add beta value to the save path
            beta = self.config.get('beta', 0.0)
            save_path = os.path.join(save_path, f"beta_{beta}")
            os.makedirs(save_path, exist_ok=True)
            
            model_to_save.save_pretrained(save_path)
            self.tokenizer.save_pretrained(save_path)
            print(f"Saved model at step {self.local_step} to {save_path}")

    def eval(self):
        """Set model to evaluation mode"""
        if isinstance(self.model, DDP):
            self.model.module.eval()
        else:
            self.model.eval()
        
    def train_mode(self, mode=True):
        """
        Set model to training or evaluation mode.
        
        Args:
            mode: True for training mode, False for evaluation mode
        """
        if mode:
            if isinstance(self.model, DDP):
                self.model.module.train()
            else:
                self.model.train()
        else:
            if isinstance(self.model, DDP):
                self.model.module.eval()
            else:
                self.model.eval()
    
    def generate(self, prompt, num_return_sequences=1, max_length=500, **kwargs):
        """
        Generate responses using the trained model.
        
        Args:
            prompt: Input prompt string
            num_return_sequences: Number of sequences to generate
            max_length: Maximum length of generated sequences
            **kwargs: Additional generation parameters (temperature, top_p, etc.)
            
        Returns:
            List of generated response strings
        """
        # Get the base model from DDP if needed
        model_for_generation = self.model.module if isinstance(self.model, DDP) else self.model
        model_for_generation.eval()
        
        # Format the prompt with proper separators
        formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
        
        # Tokenize the initial prompt
        input_ids = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.device)["input_ids"]
        
        generated_sequences = []
        
        for _ in range(num_return_sequences):
            # Generate with huggingface generate API
            with torch.no_grad():
                outputs = model_for_generation.generate(
                    input_ids=input_ids,
                    max_length=max_length,
                    do_sample=True,
                    temperature=kwargs.get('temperature', 0.7),
                    top_p=kwargs.get('top_p', 0.9),
                    top_k=kwargs.get('top_k', 50),
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                )
            
            # Decode the generated sequence
            generated_text = self.tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
            generated_sequences.append(generated_text)
        
        return generated_sequences 

    def reset_model_memory(self):
        # Clear retained graphs
        if hasattr(self.model, 'zero_grad'):
            self.model.zero_grad(set_to_none=True)  # More aggressive than zero_grad()
        
        # Clear cache in case model has internal caches
        if hasattr(self.model, 'config'):
            self.model.config.use_cache = False
        
        # Force garbage collection
        gc.collect()
        torch.cuda.empty_cache() 