import os
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.nn import functional as F
from transformers import AutoTokenizer, AutoModel
from sacrebleu.metrics import BLEU
from bert_score import BERTScorer
import numpy as np
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'



class TextCompressor(pl.LightningModule):
    def __init__(self,
                 vocab_size,
                 latent_dim=256,
                 hidden_dim=768, # Used for decoder feedforward, not directly here
                 num_layers=6,  # Decoder layers
                 num_heads=8,   # Decoder heads
                 dropout=0.1,
                 pooling_strategy="cls", # <-- Add parameter: "mean" or "cls"
                 teacher_forcing_start_ratio=0.9,
                 teacher_forcing_end_ratio=0.1,
                 tokenizer=None, # If passed, will be used; otherwise loads ModernBERT's
                 lr = 3e-4,
                 new_lr=5e-5,
                 noise_sigma = None,
                 max_length= 30, # Likely for tokenizer/decoder, not encoder pooling
                 modern_bert_model_name="answerdotai/ModernBERT-base", # Allow specifying model
                 scheduler_type = "plateau", # "plateau" or "cosine"
                ):
        super().__init__()
        # Ensure pooling_strategy is valid
        assert pooling_strategy in ["mean", "cls"], "pooling_strategy must be 'mean' or 'cls'"
        # Save hyperparameters, ignoring tokenizer if passed externally
        self.save_hyperparameters(ignore=['tokenizer'])

        self.lr = lr
        self.new_lr = new_lr
        self.teacher_forcing_start_ratio = teacher_forcing_start_ratio
        self.teacher_forcing_end_ratio = teacher_forcing_end_ratio
        self.running_loss = 1.0
        self.noise_sigma = noise_sigma
        self.pooling_strategy = pooling_strategy # Store the strategy
        self.scheduler_type = scheduler_type
        # Load Model and Tokenizer
        self.modern_bert = AutoModel.from_pretrained(modern_bert_model_name)
        # Use provided tokenizer or load default
        self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(modern_bert_model_name)
        self.max_length = max_length # Store max_length if needed elsewhere

        # Projection Layer 
        self.projection = nn.Linear(self.modern_bert.config.hidden_size, latent_dim)

        # --- Rest of the decoder setup ---
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=latent_dim,
            nhead=num_heads,
            dim_feedforward= hidden_dim * 4, # Often feedforward is 4x d_model
            dropout=dropout,
            batch_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.output_projection = nn.Linear(latent_dim, vocab_size)
        self.decoder_embedding = nn.Embedding(vocab_size, latent_dim) # Use vocab_size from hparams
        # Contrastive head (Consider removing if only doing reconstruction)
        self.proj = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, latent_dim)
        )
        self.validation_step_outputs = []
        
    def exponential_schedule(self, current_step: int) -> float:
        progress = current_step / self.num_training_steps
        return self.end_ratio + (self.start_ratio - self.end_ratio) * math.exp(-self.decay_rate * progress)

    def get_teacher_forcing_ratio(self):
        """Calculate teacher forcing ratio based on current epoch."""
        if not self.trainer:
            return self.teacher_forcing_start_ratio
        
        current_epoch = 0#self.trainer.current_epoch
        max_epochs = self.trainer.max_epochs - self.trainer.current_epoch
        
        ratio = self.teacher_forcing_start_ratio - (
            (self.teacher_forcing_start_ratio - self.teacher_forcing_end_ratio) *
            (current_epoch / max_epochs)
        )
        return max(ratio, self.teacher_forcing_end_ratio)
        
    def get_noise_sigma(self):
        # During evaluation, use full noise unless explicitly testing without noise
        if not self.training:
            return self.noise_sigma if self.noise_sigma is not None else 0.0
        
        # During training, apply warmup if needed
        if self.noise_sigma is None:
            return 0.0
        
        current_epoch = self.trainer.current_epoch
        max_epochs = self.trainer.max_epochs
        warmup_fraction = 0.00  # Use 30% of training for warmup
        warmup_epochs = int(max_epochs * warmup_fraction)
        
        if current_epoch < warmup_epochs:
            return self.noise_sigma * (current_epoch / warmup_epochs)
        return self.noise_sigma
        
    def encode(self, x, attention_mask=None, test_mode=False):
            """
            Encodes input text tokens into a fixed-size latent vector.
    
            Args:
                x (torch.Tensor): Input token IDs (batch_size, seq_len).
                attention_mask (torch.Tensor, optional): Mask for input tokens
                                                        (batch_size, seq_len). Defaults to None.
                                                        Required for 'mean' pooling.
                test_mode (bool): Flag to potentially disable noise during inference/testing.
    
            Returns:
                torch.Tensor: The normalized latent vector (batch_size, latent_dim).
                              Noise may be added depending on mode and noise_sigma.
            """
            # --- 1. Get Encoder Hidden States ---
            outputs = self.modern_bert(
                input_ids=x,
                attention_mask=attention_mask,
                return_dict=True # Ensure outputs object has attributes like last_hidden_state
            )
            last_hidden_state = outputs.last_hidden_state # (batch_size, seq_len, hidden_size)
    
            # --- 2. Apply Pooling Strategy ---
            if self.pooling_strategy == "cls":
                pooled_repr = last_hidden_state[:, 0, :] # (batch_size, hidden_size)
            elif self.pooling_strategy == "mean":
                if attention_mask is None:
                    # A simple average if no mask provided (less robust)
                    # Warning: This includes padding if attention_mask isn't used!
                    # print("Warning: Mean pooling without attention mask!") # Optional warning
                    # pooled_repr = last_hidden_state.mean(dim=1)
                    raise ValueError("Attention mask is required for mean pooling.")
    
                # Expand attention mask for broadcasting: (batch_size, seq_len, 1)
                input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
                # Sum weighted by mask
                sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, dim=1)
                # Count non-padding tokens (handle potential all-padding sequences)
                sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
                # Calculate mean
                pooled_repr = sum_embeddings / sum_mask # (batch_size, hidden_size)
            else:
                # This case should be caught by __init__ assertion, but good practice
                raise ValueError(f"Invalid pooling_strategy: {self.pooling_strategy}")
    
            # --- 3. Project to Latent Dimension ---
            z = self.projection(pooled_repr) # (batch_size, latent_dim)
    
            # --- 4. Normalize (Pre-Noise) ---
            z = F.normalize(z, p=2, dim=-1)
            # Store the pre-noise embedding if needed for analysis (optional)
            # pre_noise_z = z.clone()
    
            # --- 5. Optionally Add Noise ---
            # Use get_noise_sigma() to be consistent with training_step logic if it exists
            current_sigma = self.get_noise_sigma() if not test_mode else self.noise_sigma
            if current_sigma is not None and current_sigma > 0:
                # Generate noise from N(0, sigma^2) - note torch.randn_like is N(0, 1)
                noise = torch.randn_like(z) * current_sigma
                z = z + noise
                # --- 6. Re-Normalize (Post-Noise) ---
                # Normalize again after adding noise to stay on the unit sphere
                z = F.normalize(z, p=2, dim=-1)
            # else: # If no noise added, z is already normalized from step 4
    
            return z
            
    def generate(self, z, max_length=30):
        batch_size = z.size(0)
        device = z.device
        
        # Start with start token
        curr_tokens = torch.full((batch_size, 1), 
                               fill_value=self.tokenizer.cls_token_id or 0,
                               dtype=torch.long, device=device)
        
        # Create attention mask for generated sequence
        attention_mask = torch.ones_like(curr_tokens, dtype=torch.bool)
        
        for _ in range(max_length - 1):
            logits = self.decode_step(z, curr_tokens, attention_mask)
            next_token_logits = logits[:, -1, :]
            next_token = next_token_logits.argmax(dim=-1, keepdim=True)
            
            # Update sequences and attention mask
            curr_tokens = torch.cat([curr_tokens, next_token], dim=1)
            attention_mask = torch.cat([
                attention_mask,
                torch.ones((batch_size, 1), dtype=torch.bool, device=device)
            ], dim=1)
        
        return curr_tokens
        
    def generate_square_subsequent_mask(self, sz):
        """Generate a square mask for the sequence. The masked positions are filled with float('-inf').
        Unmasked positions are filled with float(0.0).
        """
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
        
    def decode_step(self, z, curr_tokens, attention_mask=None):
        batch_size, seq_len = curr_tokens.size()
        
        # Create causal mask for decoder self-attention
        tgt_mask = self.generate_square_subsequent_mask(seq_len).to(curr_tokens.device)
        
        # First embed the tokens
        tgt_emb = self.decoder_embedding(curr_tokens)  # [batch_size, seq_len, latent_dim]
        
        # Expand z to match sequence length
        z = z.unsqueeze(1).expand(-1, seq_len, -1)  # [batch_size, seq_len, latent_dim]
        
        # Handle attention mask
        if attention_mask is not None:
            # Make sure attention_mask matches the sequence length
            if attention_mask.size(1) != seq_len:
                attention_mask = attention_mask[:, :seq_len]
            # Convert to boolean mask where True means "mask this position"
            key_padding_mask = ~attention_mask.bool()  # [batch_size, seq_len]
        else:
            key_padding_mask = None
        
        output = self.decoder(
            tgt=tgt_emb,           # [batch_size, seq_len, latent_dim]
            memory=z,              # [batch_size, seq_len, latent_dim]
            tgt_mask=tgt_mask,     # [seq_len, seq_len]
            tgt_key_padding_mask=key_padding_mask,  # [batch_size, seq_len]
            memory_key_padding_mask=None
        )

        # Project to vocabulary
        logits = self.output_projection(output)  # [batch_size, seq_len, vocab_size]
        return logits

    def decode(self, z, tgt, tgt_mask=None, tgt_padding_mask=None):
        """Full decoding pass for teacher-forced or evaluation mode."""
        seq_len = tgt.shape[1]
        memory = z.unsqueeze(1).expand(-1, seq_len, -1)
        
        if not hasattr(self, 'decoder_embedding'):
            self.decoder_embedding = nn.Embedding(self.hparams.vocab_size, self.hparams.latent_dim)
        tgt_emb = self.decoder_embedding(tgt)
        
        decoded = self.decoder(
            tgt_emb,
            memory,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_padding_mask
        )
        
        logits = self.output_projection(decoded)
        return logits
    import torch.profiler

    def forward(self, src, tgt, tgt_mask=None, src_padding_mask=None, tgt_padding_mask=None):
        with torch.profiler.profile(
            activities=[
                torch.profiler.ProfilerActivity.CPU,
                torch.profiler.ProfilerActivity.CUDA,
            ]
        ) as prof:
            z = self.encode(src, src_padding_mask)
            logits = self.decode(z, tgt, tgt_mask, tgt_padding_mask)
        
        print(prof.key_averages().table(sort_by="cuda_time_total"))
        
        return logits, z
        """
    def forward(self, src, tgt, tgt_mask=None, src_padding_mask=None, tgt_padding_mask=None):
        z = self.encode(src, src_padding_mask)
        logits = self.decode(z, tgt, tgt_mask, tgt_padding_mask)
        return logits, z
    """
    def create_causal_mask(self, seq_len, device):
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
        return mask.to(device)

    def on_train_start(self):
        # Explicitly set the learning rate when training starts
        optimizer = self.trainer.optimizers[0]
        for param_group in optimizer.param_groups:
            param_group['lr'] = self.new_lr
            
    def on_train_epoch_start(self):
        
        """Implement noise curriculum"""
        current_epoch = self.trainer.current_epoch
        max_epochs = self.trainer.max_epochs
        
        # Optional: Gradually increase noise during training
        if self.noise_sigma is not None:
            # Start with smaller noise and increase to target
            min_noise = 0.01
            target_noise = self.noise_sigma
            self.current_noise = min_noise + (target_noise - min_noise) * (current_epoch / (max_epochs * 0.5))
            self.current_noise = min(self.current_noise, target_noise)
            
    def training_step(self, batch, batch_idx):
        batch = {k: v.to(self.device) if torch.is_tensor(v) else v for k, v in batch.items()}
    
        # Unpack batch
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask'].bool()
        target_ids = batch['target_ids']
        aug_ids = batch['aug_ids']
        aug_attention_mask = batch['aug_attention_mask'].bool()
        
        # Get ModernBERT outputs and project to latent space
        outputs = self.modern_bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        """
        cls_token_repr = outputs.last_hidden_state[:, 0, :]
        z = self.projection(cls_token_repr)
        z = F.normalize(z, p=2, dim=-1)
        """
        # Apply noise using the encode method to ensure consistency
        z = self.encode(input_ids, attention_mask)
        
#        z_aug = self.encode(aug_ids, aug_attention_mask)
        
        # Get current performance metrics
        if not hasattr(self, 'running_loss'):
            self.running_loss = 1.0
        
        # Adaptive teacher forcing ratio
        base_ratio = self.get_teacher_forcing_ratio()
        adaptive_ratio = 1.0#base_ratio * min(1.0, self.running_loss)
        # Prepare target sequences
        prev_tokens = target_ids[:, :-1]
        target_tokens = target_ids[:, 1:]
        
        if torch.rand(1).item() < adaptive_ratio:
            # Teacher forcing
            logits = self.decode_step(z, prev_tokens, attention_mask[:, :-1])
        else:
            # No teacher forcing
            batch_size = z.size(0)
            seq_len = prev_tokens.size(1)
            
            curr_tokens = torch.full(
                (batch_size, 1),
                fill_value=self.tokenizer.cls_token_id,
                dtype=torch.long,
                device=self.device
            )
            
            for _ in range(seq_len):
                curr_mask = torch.ones_like(curr_tokens, dtype=torch.bool)
                step_logits = self.decode_step(z, curr_tokens, curr_mask)
                next_token = step_logits[:, -1:].argmax(dim=-1)
                curr_tokens = torch.cat([curr_tokens, next_token], dim=1)
            
            curr_attention_mask = torch.ones_like(curr_tokens[:, :-1], dtype=torch.bool, device=self.device)
            logits = self.decode_step(z, curr_tokens[:, :-1], curr_attention_mask)
        
        # Calculate losses
        rec_loss = F.cross_entropy(
            logits.reshape(-1, logits.size(-1)),
            target_tokens.reshape(-1),
            ignore_index=self.tokenizer.pad_token_id
        )
        
        # Contrastive loss
#        z_proj = F.normalize(self.proj(z), dim=-1)
 #       z_aug_proj = F.normalize(self.proj(z_aug), dim=-1)
        
        temp = 0.1
  #      sim = torch.matmul(z_proj, z_aug_proj.T) / temp
   #     labels = torch.arange(z.shape[0], device=self.device)
    #    con_loss = F.cross_entropy(sim, labels)
        
        # Combined loss
        loss = rec_loss #+ 0.0 * con_loss
        
        # Update running loss
        with torch.no_grad():
            self.running_loss = 0.95 * self.running_loss + 0.05 * loss.item()
        
        # Logging
        self.log('train_loss', loss, prog_bar=True)
        self.log('rec_loss', rec_loss, prog_bar=True)
        #self.log('con_loss', con_loss, prog_bar=True)
        self.log('teacher_forcing_ratio', adaptive_ratio, prog_bar=True)
        self.log('running_loss', self.running_loss, prog_bar=True)
        
        # Get the current learning rate
        optimizer = self.optimizers()
        if isinstance(optimizer, list):
            optimizer = optimizer[0]
        current_lr = optimizer.param_groups[0]['lr']
        self.log('learning_rate', current_lr, prog_bar=True)
        
        return loss
        
    def validation_step(self, batch, batch_idx):
        batch = {k: v.to(self.device) if torch.is_tensor(v) else v for k, v in batch.items()}
    
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask'].bool()
        
        with torch.no_grad():
            # Encode input
            z = self.encode(input_ids, attention_mask)
            
            # Generate output
            generated_ids = self.generate(z,self.max_length)
            
            # Calculate reconstruction loss (for monitoring purposes)
            if 'target_ids' in batch:
                target_ids = batch['target_ids']
                prev_tokens = target_ids[:, :-1]
                target_tokens = target_ids[:, 1:]
                
                # Get logits from decoder
                logits = self.decode_step(z, prev_tokens, attention_mask[:, :-1])
                
                # Calculate loss
                val_loss = F.cross_entropy(
                    logits.reshape(-1, logits.size(-1)),
                    target_tokens.reshape(-1),
                    ignore_index=self.tokenizer.pad_token_id
                )
            else:
                # If no target_ids are provided, use input_ids as targets
                prev_tokens = input_ids[:, :-1]
                target_tokens = input_ids[:, 1:]
                
                # Get logits from decoder
                logits = self.decode_step(z, prev_tokens, attention_mask[:, :-1])
                
                # Calculate loss
                val_loss = F.cross_entropy(
                    logits.reshape(-1, logits.size(-1)),
                    target_tokens.reshape(-1),
                    ignore_index=self.tokenizer.pad_token_id
                )
            
            # Log validation loss
            self.log('val_loss', val_loss, prog_bar=True, sync_dist=True)
            
            # Decode for BLEU calculation
            target_texts = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)
            generated_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            
            if batch_idx == 0:
                for i in range(min(3, len(target_texts))):
                    self.print(f"\nExample {i+1}:")
                    self.print(f"Target    : {target_texts[i]}")
                    self.print(f"Generated : {generated_texts[i]}")
            
            bleu = BLEU()
            bleu_score = bleu.corpus_score(generated_texts, [target_texts]).score
            self.log('val_bleu', bleu_score, prog_bar=True, sync_dist=True)
            
            if not hasattr(self, 'validation_step_outputs'):
                self.validation_step_outputs = []
            self.validation_step_outputs.append({
                'target_texts': target_texts,
                'generated_texts': generated_texts,
                'bleu': bleu_score,
                'val_loss': val_loss.item()
            })
            
            return {
                'val_loss': val_loss,
                'bleu': bleu_score
            }
    
    def on_validation_epoch_start(self):
        self.validation_step_outputs = []
    
    def on_validation_epoch_end(self):
        all_targets = []
        all_generated = []
        
        for output in self.validation_step_outputs:
            all_generated.extend(output['generated_texts'])
            all_targets.extend(output['target_texts'])
        
        # BLEU
        bleu = BLEU()
        epoch_bleu = bleu.corpus_score(all_generated, [all_targets]).score
        
        # BERTScore
        if not hasattr(self, 'bert_scorer'):
            self.bert_scorer = BERTScorer(lang='en', rescale_with_baseline=True)
        
        max_samples = len(all_targets)#min(100, len(all_targets))
        P, R, F1 = self.bert_scorer.score(
            all_generated[:max_samples], 
            all_targets[:max_samples]
        )
        bert_score = F1.mean().item()
        
        # METEOR
        try:
            from nltk.translate import meteor_score
            import nltk
            nltk.download('wordnet', quiet=True)
            
            meteor_scores = []
            for ref, hyp in zip(all_targets[:max_samples], all_generated[:max_samples]):
                try:
                    score = meteor_score.meteor_score([ref], hyp)
                    meteor_scores.append(score)
                except:
                    continue
            meteor = sum(meteor_scores) / len(meteor_scores) if meteor_scores else 0.0
        except ImportError:
            meteor = 0.0
        
        self.log('epoch_val_bleu', epoch_bleu, sync_dist=True)
        self.log('epoch_val_bert_score', bert_score, sync_dist=True)
        self.log('epoch_val_meteor', meteor, sync_dist=True)
        
        self.print("\nValidation Epoch End:")
        self.print(f"BLEU Score: {epoch_bleu:.2f}")
        self.print(f"BERTScore: {bert_score:.2f}")
        self.print(f"METEOR Score: {meteor:.2f}")
        
        import random
        self.print("\nRandom Examples:")
        indices = random.sample(range(len(all_targets)), k=min(5, len(all_targets)))
        for idx in indices:
            self.print(f"\nExample:")
            self.print(f"Target    : {all_targets[idx]}")
            self.print(f"Generated : {all_generated[idx]}")
        
        self.validation_step_outputs.clear()
        del all_targets, all_generated
        
    def on_validation_end(self):
        if hasattr(self, 'validation_step_outputs'):
            del self.validation_step_outputs
        torch.cuda.empty_cache()   
         
    def configure_optimizers(self):
        """
        Configure optimizer and learning rate scheduler for training.
        Supports both ReduceLROnPlateau and CosineAnnealingLR.
        """
        # Define optimizer
        optimizer = torch.optim.Adam(self.parameters(), lr=self.new_lr)
        
        # Choose scheduler based on configuration
        scheduler_type = getattr(self.hparams, 'scheduler_type', 'plateau')  # Default to plateau
        
        if scheduler_type.lower() == 'cosine':
            # CosineAnnealingLR scheduler
            T_max = getattr(self.hparams, 'cosine_t_max', 225)  # Default: 50 epochs cycle
            eta_min = getattr(self.hparams, 'cosine_eta_min', 1e-6)  # Minimum learning rate
            
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                T_max=T_max,
                eta_min=eta_min
            )
            
            scheduler_config = {
                "scheduler": scheduler,
                "interval": "epoch",
                "frequency": 1,
                "monitor": None,  # Not needed for CosineAnnealingLR
            }
            
            print(f"Using CosineAnnealingLR scheduler with T_max={T_max}, eta_min={eta_min}")
            
        else:  # 'plateau' or any other value defaults to ReduceLROnPlateau
            # ReduceLROnPlateau scheduler parameters
            factor = getattr(self.hparams, 'plateau_factor', 0.75)
            patience = getattr(self.hparams, 'plateau_patience', 1)
            min_lr = getattr(self.hparams, 'plateau_min_lr', 1e-6)
            
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode='min',
                factor=factor,
                patience=patience,
                min_lr=min_lr,
                verbose=True
            )
            
            scheduler_config = {
                "scheduler": scheduler,
                "monitor": "val_loss",  # The metric to monitor
                "interval": "epoch",    # Adjust LR every epoch
                "frequency": 1,         # Adjust LR after every epoch
                "strict": False,        # Don't crash if metric is not available
            }
            
            print(f"Using ReduceLROnPlateau scheduler with factor={factor}, patience={patience}, min_lr={min_lr}")
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler_config,
        }

    def test_step(self, batch, batch_idx):
        batch = {k: v.to(self.device) if torch.is_tensor(v) else v for k, v in batch.items()}
    
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask'].bool()
        test_mode=False
        with torch.no_grad():
            if test_mode : 
                z, z_zeroed = self.encode(input_ids, attention_mask,test_mode=test_mode)
                generated_ids_zeroed = self.generate(z_zeroed)
                generated_texts_zeroed = self.tokenizer.batch_decode(generated_ids_zeroed, skip_special_tokens=True)
            else :
                z= self.encode(input_ids, attention_mask,test_mode=test_mode)
            generated_ids = self.generate(z,max_length=self.max_length)
            
            target_texts = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)
            generated_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    
            if batch_idx == 0:
                for i in range(min(3, len(target_texts))):
                    self.print(f"\nTest Example {i+1}:")
                    self.print(f"Target          : {target_texts[i]}")
                    self.print(f"Generated       : {generated_texts[i]}")
            #        self.print(f"Generated (50% zero): {generated_texts_zeroed[i]}")
            
            bleu = BLEU()
            bleu_score = bleu.corpus_score(generated_texts, [target_texts]).score
            self.log('test_bleu', bleu_score, prog_bar=True, sync_dist=True)
            #bleu_zeroed = BLEU().corpus_score(generated_texts_zeroed, [target_texts]).score
            #self.print(f"BLEU (normal): {bleu_score:.4f}, BLEU (50% zeroed): {bleu_zeroed:.4f}")
            if not hasattr(self, 'test_step_outputs'):
                self.test_step_outputs = []
            self.test_step_outputs.append({
                'target_texts': target_texts,
                'generated_texts': generated_texts,
                'bleu': bleu_score
            })
            
            return {
                'test_loss': torch.tensor(0.0, device=self.device),
                'bleu': bleu_score
            }
    
    def on_test_epoch_start(self):
        self.test_step_outputs = []
    
    def on_test_epoch_end(self):
        all_targets = []
        all_generated = []
        
        for output in self.test_step_outputs:
            all_generated.extend(output['generated_texts'])
            all_targets.extend(output['target_texts'])
        
        # BLEU
        bleu = BLEU()
        epoch_bleu = bleu.corpus_score(all_generated, [all_targets]).score
        
        # BERTScore
        if not hasattr(self, 'bert_scorer'):
            self.bert_scorer = BERTScorer(lang='en', rescale_with_baseline=True)
        
        max_samples = len(all_targets)  # Use all samples for test set
        P, R, F1 = self.bert_scorer.score(
            all_generated[:max_samples], 
            all_targets[:max_samples]
        )
        
        bert_score = F1.mean().item()
        
        # METEOR
        try:
            from nltk.translate import meteor_score
            import nltk
            nltk.download('wordnet', quiet=True)
            
            meteor_scores = []
            for ref, hyp in zip(all_targets[:max_samples], all_generated[:max_samples]):
                try:
                    score = meteor_score.meteor_score([ref], hyp)
                    meteor_scores.append(score)
                except:
                    continue
            meteor = sum(meteor_scores) / len(meteor_scores) if meteor_scores else 0.0
        except ImportError:
            meteor = 0.0
        
        self.log('epoch_test_bleu', epoch_bleu, sync_dist=True)
        self.log('epoch_test_bert_score', bert_score, sync_dist=True)
        self.log('epoch_test_meteor', meteor, sync_dist=True)
        
        self.print("\nTest Epoch End:")
        self.print(f"BLEU Score: {epoch_bleu:.2f}")
        self.print(f"BERTScore: {bert_score:.2f}")
        self.print(f"METEOR Score: {meteor:.2f}")
        
        import random
        self.print("\nRandom Test Examples:")
        indices = random.sample(range(len(all_targets)), k=min(5, len(all_targets)))
        for idx in indices:
            self.print(f"\nExample:")
            self.print(f"Target    : {all_targets[idx]}")
            self.print(f"Generated : {all_generated[idx]}")
        
        # Save test results to file
        self.print("\nSaving test results to file...")
        results = []
        for i in range(len(all_targets)):
            results.append({
                "target": all_targets[i],
                "generated": all_generated[i]
            })
        
        import json
        import os
        os.makedirs("test_results", exist_ok=True)
        with open("test_results/generated_texts.json", "w") as f:
            json.dump(results, f, indent=2)
        
        self.test_step_outputs.clear()
        del all_targets, all_generated
        
    def on_test_end(self):
        if hasattr(self, 'test_step_outputs'):
            del self.test_step_outputs
    def save_pretrained(self, path):
        """Save the model to a directory."""
        os.makedirs(path, exist_ok=True)
        torch.save(self.state_dict(), os.path.join(path, "model.pt"))
        torch.save(self.hparams, os.path.join(path, "hparams.pt"))
        self.tokenizer.save_pretrained(path)
    
    @classmethod
    def from_pretrained(cls, path):
        """Load a pretrained model from a directory."""
        hparams = torch.load(os.path.join(path, "hparams.pt"))
        model = cls(**hparams)
        model.load_state_dict(torch.load(os.path.join(path, "model.pt")))
        model.tokenizer = AutoTokenizer.from_pretrained(path)
        return model