import os, glob, gc
import json
import re
from typing import List, Dict, Tuple, Optional
from common_utils import clean_eos_tokens


class InspirationRetrievalSFTDataToGo:
    """
    Prepare SFT data from inspiration retrieval rejection sampling results.
    
    Uses memory-efficient streaming: loads files one by one, selects best sample
    immediately, keeping only lightweight data in memory.
    
    Usage:
        # 100% eval data
        python script.py --results_path /path/to/checkpoint_dir --output_dir output/ \\
                        --train_ratio 0.0 --write_eval
        
        # 100% training data
        python script.py --results_path /path/to/checkpoint_dir --output_dir output/ \\
                        --train_ratio 1.0 --write_train
    """
    
    def __init__(self, rejection_sampling_results_path: str, output_dir: str):
        """
        Initialize the SFT data preparation.
        
        Args:
            rejection_sampling_results_path: Directory containing checkpoint files
                                            (intermediate_results_*.json, shard_*.json)
            output_dir: Directory to save the prepared SFT data
        """
        self.results_path = rejection_sampling_results_path
        self.output_dir = output_dir
        os.makedirs(self.output_dir, exist_ok=True)
        
        if not os.path.isdir(rejection_sampling_results_path):
            raise ValueError(f"Expected directory, got file: {rejection_sampling_results_path}")
        
        print(f"Loading checkpoint files from: {self.results_path}")
        self.rejection_results = self._merge_checkpoint_files(rejection_sampling_results_path)
        print(f"Loaded {len(self.rejection_results)} instances")
        if self.rejection_results:
            mean_acc = sum(r.get('accuracy', 0) for r in self.rejection_results) / len(self.rejection_results)
            print(f"Mean accuracy: {mean_acc:.2%}")
    
    def _merge_checkpoint_files(self, directory: str) -> List[Dict]:
        """
        Merge all checkpoint files using memory-efficient streaming.
        
        Optimization: For each instance, immediately select best sample and keep only
        lightweight data (prompt, response, metadata). This reduces memory from ~150GB
        to ~5GB for 200k instances.
        
        Handles:
        - intermediate_results_*.json files
        - shard_*.json files (new format)
        - Deduplication by (year_pmid, file_step_id) - each step is a unique sample
        
        Args:
            directory: Directory containing checkpoint files
            
        Returns:
            List of lightweight results with best sample data
        """
        
        # Find all relevant checkpoint files
        patterns = [
            os.path.join(directory, 'intermediate_results_*.json'),
            os.path.join(directory, 'shard_*.json')
        ]
        
        all_files = []
        for pattern in patterns:
            all_files.extend(glob.glob(pattern))
        
        # Filter out .tmp files and _current.json (incomplete shards)
        all_files = [f for f in all_files if not f.endswith('.tmp') and '_current.json' not in f]
        # Sort by filename (chronological order based on naming convention)
        all_files.sort(key=lambda x: os.path.basename(x))
        
        print(f"Found {len(all_files)} checkpoint files")
        for f in all_files:
            print(f"  - {os.path.basename(f)} ({os.path.getsize(f) / (1024**3):.2f} GB)")
        
        # Memory-efficient processing: load one file at a time, select best immediately
        print("\n=== Processing files (memory-efficient mode) ===")
        all_results = []  # Lightweight: only best sample data
        seen_keys = set()  # (year_pmid, file_step_id) for deduplication
        
        stats = {
            'total_instances': 0,
            'instances_with_best': 0,
            'skipped_no_correct': 0,
            'skipped_poor_extraction': 0,
            'duplicates_skipped': 0
        }
        
        for filepath in all_files:
            filename = os.path.basename(filepath)
            file_size_gb = os.path.getsize(filepath) / (1024**3)
            print(f"\n  Loading {filename} ({file_size_gb:.2f} GB)...")
            
            try:
                with open(filepath, 'r') as f:
                    data = json.load(f)
                
                if isinstance(data, list):
                    results = data
                elif isinstance(data, dict) and 'all_results' in data:
                    results = data['all_results']
                else:
                    print(f"    Skipping: unknown format")
                    continue
                
                file_added = 0
                file_skipped_no_correct = 0
                file_skipped_poor = 0
                file_duplicates = 0
                
                for file_step_id, instance_results in enumerate(results):
                    stats['total_instances'] += 1
                    
                    # Generate unique key: (year_pmid, file_step_id)
                    year_pmid = instance_results.get('year_pmid', '')
                    unique_key = (year_pmid, filename, file_step_id)
                    
                    if unique_key in seen_keys:
                        file_duplicates += 1
                        stats['duplicates_skipped'] += 1
                        continue
                    seen_keys.add(unique_key)
                    
                    # Select best sample immediately
                    best_sample, skip_reason = self.select_best_trace_per_instance(instance_results)
                    
                    if best_sample is None:
                        if skip_reason == 'no_correct':
                            file_skipped_no_correct += 1
                            stats['skipped_no_correct'] += 1
                        elif skip_reason == 'poor_extraction':
                            file_skipped_poor += 1
                            stats['skipped_poor_extraction'] += 1
                        continue
                    
                    # Keep only lightweight data (no full results list!)
                    lightweight_result = {
                        # Metadata
                        'year_pmid': year_pmid,
                        'accuracy': instance_results.get('accuracy', 0),
                        '_source_file': filename,
                        '_file_step_id': file_step_id,
                        # Best sample data (for SFT)
                        'prompt': best_sample.get('prompt', ''),
                        'response': best_sample.get('response', ''),
                        'selected_label': best_sample.get('selected_label', ''),
                        'selected_title_recovered': best_sample.get('selected_title_recovered', ''),
                        'is_correct': best_sample.get('is_correct', False),
                        # Keep ground_truth for reference
                        'ground_truth': instance_results.get('ground_truth', []),
                    }
                    
                    all_results.append(lightweight_result)
                    file_added += 1
                    stats['instances_with_best'] += 1
                
                print(f"    Added: {file_added}, No correct: {file_skipped_no_correct}, "
                      f"Poor extraction: {file_skipped_poor}, Duplicates: {file_duplicates}")
                
                # Free memory immediately after each file
                del data, results
                gc.collect()
                
            except Exception as e:
                print(f"    Error loading {filename}: {e}")
                import traceback
                traceback.print_exc()
                continue
        
        # Sort by year_pmid (chronological order)
        def get_sort_key(r):
            year_pmid = r.get('year_pmid', '9999_99999999')
            if year_pmid:
                parts = year_pmid.split('_')
                year = int(parts[0]) if parts[0] != '0000' else 2020
                pmid = parts[1] if len(parts) > 1 else '0'
                return (year, pmid)
            return (9999, '99999999')
        
        all_results.sort(key=get_sort_key)
        
        print(f"\n=== Summary ===")
        print(f"Total instances scanned: {stats['total_instances']}")
        print(f"Instances with best sample: {stats['instances_with_best']}")
        print(f"Skipped (no correct): {stats['skipped_no_correct']}")
        print(f"Skipped (poor extraction): {stats['skipped_poor_extraction']}")
        print(f"Duplicates skipped: {stats['duplicates_skipped']}")
        print(f"Final result count: {len(all_results)}")
        
        return all_results
    
    # Input:
    #   instance_results: {
    #     'results': [list of individual sampling attempts],
    #     'accuracy': float (0.0 to 1.0),
    #     'correct_count': int,
    #     'total_samples': int,
    #     'background': [research_question, survey, pre_hyp],
    #     'ground_truth': [title, abstract, inspiration, relation]
    #   }
    #   Each item in 'results' contains:
    #   {
    #     'prompt': str,
    #     'response': str (full LLM response with reasoning),
    #     'selected_label': str (e.g., 'A', 'B', 'L'),
    #     'selected_title_recovered': str (mapped from label_to_title),
    #     'selection_reason': str (extracted reason),
    #     'ground_truth_label': str (e.g., 'A'),
    #     'label_to_title': dict (e.g., {'A': 'title1', 'B': 'title2', ...}),
    #     'is_correct': bool
    #   }
    # Output:
    #   best_sample: The best correct sample dictionary, or None
    def select_best_trace_per_instance(self, instance_results: Dict) -> Tuple[Optional[Dict], str]:
        """
        Select the single best reasoning trace from an instance's results.
        
        Args:
            instance_results: Results for a single instance containing multiple samples
            
        Returns:
            Tuple of (best_sample, skip_reason) where:
            - best_sample: The best correct sample that meets quality criteria, or None if none exist
            - skip_reason: 'no_correct', 'poor_extraction', 'invalid_format', or '' if successful
        """
        # Check if instance_results has the expected structure
        if not isinstance(instance_results, dict):
            print(f"WARNING: instance_results is not a dict, it's {type(instance_results)}")
            return None, 'invalid_format'
            
        if 'results' not in instance_results:
            print(f"WARNING: 'results' key missing in instance. Available keys: {list(instance_results.keys())}")
            return None, 'invalid_format'
            
        # Get only correct samples
        correct_samples = [
            sample for sample in instance_results['results'] 
            if sample.get('is_correct', False)
        ]
        
        if not correct_samples:
            return None, 'no_correct'
        
        # Filter samples with valid label and recovered title
        high_quality_samples = [
            sample for sample in correct_samples
            if sample.get('selected_label') and sample.get('selected_title_recovered')
        ]
        
        if not high_quality_samples:
            return None, 'poor_extraction'
        
        # Select the best one based on full response length
        # Simple heuristic: prefer longer, more detailed responses with richer reasoning
        best_sample = max(high_quality_samples, key=lambda x: 
            len(x.get('response', '')) if x.get('response') else 0
        )
        
        return best_sample, ''
    
    def prepare_sft_data_input_output(self) -> List[Tuple[str, str]]:
        """
        Prepare input-output pairs for SFT training.
        
        Data is in lightweight format (best sample pre-selected):
        {
            'year_pmid', 'accuracy', '_source_file', '_file_step_id',
            'prompt', 'response', 'selected_label', 'selected_title_recovered',
            'is_correct', 'ground_truth'
        }
        
        Returns:
            List of (input, output) tuples
        """
        sft_data_collection = []
        stats = {
            'total_instances': 0,
            'instances_with_correct': 0,
            'easy_count': 0,
            'medium_count': 0,
            'hard_count': 0,
            'skipped_no_reasoning': 0,
            'skipped_poor_format': 0
        }
        
        for idx, instance_data in enumerate(self.rejection_results):
            stats['total_instances'] += 1
            
            if idx % 5000 == 0:
                print(f"Processing instance {idx}/{len(self.rejection_results)}")
            
            accuracy = instance_data.get('accuracy', 0)
            prompt = instance_data.get('prompt', '')
            response = instance_data.get('response', '')
            
            # Check if we have valid prompt and response
            if not prompt or not response:
                stats['skipped_no_reasoning'] += 1
                continue
            
            # The prompt already contains all the formatted input
            input_text = prompt
            output_text = response
            
            # Clean any existing EOS tokens
            output_text = clean_eos_tokens(output_text)
            
            # R1-Distill Native Format:
            # - Output must start with <think>\n
            # - Output ends with \n</think>\n\n[answer]
            if output_text.strip().startswith("<think>"):
                if not output_text.strip().startswith("<think>\n"):
                    output_text = output_text.replace("<think>", "<think>\n", 1)
            else:
                output_text = "<think>\n" + output_text
            
            # Normalize </think> format
            think_close_count = output_text.count('</think>')
            if think_close_count != 1:
                stats['skipped_poor_format'] += 1
                continue
            
            output_text = re.sub(r'\s*</think>\s*', '\n</think>\n\n', output_text)
            
            sft_data_collection.append((input_text, output_text))
            stats['instances_with_correct'] += 1
            
            # Track difficulty distribution
            if accuracy >= 0.8:
                stats['easy_count'] += 1
            elif accuracy >= 0.4:
                stats['medium_count'] += 1
            else:
                stats['hard_count'] += 1
        
        # Print statistics
        print("\n=== Data Collection Statistics ===")
        print(f"Total instances processed: {stats['total_instances']}")
        print(f"Instances included: {stats['instances_with_correct']}")
        print(f"Skipped (no reasoning): {stats['skipped_no_reasoning']}")
        print(f"Skipped (poor format): {stats['skipped_poor_format']}")
        print(f"\nDifficulty distribution:")
        print(f"  Easy (≥80% acc): {stats['easy_count']}")
        print(f"  Medium (40-79% acc): {stats['medium_count']}")
        print(f"  Hard (<40% acc): {stats['hard_count']}")
        
        return sft_data_collection
    
    def save_sft_data_to_go(self, train_ratio, 
                            write_train: bool, 
                            write_eval: bool):
        """
        Save the prepared SFT data in JSONL format for training.
        
        Args:
            train_ratio: Ratio of data to use for training (0.0 to 1.0)
            write_train: Whether to write training data
            write_eval: Whether to write evaluation data
        
        Note: EOS tokens are handled by the training framework based on the template setting.
        We clean any existing tokens and provide clean text for the framework to process.
        """
        # Prepare the data
        sft_data_collection = self.prepare_sft_data_input_output()
        
        if not sft_data_collection:
            print("No data to save!")
            return
        
        # Filter out samples with malformed reasoning traces
        # Must have both <think>\n at start and \n</think>\n\n in the middle
        # (consistent with HC format: reasoning ends with newline, then </think>, then double newline before answer)
        original_count = len(sft_data_collection)
        sft_data_collection = [
            (input_text, output_text) 
            for input_text, output_text in sft_data_collection 
            if output_text.startswith('<think>\n') and '\n</think>\n\n' in output_text
        ]
        filtered_count = original_count - len(sft_data_collection)
        if filtered_count > 0:
            print(f"\n⚠ Filtered out {filtered_count} samples missing proper </think> format ({filtered_count/original_count*100:.2f}%)")
            print(f"  Remaining samples: {len(sft_data_collection)}")
        
        # R1-Distill Native Format: No system prompt needed
        # The user prompt already contains task instructions
        
        # SORT DATA BY YEAR for chronological ordering (curriculum learning)
        # Check if year_pmid is preserved in the rejection results
        has_year_pmid = (self.rejection_results and 
                        len(self.rejection_results) > 0 and 
                        'year_pmid' in self.rejection_results[0])
        
        if has_year_pmid:
            print("\n✓ Year_PMID information found - sorting chronologically")
            
            # Build a simple mapping: prompt_prefix -> (year, pmid) for sorting
            prompt_to_year = {}
            for result in self.rejection_results:
                if 'year_pmid' in result and result['year_pmid'] and 'prompt' in result:
                    # Parse year_pmid string
                    year_pmid = result['year_pmid']
                    parts = year_pmid.split('_')
                    year = int(parts[0]) if parts[0] != '0000' else 2020
                    pmid = parts[1] if len(parts) > 1 else '0'
                    prompt_to_year[result['prompt'][:500]] = (year, pmid)
            
            # Sort the SFT data by year_pmid
            def get_sort_key(item):
                prompt_key = item[0][:500]  # item = (input_text, output_text)
                return prompt_to_year.get(prompt_key, (9999, '99999999'))
            
            sft_data_collection.sort(key=get_sort_key)
            
            # Print summary
            print(f"Sorted {len(sft_data_collection)} samples by year and PMID")
            if len(prompt_to_year) > 0:
                years = [y for y, p in prompt_to_year.values()]
                print(f"Year range: {min(years)} to {max(years)}")
        else:
            print("\n⚠ Year_PMID not found - keeping original order (approximately chronological)")
            # Don't shuffle - keep original order
        
        # Calculate split point for train/eval
        total_samples = len(sft_data_collection)
        train_size = int(total_samples * train_ratio)
        
        # Split the data
        train_data = sft_data_collection[:train_size]
        eval_data = sft_data_collection[train_size:]
        
        print(f"\n=== Final Dataset Statistics ===")
        print(f"Total samples: {total_samples}")
        print(f"Train samples: {len(train_data)}")
        print(f"Eval samples: {len(eval_data)}")
        
        # Write training data if requested
        # R1-Distill Native Format: No system prompt (user prompt already contains task instructions)
        if write_train and len(train_data) > 0:
            train_path = os.path.join(self.output_dir, "train.jsonl")
            with open(train_path, "w", encoding="utf-8") as f:
                for input_text, output_text in train_data:
                    row = {
                        "conversations": [
                            {"role": "user", "content": input_text},
                            {"role": "assistant", "content": output_text}
                        ]
                    }
                    f.write(json.dumps(row, ensure_ascii=False) + "\n")
            print(f"Training data saved to {train_path}")
        elif write_train and len(train_data) == 0:
            print("Warning: write_train is True but train_ratio is 0 or no train data available")
        
        # Write evaluation data if requested
        # R1-Distill Native Format: No system prompt (user prompt already contains task instructions)
        if write_eval and len(eval_data) > 0:
            eval_path = os.path.join(self.output_dir, "eval.jsonl")
            with open(eval_path, "w", encoding="utf-8") as f:
                for input_text, output_text in eval_data:
                    row = {
                        "conversations": [
                            {"role": "user", "content": input_text},
                            {"role": "assistant", "content": output_text}
                        ]
                    }
                    f.write(json.dumps(row, ensure_ascii=False) + "\n")
            print(f"Evaluation data saved to {eval_path}")
        elif write_eval and len(eval_data) == 0:
            print("Warning: write_eval is True but no eval data available")
        
        # Save a summary report
        self._save_summary_report(total_samples, train_size, len(eval_data), 
                                   write_train, write_eval)
    
    def _save_summary_report(self, total_samples: int, train_size: int, eval_size: int,
                            write_train: bool, write_eval: bool):
        """
        Save a summary report of the data preparation.
        
        Args:
            total_samples: Total number of samples
            train_size: Number of training samples
            eval_size: Number of evaluation samples
            write_train: Whether training data was written
            write_eval: Whether evaluation data was written
        """
        # Determine report file name based on what was written
        if write_train and write_eval:
            report_name = "data_preparation_summary_train_eval.md"
        elif write_train:
            report_name = "data_preparation_summary_train.md"
        elif write_eval:
            report_name = "data_preparation_summary_eval.md"
        else:
            report_name = "data_preparation_summary.md"
        
        report_path = os.path.join(self.output_dir, report_name)
        
        with open(report_path, 'w') as f:
            f.write("# Inspiration Retrieval SFT Data Preparation Summary\n\n")
            f.write(f"## Source\n")
            f.write(f"- Checkpoint directory: {self.results_path}\n")
            f.write(f"- Original instances: {len(self.rejection_results)}\n\n")
            f.write(f"## Prepared Dataset\n")
            f.write(f"- Total samples: {total_samples}\n")
            f.write(f"- Train samples: {train_size} ({train_size/total_samples*100:.1f}%)\n")
            f.write(f"- Eval samples: {eval_size} ({eval_size/total_samples*100:.1f}%)\n\n")
            f.write(f"## Strategy\n")
            f.write(f"- One best reasoning trace per instance\n")
            f.write(f"- Selected based on reasoning quality and length\n")
            f.write(f"- Includes easy, medium, and hard instances\n")
            f.write(f"- Preserves surprising/non-obvious inspirations from hard cases\n")
        
        print(f"Summary report saved to {report_path}")
    


def main():
    """Main function to prepare SFT data from rejection sampling results."""
    import argparse
    
    parser = argparse.ArgumentParser(description='Prepare SFT data from inspiration retrieval rejection sampling')
    parser.add_argument('--results_path', type=str, required=True,
                       help='Directory containing checkpoint files (intermediate_results_*.json, shard_*.json)')
    parser.add_argument('--output_dir', type=str, required=True,
                       help='Directory to save prepared SFT data')
    parser.add_argument('--train_ratio', type=float, required=True,
                       help='Ratio of data to use for training (0.0 to 1.0)')
    parser.add_argument('--write_train', action='store_true',
                       help='Write training data to train.jsonl')
    parser.add_argument('--write_eval', action='store_true',
                       help='Write evaluation data to eval.jsonl')
    
    args = parser.parse_args()
    
    # Validate train_ratio
    if not 0.0 <= args.train_ratio <= 1.0:
        parser.error("train_ratio must be between 0.0 and 1.0")
    
    # Ensure at least one output is requested
    if not args.write_train and not args.write_eval:
        parser.error("At least one of --write_train or --write_eval must be specified")
    
    # Warn if train_ratio > 0 but write_train is False
    if args.train_ratio > 0 and not args.write_train:
        print("Warning: train_ratio > 0 but write_train is False. No training data will be written.")
    
    # Warn if train_ratio < 1.0 but write_eval is False
    if args.train_ratio < 1.0 and not args.write_eval:
        print("Warning: train_ratio < 1.0 but write_eval is False. No evaluation data will be written.")
    
    # Initialize and run data preparation
    preparer = InspirationRetrievalSFTDataToGo(
        rejection_sampling_results_path=args.results_path,
        output_dir=args.output_dir
    )
    
    preparer.save_sft_data_to_go(
        train_ratio=args.train_ratio,
        write_train=args.write_train,
        write_eval=args.write_eval
    )
    print("\n✓ SFT data preparation complete!")


if __name__ == "__main__":
    main()
