"""
Candidate step generation system for online best-of-n
"""

import torch
from typing import List, Dict, Tuple, Optional, Union
from transformers import StoppingCriteriaList
import logging
import time
import gc
from dataclasses import fields

from lm_polygraph import WhiteboxModel
from luh import AutoUncertaintyHead
from .step_detection import (
    StepBoundaryDetector, 
    BatchStepStoppingCriteria, 
    StepExtractionResult
)

log = logging.getLogger(__name__)


class StepCandidate:
    """Represents a candidate next step in trajectory"""
    
    def __init__(
        self,
        text: str,
        token_ids: List[int],
        is_complete: bool,
        is_trajectory_complete: bool,
        generation_scores: Optional[torch.Tensor] = None,
        raw_text: str = None,
        # New fields for batch context
        batch_idx: Optional[int] = None,
        sample_idx: Optional[int] = None,
        trajectory_context: Optional[str] = None
    ):
        self.text = text
        self.token_ids = token_ids
        self.is_complete = is_complete
        self.is_trajectory_complete = is_trajectory_complete
        self.generation_scores = generation_scores
        self.raw_text = raw_text or text
        # Store batch context
        self.batch_idx = batch_idx
        self.sample_idx = sample_idx
        self.trajectory_context = trajectory_context
        
    def __str__(self):
        return f"StepCandidate(text='{self.text[:50]}...', complete={self.is_complete}, batch={self.batch_idx}, sample={self.sample_idx})"


class StepCandidateGenerator:
    """Generates N candidate next steps for online best-of-n"""
    
    def __init__(
        self,
        model: WhiteboxModel,
        uncertainty_head = None,
        uhead_path: str = None,
        detector: StepBoundaryDetector = None,
        candidates_per_step: int = 5,
        temperature: float = 0.8,
        top_p: float = 0.95,
        top_k: int = 50,
        max_new_tokens: int = 250,
        device: str = "cuda",
        memory_efficient: bool = True,
        offload_to_cpu: bool = True
    ):
        self.model = model
        self.detector = detector or StepBoundaryDetector()
        self.candidates_per_step = candidates_per_step
        self.temperature = temperature
        self.top_p = top_p
        self.top_k = top_k
        self.max_new_tokens = max_new_tokens
        self.device = device
        self.memory_efficient = memory_efficient
        self.offload_to_cpu = offload_to_cpu
        
        # Initialize uncertainty head (optional - only needed for UHead-based scoring)
        if uncertainty_head is not None:
            self.uncertainty_head = uncertainty_head
        elif uhead_path is not None:
            self.uncertainty_head = AutoUncertaintyHead.from_pretrained(
                uhead_path, 
                self.model.model
            ).to(device)
            self.uncertainty_head.eval()
        else:
            self.uncertainty_head = None
        
    def _move_generation_output_to_cpu(self, out):
        """
        Move all tensors in generation output to CPU following bestofn approach.
        """
        # Get the class of the output object
        output_class = type(out)
        
        # Dictionary to store the new field values
        new_values = {}
        
        # Process each field in the dataclass
        for field in fields(out):
            field_value = getattr(out, field.name)
            
            if field_value is None:
                new_values[field.name] = None
            elif field.name == 'sequences':
                # sequences: torch.LongTensor
                new_values[field.name] = field_value.cpu() if torch.is_tensor(field_value) else field_value
            elif field.name in ['scores', 'logits']:
                # scores/logits: Optional[tuple[torch.FloatTensor]]
                if field_value is not None:
                    new_values[field.name] = tuple(tensor.cpu() for tensor in field_value)
                else:
                    new_values[field.name] = None
            elif field.name in ['attentions', 'hidden_states']:
                # attentions/hidden_states: Optional[tuple[tuple[torch.FloatTensor]]]
                if field_value is not None:
                    new_values[field.name] = tuple(
                        tuple(tensor.cpu() for tensor in inner_tuple)
                        for inner_tuple in field_value
                    )
                else:
                    new_values[field.name] = None
            elif field.name == 'past_key_values':
                # past_key_values: Optional[tuple[tuple[tuple[torch.FloatTensor]]]]
                if field_value is not None:
                    new_values[field.name] = tuple(
                        tuple(
                            tuple(tensor.cpu() for tensor in kv_tuple)
                            for kv_tuple in layer_tuple
                        )
                        for layer_tuple in field_value
                    )
                else:
                    new_values[field.name] = None
            elif hasattr(field_value, 'to') and callable(getattr(field_value, 'to')):
                # For any other tensor-like objects with a 'to' method
                new_values[field.name] = field_value.cpu()
            else:
                # For any other fields, just copy as-is
                new_values[field.name] = field_value
        
        # Create new instance with moved tensors
        return output_class(**new_values)
        
    def generate_candidates(
        self, 
        trajectory: str,
        verbose: bool = False
    ) -> Dict[str, any]:
        """Generate N candidate next steps from current trajectory with features
        
        Following bestofn pattern:
        1. Generate candidates in batches with adaptive sizing
        2. Extract features immediately after generation
        3. Return dict with features, attention masks, and inputs (as lists)
        
        Returns:
            Dict containing:
                - greedy_texts: List of generated step texts
                - greedy_tokens: List of token sequences  
                - uhead_features: List of extracted features per batch
                - full_attention_mask: List of attention masks per batch
                - llm_inputs: List of batch inputs with context_lenghts
                - step_candidates: List of StepCandidate objects
        """
        
        if verbose:
            log.info(f"Generating {self.candidates_per_step} candidates from trajectory")
            
        # Tokenize current trajectory
        inputs = self.model.tokenize([trajectory])
        input_length = inputs['input_ids'].shape[1]
        
        # Use adaptive batch sizes: [n, 2, 1]
        batch_sizes = [self.candidates_per_step]
        if self.candidates_per_step > 2:
            batch_sizes.append(2)
        if self.candidates_per_step > 1:
            batch_sizes.append(1)
        
        # Lists to store results (following bestofn pattern)
        all_greedy_texts = []
        all_greedy_tokens = []
        all_uhead_features = []
        all_full_attention_masks = []
        all_llm_inputs = []
        all_step_candidates = []
        
        generated_count = 0
        
        for batch_size in batch_sizes:
            if generated_count >= self.candidates_per_step:
                break
                
            try:
                remaining = self.candidates_per_step - generated_count
                current_batch_size = min(batch_size, remaining)
                
                if verbose:
                    log.info(f"Attempting to generate {current_batch_size} candidates with batch size {batch_size}")
                
                # Generate batch
                batch_result = self._generate_batch_with_features(
                    inputs, input_length, current_batch_size, trajectory, verbose
                )
                
                # Accumulate results
                all_greedy_texts.extend(batch_result["greedy_texts"])
                all_greedy_tokens.extend(batch_result["greedy_tokens"])
                all_uhead_features.append(batch_result["uhead_features"])
                all_full_attention_masks.append(batch_result["full_attention_mask"])
                all_llm_inputs.append(batch_result["llm_inputs"])
                all_step_candidates.extend(batch_result["step_candidates"])
                
                generated_count += current_batch_size
                
                if verbose:
                    log.info(f"Successfully generated {current_batch_size} candidates")
                
            except torch.cuda.OutOfMemoryError as e:
                if verbose:
                    log.warning(f"OOM with batch size {batch_size}: {e}")
                torch.cuda.empty_cache()
                gc.collect()
                
                if batch_size == 1:
                    log.error("CUDA OOM even with single candidate generation")
                    raise
        
        # Return dict following bestofn pattern
        return {
            "greedy_texts": all_greedy_texts,
            "greedy_tokens": all_greedy_tokens,
            "uhead_features": all_uhead_features,
            "full_attention_mask": all_full_attention_masks,
            "llm_inputs": all_llm_inputs,
            "step_candidates": all_step_candidates
        }
    
    def _generate_batch_with_features(
        self,
        inputs: Dict[str, torch.Tensor],
        input_length: int,
        batch_size: int,
        trajectory: str,
        verbose: bool = False
    ) -> Dict[str, any]:
        """Generate batch of candidates and extract features immediately
        
        Following bestofn pattern:
        1. Generate candidates with stopping criteria
        2. Offload to CPU if memory efficient
        3. Extract features 
        4. Return dict with all needed data
        """
        # Move inputs to device
        batch_inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        # Create stopping criteria for this batch
        stopping_criteria = BatchStepStoppingCriteria(
            tokenizer=self.model.tokenizer,
            start_length=input_length,
            detector=self.detector,
            batch_size=batch_size
        )
        
        gen_params = {
            "max_new_tokens": self.max_new_tokens,
            "do_sample": True,
            "temperature": self.temperature,
            "top_p": self.top_p,
            "top_k": self.top_k,
            "num_return_sequences": batch_size,
            "output_scores": True,
            "output_attentions": True,  # Enable for feature extraction
            "output_hidden_states": True,  # Enable for feature extraction
            "return_dict_in_generate": True,
            "stopping_criteria": StoppingCriteriaList([stopping_criteria]),
            "pad_token_id": self.model.tokenizer.eos_token_id,
            "eos_token_id": self.model.tokenizer.eos_token_id
        }
        
        # Store original generation parameters
        old_do_sample = self.model.generation_parameters.do_sample
        old_temperature = self.model.generation_parameters.temperature
        old_top_p = self.model.generation_parameters.top_p
        old_top_k = self.model.generation_parameters.top_k
        
        try:
            # Override generation parameters
            self.model.generation_parameters.do_sample = True
            self.model.generation_parameters.temperature = self.temperature
            self.model.generation_parameters.top_p = self.top_p
            self.model.generation_parameters.top_k = self.top_k
            
            # Generate
            start_time = time.time()
            with torch.no_grad():
                out = self.model.generate(**batch_inputs, **gen_params)
            
            # Check actual number of generated sequences
            actual_generated = out.sequences.shape[0]
            if actual_generated != batch_size:
                log.warning(f"Requested {batch_size} sequences but got {actual_generated}")
                batch_size = actual_generated  # Use actual generated count
            
            if verbose:
                log.info(f"Generated {batch_size} candidates in {time.time() - start_time:.2f}s")
            
            # Offload to CPU if memory efficient (following bestofn pattern)
            if self.memory_efficient and self.offload_to_cpu:
                if verbose:
                    gpu_mem_before = torch.cuda.memory_allocated() / 1024**3 if torch.cuda.is_available() else 0
                    log.debug(f"GPU memory before offloading: {gpu_mem_before:.2f} GB")
                
                out_cpu = self._move_generation_output_to_cpu(out)
                batch_inputs_cpu = {k: v.cpu() if torch.is_tensor(v) else v for k, v in batch_inputs.items()}
                
                # Delete GPU tensors
                del out, batch_inputs
                
                if verbose and torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    gpu_mem_after = torch.cuda.memory_allocated() / 1024**3
                    log.debug(f"GPU memory after offloading: {gpu_mem_after:.2f} GB (saved: {gpu_mem_before - gpu_mem_after:.2f} GB)")
                
                out = out_cpu
                batch_inputs = batch_inputs_cpu
            
            # Extract step candidates and texts first (needed for mask length)
            greedy_texts = []
            greedy_tokens = []
            step_candidates = []
            
            for i in range(batch_size):
                # Get newly generated tokens
                new_tokens = out.sequences[i, input_length:]
                raw_generated_text = self.model.tokenizer.decode(new_tokens, skip_special_tokens=True)
                
                # Extract step
                step_text = self.detector.extract_step_text(raw_generated_text)
                is_complete = self.detector.is_step_complete(raw_generated_text)
                is_trajectory_complete = self.detector.is_trajectory_complete(raw_generated_text)
                
                greedy_texts.append(step_text)
                greedy_tokens.append(new_tokens.tolist())
                
                candidate = StepCandidate(
                    text=step_text,
                    token_ids=new_tokens.tolist(),
                    is_complete=is_complete,
                    is_trajectory_complete=is_trajectory_complete,
                    generation_scores=torch.stack(out.scores, dim=1)[i] if out.scores else None,
                    raw_text=raw_generated_text,
                    batch_idx=0,  # Single batch
                    sample_idx=i,
                    trajectory_context=trajectory
                )
                step_candidates.append(candidate)
            
            # Create full attention mask (following bestofn pattern exactly)
            full_attn_mask = torch.zeros_like(out.sequences).bool()
            for i in range(batch_size):
                idx = input_length
                # Copy input attention mask
                full_attn_mask[i, :idx] = inputs["attention_mask"][0]  # All have same input
                # Set mask for generated tokens
                length = len(greedy_tokens[i])
                full_attn_mask[i][idx: idx + length] = 1
            
            # Expand batch_inputs to match the number of generated sequences
            # Following bestofn pattern - they duplicate inputs n times
            expanded_batch_inputs = {}
            for key, value in batch_inputs.items():
                if torch.is_tensor(value) and value.shape[0] == 1:
                    # Expand single input to batch_size
                    expanded_batch_inputs[key] = value.expand(batch_size, -1)
                else:
                    expanded_batch_inputs[key] = value
            
            # Add context lengths (following bestofn typo) 
            context_lengths = torch.tensor([input_length] * batch_size)
            expanded_batch_inputs["context_lenghts"] = context_lengths
            
            # Add attributes to out object (following bestofn pattern)
            out.full_attention_mask = full_attn_mask
            out.context_lengths = context_lengths
            
            # Extract features using uncertainty head (if available)
            if self.uncertainty_head is not None:
                if verbose:
                    log.info("Extracting features using uncertainty head...")
                
                uhead_features = self.uncertainty_head.feature_extractor(expanded_batch_inputs, out)
                
                if verbose:
                    log.info("Feature extraction complete")
            else:
                # No uncertainty head - set features to None
                uhead_features = None
            
            # Return dict following bestofn pattern
            return {
                "greedy_texts": greedy_texts,
                "greedy_tokens": greedy_tokens,
                "uhead_features": uhead_features,
                "full_attention_mask": full_attn_mask,
                "llm_inputs": expanded_batch_inputs,
                "step_candidates": step_candidates
            }
            
        finally:
            # Restore original parameters
            self.model.generation_parameters.do_sample = old_do_sample
            self.model.generation_parameters.temperature = old_temperature
            self.model.generation_parameters.top_p = old_top_p
            self.model.generation_parameters.top_k = old_top_k
        
    def _generate_single_candidate(
        self, 
        trajectory: str, 
        verbose: bool = False
    ) -> List[StepCandidate]:
        """Fallback to single candidate generation on OOM"""
        if verbose:
            log.warning("Falling back to single candidate generation due to memory constraints")
            
        inputs = self.model.tokenize([trajectory])
        input_length = inputs['input_ids'].shape[1]
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=self.max_new_tokens,
                do_sample=True,
                temperature=self.temperature,
                top_p=self.top_p,
                top_k=self.top_k,
                num_return_sequences=1,
                pad_token_id=self.model.tokenizer.eos_token_id,
                eos_token_id=self.model.tokenizer.eos_token_id
            )
            
        new_tokens = outputs[0][input_length:]
        # Decode with special tokens to preserve EOS detection
        raw_text_with_special = self.model.tokenizer.decode(new_tokens, skip_special_tokens=False)
        raw_text = self.model.tokenizer.decode(new_tokens, skip_special_tokens=True)
        step_text = self.detector.extract_step_text(raw_text)
        
        # Check if EOS token was reached
        # import pdb; pdb.set_trace()
        reached_eos = (new_tokens[-1].item() == self.model.tokenizer.eos_token_id) if len(new_tokens) > 0 else False
        
        candidate = StepCandidate(
            text=step_text,
            token_ids=new_tokens.tolist(),
            is_complete=self.detector.is_step_complete(raw_text),
            is_trajectory_complete=self.detector.is_trajectory_complete(raw_text, reached_eos=reached_eos),
            raw_text=raw_text
        )
        # import pdb; pdb.set_trace()
        return [candidate]
        
    def filter_valid_candidates(self, candidates: List[StepCandidate]) -> List[StepCandidate]:
        """Filter out invalid or empty candidates"""
        valid_candidates = []
        
        for candidate in candidates:
            # Skip empty or very short candidates
            if len(candidate.text.strip()) < 3:
                continue
                
            # Skip candidates that are just punctuation or whitespace
            if not any(c.isalnum() for c in candidate.text):
                continue
                
            valid_candidates.append(candidate)
            
        # If no valid candidates, return at least one
        print(f"valid_candidates: {valid_candidates}")
        if not valid_candidates and candidates:
            valid_candidates = [candidates[0]]
            
        return valid_candidates