import os
import json
from typing import List, Tuple
from prompt_store import instruction_prompts
from common_utils import clean_eos_tokens


class HCSFTDataToGo:
    """
    Prepare SFT data for hypothesis composition from reasoning traces.
    
    Usage examples:
        # 100% eval data
        python script.py --qa_data_dir /path/to/qa --reasoning_trace_dir /path/to/traces \\
                        --output_dir /path/to/output --train_ratio 0.0 --write_eval
        
        # 95/5 train/eval split (default)
        python script.py --qa_data_dir /path/to/qa --reasoning_trace_dir /path/to/traces \\
                        --output_dir /path/to/output --write_train --write_eval
        
        # Only training data
        python script.py --qa_data_dir /path/to/qa --reasoning_trace_dir /path/to/traces \\
                        --output_dir /path/to/output --train_ratio 1.0 --write_train
        
        # Legacy format with postfix
        python script.py --qa_data_dir <folder_sft_qa_data>/pubmed_sft_qa_data_run8 \\
                        --reasoning_trace_dir <folder_sft_hc_reasoning_trace>/pubmed_sft_HC_reasoning_trace_run8 \\
                        --output_dir <folder_sft_hc_sft_data_to_go>/pubmed_sft_HC_sft_data_to_go_run8 \\
                        --write_train --write_eval
    """

    def __init__(self, sft_qa_data_dir: str, sft_HC_reasoning_trace_dir: str, sft_HC_sft_data_to_go_dir: str):
        """
        Initialize the HC SFT data preparation.
        
        Args:
            sft_qa_data_dir: Directory containing QA data
            sft_HC_reasoning_trace_dir: Directory containing reasoning traces
            sft_HC_sft_data_to_go_dir: Directory to save prepared SFT data
        """
        self.sft_qa_data_dir = sft_qa_data_dir
        self.sft_HC_reasoning_trace_dir = sft_HC_reasoning_trace_dir
        self.sft_HC_sft_data_to_go_dir = sft_HC_sft_data_to_go_dir

        os.makedirs(self.sft_HC_sft_data_to_go_dir, exist_ok=True)
        self.paper_with_reasoning_trace_list = self.find_paper_with_reasoning_trace()

    
    def find_paper_with_reasoning_trace(self) -> List[str]:
        """
        Find all reasoning trace files and sort them chronologically
        """
        # all data to process
        reasoning_trace_file_collection = []
        for cur_file in os.listdir(self.sft_HC_reasoning_trace_dir):
            if cur_file.endswith(".json"):
                reasoning_trace_file_collection.append(cur_file)
        
        # Sort by year (handling special case 0000 as 2020), then by PMID
        def get_sort_key(filename):
            # Extract year and PMID from filename like "YYYY_PMID.json"
            year_str = filename.split('_')[0]
            year = int(year_str) if year_str != '0000' else 2020  # Special case: 0000 means 2020
            pmid = filename.split('_')[1].split('.')[0] if '_' in filename else ''
            return (year, pmid)
        
        reasoning_trace_file_collection.sort(key=get_sort_key)
        print(f"reasoning_trace_file_collection[0:5]: {reasoning_trace_file_collection[0:5]}")
        print(f"reasoning_trace_file_collection[-5:]: {reasoning_trace_file_collection[-5:]}")
        print(f"Total number of reasoning trace files: {len(reasoning_trace_file_collection)}")
        return reasoning_trace_file_collection



    def save_sft_data_to_go(self, train_ratio, 
                            write_train, 
                            write_eval,
                            min_hypothesis_word_length):
        """
        Save the prepared SFT data in JSONL format for training.
        
        Args:
            train_ratio: Ratio of data to use for training (0.0 to 1.0)
            write_train: Whether to write training data
            write_eval: Whether to write evaluation data
            min_hypothesis_word_length: Minimum word length for hypothesis to be included
        
        Note: EOS tokens are handled by the training framework based on the template setting.
        We clean any existing tokens and provide clean text for the framework to process.
        """
        sft_data_to_go_collection = self.prepare_sft_data_input_output(min_hypothesis_word_length=min_hypothesis_word_length)
        
        if not sft_data_to_go_collection:
            print("No data to save!")
            return

        # Filter out samples with malformed reasoning traces
        # Must have both <think>\n at start and \n</think>\n\n in the middle
        original_count = len(sft_data_to_go_collection)
        sft_data_to_go_collection = [
            (input_text, output_text) 
            for input_text, output_text in sft_data_to_go_collection 
            if output_text.startswith('<think>\n') and '\n</think>\n\n' in output_text
        ]
        filtered_count = original_count - len(sft_data_to_go_collection)
        if filtered_count > 0:
            print(f"\n⚠ Filtered out {filtered_count} samples missing </think> ({filtered_count/original_count*100:.2f}%)")
            print(f"  Remaining samples: {len(sft_data_to_go_collection)}")

        # Calculate split point for train/eval
        total_samples = len(sft_data_to_go_collection)
        train_size = int(total_samples * train_ratio)
        
        # Split the data
        train_data = sft_data_to_go_collection[:train_size]
        eval_data = sft_data_to_go_collection[train_size:]
        
        print(f"\n=== Final Dataset Statistics ===")
        print(f"Total samples: {total_samples}")
        print(f"Train samples: {len(train_data)}")
        print(f"Eval samples: {len(eval_data)}")
        
        # Helper function to write data
        def write_data(data: List[Tuple[str, str]], filename: str, data_type: str):
            if not data:
                print(f"Warning: write_{data_type} is True but no {data_type} data available")
                return
            
            filepath = os.path.join(self.sft_HC_sft_data_to_go_dir, filename)
            with open(filepath, "w", encoding="utf-8") as f:
                for input_text, output_text in data:
                    # Clean any existing EOS tokens - let the training framework handle them
                    output_clean = clean_eos_tokens(output_text)
                    # R1-Distill Native Format:
                    # - No system prompt (DeepSeek R1-Distill was trained without system prompts)
                    # - User prompt already contains task instructions (no prefix needed)
                    # - Output has NO <think> at start (template adds it via add_generation_prompt)
                    row = {
                        "conversations": [
                            {"role": "user", "content": input_text},
                            {"role": "assistant", "content": output_clean}
                        ]
                    }
                    f.write(json.dumps(row, ensure_ascii=False) + "\n")
            print(f"{data_type.capitalize()} data saved to {filepath}")
        
        # Write training and evaluation data as requested
        if write_train:
            write_data(train_data, "train.jsonl", "train")
        if write_eval:
            write_data(eval_data, "eval.jsonl", "eval")
        
        # Save a summary report
        self._save_summary_report(total_samples, train_size, len(eval_data), 
                                 write_train, write_eval)
    
    def _save_summary_report(self, total_samples: int, train_size: int, eval_size: int,
                            write_train: bool, write_eval: bool):
        """
        Save a summary report of the data preparation.
        
        Args:
            total_samples: Total number of samples
            train_size: Number of training samples
            eval_size: Number of evaluation samples
            write_train: Whether training data was written
            write_eval: Whether evaluation data was written
        """
        # Determine report file name based on what was written
        if write_train and write_eval:
            report_name = "data_preparation_summary_train_eval.md"
        elif write_train:
            report_name = "data_preparation_summary_train.md"
        elif write_eval:
            report_name = "data_preparation_summary_eval.md"
        else:
            report_name = "data_preparation_summary.md"
        
        report_path = os.path.join(self.sft_HC_sft_data_to_go_dir, report_name)
        
        with open(report_path, 'w') as f:
            f.write("# Hypothesis Composition SFT Data Preparation Summary\n\n")
            f.write(f"## Source\n")
            f.write(f"- QA data directory: {self.sft_qa_data_dir}\n")
            f.write(f"- Reasoning trace directory: {self.sft_HC_reasoning_trace_dir}\n")
            f.write(f"- Number of papers processed: {len(self.paper_with_reasoning_trace_list)}\n\n")
            f.write(f"## Prepared Dataset\n")
            f.write(f"- Total samples: {total_samples}\n")
            train_pct = train_size/total_samples*100 if total_samples > 0 else 0
            eval_pct = eval_size/total_samples*100 if total_samples > 0 else 0
            f.write(f"- Train samples: {train_size} ({train_pct:.1f}%)\n")
            f.write(f"- Eval samples: {eval_size} ({eval_pct:.1f}%)\n\n")
            f.write(f"## Strategy\n")
            f.write(f"- Chronologically sorted by year and PMID\n")
            f.write(f"- Each sample includes research context and inspiration integration\n")
            f.write(f"- Reasoning traces preserved in <think> tags\n")
        
        print(f"Summary report saved to {report_path}")

    def prepare_sft_data_input_output(self, min_hypothesis_word_length: int) -> List[Tuple[str, str]]:
        """
        Prepare input-output pairs for SFT training.
        
        Returns:
            List of (input, output) tuples
        """
        # prompts = instruction_prompts("prepare_HC_sft_data_to_go")
        # prompts = instruction_prompts("prepare_HC_sft_data_to_go_comprehensive")
        # v2: Use delta hypothesis format (Inspiration/Motivation/Mechanism/Methodology)
        prompts = instruction_prompts("prepare_HC_sft_data_to_go_comprehensive_v2_delta")
        # sft_data_to_go_collection: [[input, output]]
        sft_data_to_go_collection = []
        # Track word lengths for statistics
        hypothesis_word_lengths = []
        # Track filtered out and kept counts separately for with/without previous hypothesis
        filtered_out_counts = {"with_prev_hyp": 0, "no_prev_hyp": 0}
        kept_counts = {"with_prev_hyp": 0, "no_prev_hyp": 0}
        
        for cur_file_idx, cur_file in enumerate(self.paper_with_reasoning_trace_list):
            if cur_file_idx % 1000 == 0:
                print(f"Processing {cur_file_idx} / {len(self.paper_with_reasoning_trace_list)} papers...")
            # sft_qa_data: load research question and background survey
            with open(os.path.join(self.sft_qa_data_dir, cur_file), "r") as f:
                sft_qa_data = json.load(f)
            research_question = sft_qa_data["research_question"]
            survey = sft_qa_data["background_survey"]

            # reasoning_trace: load reasoning trace
            with open(os.path.join(self.sft_HC_reasoning_trace_dir, cur_file), "r") as f:
                reasoning_trace = json.load(f)
            
            for cur_item in reasoning_trace:
                # Check if there's a previous hypothesis
                has_prev_hyp = cur_item[1] is not None
                cur_prev_hyp = cur_item[1] if has_prev_hyp else "No previous hypothesis."
                cur_insp_title = cur_item[2]
                cur_insp_abstract = cur_item[3]
                cur_next_hyp_reasoning_trace = cur_item[5]
                cur_next_hyp = cur_item[6]

                # filter for None reasoning trace or hypothesis
                if cur_next_hyp_reasoning_trace is None or cur_next_hyp is None:
                    # print(f"cur_next_hyp_reasoning_trace is None or cur_next_hyp is None: {cur_file}")
                    continue
                
                # Calculate word length
                word_length = len(cur_next_hyp.split())
                
                # filter for minimum hypothesis word length
                if word_length < min_hypothesis_word_length:
                    # print(f"cur_next_hyp word length is less than {min_hypothesis_word_length}: length {word_length}, has_prev_hyp: {has_prev_hyp}")
                    filtered_out_counts["with_prev_hyp" if has_prev_hyp else "no_prev_hyp"] += 1
                    continue

                # Track word length for statistics and count kept samples
                hypothesis_word_lengths.append(word_length)
                kept_counts["with_prev_hyp" if has_prev_hyp else "no_prev_hyp"] += 1
                
                # R1-Distill Native Format (matching LLaMA-Factory ReasoningTemplate):
                # - No system prompt (user prompt already contains task instructions)
                # - Output starts with <think>\n (LLaMA-Factory expects this for ReasoningTemplate)
                # - Output ends with \n</think>\n\n[answer] (matches thought_words default)
                cur_input = prompts[0] + research_question + prompts[1] + survey + prompts[2] + cur_prev_hyp + prompts[3] + cur_insp_title + prompts[4] + cur_insp_abstract + prompts[5]
                # Output: <think>\n + reasoning + \n</think>\n\n + hypothesis
                cur_output = "<think>\n" + cur_next_hyp_reasoning_trace + "\n</think>\n\n" + cur_next_hyp
                sft_data_to_go_collection.append([cur_input, cur_output])
        
        # Print word length statistics
        if hypothesis_word_lengths:
            sorted_lengths = sorted(hypothesis_word_lengths)
            n = len(sorted_lengths)
            mean_length = sum(sorted_lengths) / n
            median_length = sorted_lengths[n // 2] if n % 2 == 1 else (sorted_lengths[n // 2 - 1] + sorted_lengths[n // 2]) / 2
            p10 = sorted_lengths[int(n * 0.1)]
            p25 = sorted_lengths[int(n * 0.25)]
            p75 = sorted_lengths[int(n * 0.75)]
            p90 = sorted_lengths[int(n * 0.9)]
            
            total_filtered = sum(filtered_out_counts.values())
            
            print(f"\n=== Hypothesis Word Length Statistics ===")
            print(f"Total samples (before filtering): {n + total_filtered}")
            print(f"Total samples (after filtering): {n}")
            print(f"Filtered out (< {min_hypothesis_word_length} words): {total_filtered}")
            print(f"  - With previous hypothesis: {filtered_out_counts['with_prev_hyp']}")
            print(f"  - Without previous hypothesis (first step): {filtered_out_counts['no_prev_hyp']}")
            print(f"Kept (>= {min_hypothesis_word_length} words): {n}")
            print(f"  - With previous hypothesis: {kept_counts['with_prev_hyp']}")
            print(f"  - Without previous hypothesis (first step): {kept_counts['no_prev_hyp']}")
            print(f"Min: {sorted_lengths[0]} words")
            print(f"Max: {sorted_lengths[-1]} words")
            print(f"Mean: {mean_length:.1f} words")
            print(f"Median: {median_length:.1f} words")
            print(f"10th percentile: {p10} words")
            print(f"25th percentile: {p25} words")
            print(f"75th percentile: {p75} words")
            print(f"90th percentile: {p90} words")
        else:
            print("\n=== Hypothesis Word Length Statistics ===")
            print("No valid samples found!")
        
        return sft_data_to_go_collection
    




def main():
    """Main function to prepare SFT data for hypothesis composition."""
    import argparse
    
    parser = argparse.ArgumentParser(description='Prepare SFT data for hypothesis composition training')
    parser.add_argument('--qa_data_dir', type=str, required=True,
                       help='Directory containing QA data')
    parser.add_argument('--reasoning_trace_dir', type=str, required=True,
                       help='Directory containing reasoning traces')
    parser.add_argument('--output_dir', type=str, required=True,
                       help='Directory to save prepared SFT data')
    parser.add_argument('--train_ratio', type=float, default=0.95,
                       help='Ratio of data to use for training (0.0 to 1.0, default: 0.95)')
    parser.add_argument('--write_train', action='store_true',
                       help='Write training data to train.jsonl')
    parser.add_argument('--write_eval', action='store_true',
                       help='Write evaluation data to eval.jsonl')
    parser.add_argument('--min_hypothesis_word_length', type=int, default=100,
                       help='Minimum word length for hypothesis to be included (default: 100)')
    
    args = parser.parse_args()
    
    # Validate train_ratio
    if not 0.0 <= args.train_ratio <= 1.0:
        parser.error("train_ratio must be between 0.0 and 1.0")
    
    # Ensure at least one output is requested
    if not args.write_train and not args.write_eval:
        parser.error("At least one of --write_train or --write_eval must be specified")
    
    # Warn if train_ratio > 0 but write_train is False
    if args.train_ratio > 0 and not args.write_train:
        print("Warning: train_ratio > 0 but write_train is False. No training data will be written.")
    
    # Warn if train_ratio < 1.0 but write_eval is False
    if args.train_ratio < 1.0 and not args.write_eval:
        print("Warning: train_ratio < 1.0 but write_eval is False. No evaluation data will be written.")
    
    # Use provided paths directly
    sft_qa_data_dir = args.qa_data_dir
    sft_HC_reasoning_trace_dir = args.reasoning_trace_dir
    sft_HC_sft_data_to_go_dir = args.output_dir
    
    print(f"\nPreparing hypothesis composition SFT data")
    print(f"Note: EOS tokens will be handled by the training framework (use template: deepseekr1 in training config)")
    print(f"QA data directory: {sft_qa_data_dir}")
    print(f"Reasoning trace directory: {sft_HC_reasoning_trace_dir}")
    print(f"Output directory: {sft_HC_sft_data_to_go_dir}")
    
    # Initialize and run data preparation
    sft_data_to_go = HCSFTDataToGo(sft_qa_data_dir, sft_HC_reasoning_trace_dir, sft_HC_sft_data_to_go_dir)
    sft_data_to_go.save_sft_data_to_go(
        train_ratio=args.train_ratio,
        write_train=args.write_train,
        write_eval=args.write_eval,
        min_hypothesis_word_length=args.min_hypothesis_word_length
    )
    print("\n✓ SFT data preparation complete!")


if __name__ == "__main__":
    main()