"""
Direct online best-of-n implementation using direct UHead scorer
"""

import torch
import logging
import gc
from typing import List, Dict, Optional
from tqdm import tqdm

from datasets import Dataset
from lm_polygraph import WhiteboxModel

from online_bestofn.step_detection import StepBoundaryDetector
from online_bestofn.step_generation import StepCandidateGenerator
from online_bestofn.scorers.uhead_batch import BatchUHeadScorer
from utils import parse_ans

log = logging.getLogger(__name__)


class DirectOnlineBestOfN:
    """
    Simplified online best-of-n that uses UHead directly without stat calculator pipeline.
    
    Key improvements:
    1. Direct UHead application - no stat calculator overhead
    2. Cleaner separation of scoring from generation
    3. More efficient batch processing
    """
    
    def __init__(
        self,
        model: WhiteboxModel,
        uhead_path: str,
        candidates_per_step: int = 10,
        max_steps: int = 20,
        max_new_tokens: int = 350,
        temperature: float = 0.7,
        device: str = "cuda",
        verbose: bool = True,
        generation_batch_size: int = None,
        feature_batch_size: int = 1,
        memory_efficient: bool = True
    ):
        self.model = model
        self.candidates_per_step = candidates_per_step
        self.max_steps = max_steps
        self.max_new_tokens = max_new_tokens
        self.temperature = temperature
        self.device = device
        self.verbose = verbose
        self.generation_batch_size = generation_batch_size or candidates_per_step
        self.memory_efficient = memory_efficient
        
        # Initialize components
        self.detector = StepBoundaryDetector(
            step_patterns=["- Step", "<Answer>:", "\n<Answer>:"],
            answer_patterns=["<Answer>:", "\n<Answer>:"],
            max_tokens_per_step=max_new_tokens
        )
        
        self.step_generator = StepCandidateGenerator(
            model=model,
            uhead_path=uhead_path,  # Pass uhead for feature extraction
            detector=self.detector,
            candidates_per_step=candidates_per_step,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            device=device,
            memory_efficient=memory_efficient
        )
        
        self.scorer = BatchUHeadScorer(
            uncertainty_head=self.step_generator.uncertainty_head,  # Share the same uhead
            model=model,  # Pass WhiteboxModel for StepsExtractor
            device=device
        )
    
    def generate_trajectory(self, prompt: str) -> Dict[str, any]:
        """
        Generate a trajectory step-by-step using direct UHead scoring.
        
        Args:
            prompt: Initial prompt/question
            
        Returns:
            Dictionary with:
                - trajectory: Final generated trajectory
                - steps: List of selected steps
                - step_scores: Scores for each selected step
                - completed: Whether trajectory reached completion
        """
        
        trajectory = prompt
        selected_steps = []
        step_scores = []
        
        for step_num in range(self.max_steps):
            if self.verbose:
                log.info(f"\n=== Step {step_num} ===")
            
            # Generate candidates with features
            if self.verbose:
                log.info(f"Generating candidates with temperature={self.temperature}")
            
            # Generate candidates and extract features in single pass
            result_dict = self.step_generator.generate_candidates(
                trajectory, 
                verbose=self.verbose
            )
            
            candidates = result_dict["step_candidates"]
            
            if not candidates:
                if self.verbose:
                    log.info("No candidates generated, stopping")
                break
            
            # Score candidates using pre-extracted features
            # Pass the trajectory as input_texts for proper claim extraction context
            candidate_scores = self.scorer.compute_uncertainties(
                result_dict, 
                input_texts=[trajectory] * len(candidates)
            )
            
            # Log all candidates
            if self.verbose:
                log.info(f"Generated {len(candidates)} candidates:")
                for i, (candidate, score) in enumerate(zip(candidates, candidate_scores)):
                    log.info(f"  [{i}] Score: {score:.3f} | Text: '{candidate.text}'")
            
            # Select best candidate (lowest uncertainty)
            best_idx = min(range(len(candidate_scores)), key=lambda i: candidate_scores[i])
            selected_candidate = candidates[best_idx]
            selected_score = candidate_scores[best_idx]
            
            if self.verbose:
                log.info(f"Selected candidate {best_idx} (score: {selected_score:.3f})")
                log.info(f"Text: {selected_candidate.text}")
            
            # Update trajectory
            trajectory += selected_candidate.text + '\n'
            selected_steps.append(selected_candidate.text)
            step_scores.append(selected_score)
            
            # Clean up GPU memory after each step
            if self.memory_efficient:
                # Clear candidates and scores from memory
                del candidates, result_dict, candidate_scores
                torch.cuda.empty_cache()
                gc.collect()
                if self.verbose:
                    log.debug(f"Cleaned up memory after step {step_num}")
            
            # Check if trajectory is complete
            if selected_candidate.is_trajectory_complete:
                if self.verbose:
                    log.info("Answer pattern detected - generating final answer")
                
                # Generate final answer
                final_answer = self._generate_final_answer(trajectory)
                trajectory += final_answer
                selected_steps.append(final_answer)
                break
        
        # Log the complete trajectory at the end
        if self.verbose:
            log.info(f"\n{'='*60}")
            log.info("FINAL TRAJECTORY:")
            log.info(f"{'='*60}")
            log.info(trajectory)
            log.info(f"{'='*60}")
            log.info(f"Total steps: {len(selected_steps)}")
            if step_scores:
                log.info(f"Average step score: {sum(step_scores)/len(step_scores):.3f}")
        
        return {
            "trajectory": trajectory,
            "steps": selected_steps,
            "step_scores": step_scores,
            "completed": len(selected_steps) > 0
        }
    
    def _generate_final_answer(self, trajectory: str) -> str:
        """Generate and select best final answer using single forward pass"""
        
        # For answer generation, we need to generate without step detection
        # So we'll use a custom generation method similar to _generate_batch_with_features
        # but without stopping criteria
        
        if self.verbose:
            log.info(f"Generating {self.candidates_per_step} answer candidates")
        
        # Generate answer candidates with features
        result_dict = self._generate_answer_candidates_with_features(trajectory)
        
        answer_candidates = result_dict["greedy_texts"]
        
        # Score candidates using pre-extracted features
        # Pass the trajectory as input_texts for proper claim extraction context
        answer_scores = self.scorer.compute_uncertainties(
            result_dict,
            input_texts=[trajectory] * len(answer_candidates)
        )
        
        # Log all answer candidates with scores (similar to trajectory generation)
        if self.verbose:
            log.info(f"Generated {len(answer_candidates)} answer candidates:")
            for i, (candidate, score) in enumerate(zip(answer_candidates, answer_scores)):
                # Truncate long answers for readability
                candidate_preview = candidate[:200] + "..." if len(candidate) > 200 else candidate
                log.info(f"  [{i}] Score: {score:.3f} | Answer: '{candidate_preview}'")
        
        # Select best answer
        best_idx = min(range(len(answer_scores)), key=lambda i: answer_scores[i])
        
        if self.verbose:
            log.info(f"Selected answer {best_idx} (score: {answer_scores[best_idx]:.3f})")
        
        selected_answer = answer_candidates[best_idx]
        
        # Clean up memory
        if self.memory_efficient:
            del answer_candidates, answer_scores, result_dict
            torch.cuda.empty_cache()
            gc.collect()
        
        return selected_answer
    
    def _generate_answer_candidates_with_features(self, trajectory: str) -> Dict[str, any]:
        """Generate answer candidates with features (no step detection)"""
        
        # Tokenize trajectory
        inputs = self.model.tokenize([trajectory])
        input_length = inputs['input_ids'].shape[1]
        
        # Use adaptive batch sizes
        batch_sizes = [self.candidates_per_step]
        if self.candidates_per_step > 2:
            batch_sizes.append(2)
        if self.candidates_per_step > 1:
            batch_sizes.append(1)
        
        # Lists to store results
        all_greedy_texts = []
        all_greedy_tokens = []
        all_uhead_features = []
        all_full_attention_masks = []
        all_llm_inputs = []
        
        generated_count = 0
        
        for batch_size in batch_sizes:
            if generated_count >= self.candidates_per_step:
                break
                
            try:
                remaining = self.candidates_per_step - generated_count
                current_batch_size = min(batch_size, remaining)
                
                # Generate batch without step detection
                batch_result = self._generate_answer_batch_with_features(
                    inputs, input_length, current_batch_size, trajectory
                )
                
                # Accumulate results
                all_greedy_texts.extend(batch_result["greedy_texts"])
                all_greedy_tokens.extend(batch_result["greedy_tokens"])
                all_uhead_features.append(batch_result["uhead_features"])
                all_full_attention_masks.append(batch_result["full_attention_mask"])
                all_llm_inputs.append(batch_result["llm_inputs"])
                
                generated_count += current_batch_size
                
            except torch.cuda.OutOfMemoryError:
                torch.cuda.empty_cache()
                gc.collect()
                if batch_size == 1:
                    raise
        
        return {
            "greedy_texts": all_greedy_texts,
            "greedy_tokens": all_greedy_tokens,
            "uhead_features": all_uhead_features,
            "full_attention_mask": all_full_attention_masks,
            "llm_inputs": all_llm_inputs
        }
    
    def _generate_answer_batch_with_features(
        self,
        inputs: Dict[str, torch.Tensor],
        input_length: int,
        batch_size: int,
        trajectory: str
    ) -> Dict[str, any]:
        """Generate batch of answers with features (no step detection)"""
        
        # Move inputs to device
        batch_inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        gen_params = {
            "max_new_tokens": 1024,  # Longer for answers
            "do_sample": True,
            "temperature": self.temperature,
            "top_p": 0.95,
            "top_k": 50,
            "num_return_sequences": batch_size,
            "output_scores": True,
            "output_attentions": True,
            "output_hidden_states": True,
            "return_dict_in_generate": True,
            # No stopping criteria - only stop at EOS
            "pad_token_id": self.model.tokenizer.eos_token_id,
            "eos_token_id": self.model.tokenizer.eos_token_id
        }
        
        # Generate
        with torch.no_grad():
            out = self.model.generate(**batch_inputs, **gen_params)
        
        # Check actual number of generated sequences
        actual_generated = out.sequences.shape[0]
        if actual_generated != batch_size:
            log.warning(f"Answer generation: Requested {batch_size} sequences but got {actual_generated}")
            batch_size = actual_generated  # Use actual generated count
        
        # Offload to CPU if memory efficient
        if self.memory_efficient:
            out_cpu = self.step_generator._move_generation_output_to_cpu(out)
            batch_inputs_cpu = {k: v.cpu() if torch.is_tensor(v) else v for k, v in batch_inputs.items()}
            
            del out, batch_inputs
            torch.cuda.empty_cache()
            
            out = out_cpu
            batch_inputs = batch_inputs_cpu
        
        # Extract texts
        greedy_texts = []
        greedy_tokens = []
        
        for i in range(batch_size):
            new_tokens = out.sequences[i, input_length:]
            answer_text = self.model.tokenizer.decode(new_tokens, skip_special_tokens=True)
            greedy_texts.append(answer_text)
            greedy_tokens.append(new_tokens.tolist())
        
        # Create attention mask
        full_attn_mask = torch.zeros_like(out.sequences).bool()
        for i in range(batch_size):
            idx = input_length
            full_attn_mask[i, :idx] = inputs["attention_mask"][0]
            length = len(greedy_tokens[i])
            full_attn_mask[i][idx: idx + length] = 1
        
        # Expand batch_inputs to match the number of generated sequences
        expanded_batch_inputs = {}
        for key, value in batch_inputs.items():
            if torch.is_tensor(value) and value.shape[0] == 1:
                # Expand single input to batch_size
                expanded_batch_inputs[key] = value.expand(batch_size, -1)
            else:
                expanded_batch_inputs[key] = value
        
        # Add context lengths
        context_lengths = torch.tensor([input_length] * batch_size)
        expanded_batch_inputs["context_lenghts"] = context_lengths
        
        # Add attributes to out object (following bestofn pattern)
        out.full_attention_mask = full_attn_mask
        out.context_lengths = context_lengths
        
        # Extract features
        uhead_features = self.step_generator.uncertainty_head.feature_extractor(expanded_batch_inputs, out)
        
        return {
            "greedy_texts": greedy_texts,
            "greedy_tokens": greedy_tokens,
            "uhead_features": uhead_features,
            "full_attention_mask": full_attn_mask,
            "llm_inputs": expanded_batch_inputs
        }
    
    def cleanup(self):
        """Clean up resources"""
        # No cleanup needed - uncertainty head is owned by step_generator
        pass


def run_direct_online_bestofn(
    dataset: Dataset,
    model: WhiteboxModel,
    uhead_path: str,
    save_path: str,
    n: int = 10,
    max_new_tokens: int = 250,
    subset: Optional[int] = None,
    verbose: bool = True
):
    """
    Run direct online best-of-n evaluation on a dataset.
    
    Args:
        dataset: Evaluation dataset
        model: Language model
        uhead_path: Path to UHead model
        save_path: Path to save results
        n: Number of candidates per step
        max_new_tokens: Max tokens per step
        subset: Evaluate only first N samples
        verbose: Enable verbose logging
    """
    
    if subset:
        dataset = dataset.select(range(min(subset, len(dataset))))
        log.info(f"Using subset of {len(dataset)} samples")
    
    # Initialize generator
    generator = DirectOnlineBestOfN(
        model=model,
        uhead_path=uhead_path,
        candidates_per_step=n,
        temperature=0.7,
        device=str(model.device()),
        verbose=verbose
    )
    
    results = []
    
    try:
        for i, sample in enumerate(tqdm(dataset, desc="Processing samples")):
            log.info(f"\n{'='*60}")
            log.info(f"Sample {i+1}/{len(dataset)}")
            log.info(f"Question: {sample['question']}")
            log.info(f"Gold Answer: {sample['answer']}")
            
            # Generate trajectory
            result = generator.generate_trajectory(sample["question"])
            
            # Extract answer from trajectory
            generated_text = result["trajectory"]
            if "trajectory" in result:
                generated_text = result["trajectory"].replace(sample["question"], "").strip()
            
            # Check correctness
            is_correct = _is_correct_answer(generated_text, sample["answer"])
            
            # Store result
            results.append({
                "question": sample["question"],
                "gold_answer": sample["answer"],
                "generated_trajectory": result["trajectory"],
                "generated_answer": generated_text,
                "steps": result["steps"],
                "step_scores": result["step_scores"],
                "is_correct": is_correct,
                "completed": result["completed"]
            })
            
            log.info(f"Generated: {generated_text}...")
            log.info(f"Gold: {sample['answer']}")
            log.info(f"Generated answer: {parse_ans(generated_text)}")
            log.info(f"Correct: {is_correct}")
            # import pdb; pdb.set_trace()
            
            # Clean up memory after each sample
            if generator.memory_efficient:
                torch.cuda.empty_cache()
                gc.collect()
            
            # Save periodically
            if (i + 1) % 10 == 0:
                torch.save(results, save_path)
                log.info(f"Saved {len(results)} results to {save_path}")
    
    finally:
        # Final save
        torch.save(results, save_path)
        log.info(f"Final save: {len(results)} results to {save_path}")
        
        # Cleanup
        generator.cleanup()
    
    # Print summary
    correct = sum(r["is_correct"] for r in results)
    log.info(f"\nAccuracy: {correct}/{len(results)} = {correct/len(results):.2%}")
    
    return results


def _is_correct_answer(generated: str, gold: str) -> bool:
    """Check if generated answer matches gold answer"""
    try:
        pred = parse_ans(generated)
        gold_parsed = gold
        
        if pred is None or gold_parsed is None:
            return False
            
        return pred == gold_parsed
    except:
        return False


if __name__ == "__main__":
    # Example usage
    import argparse
    from datasets import load_dataset
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset-path", type=str, required=True)
    parser.add_argument("--model-path", type=str, default="Qwen/Qwen3-1.7B")
    parser.add_argument("--uhead-path", type=str, required=True)
    parser.add_argument("--save-path", type=str, required=True)
    parser.add_argument("--subset", type=int, default=None)
    parser.add_argument("--n", type=int, default=10)
    parser.add_argument("--device", type=str, default="cuda")
    
    args = parser.parse_args()
    
    # Load model
    tokenizer = load_qwen_tokenizer(args.model_path)
    base_model = load_qwen_model(args.model_path, args.device)
    model = WhiteboxModel(base_model, tokenizer)
    
    # Load dataset
    dataset = load_dataset(args.dataset_path, split="test")
    
    # Run evaluation
    run_direct_online_bestofn(
        dataset=dataset,
        model=model,
        uhead_path=args.uhead_path,
        save_path=args.save_path,
        n=args.n,
        subset=args.subset,
        verbose=True
    )