#!/usr/bin/env python3
"""
Synthesize_QA.py - Second stage synthesis script for Q&A generation
This script reads the metadata from details.json generated in the first stage
and generates Q&A pairs using the qa_generator, then splits them into train/eval sets.
"""

import os
import json
import random
from qa_generator import UserRequestGenerator, UserRequestGeneratorVariant

class QASynthesizer:
    def __init__(self, details_json_path):
        """
        Initialize the QASynthesizer by loading details from the first stage.
        
        Args:
            details_json_path (str): Path to the details.json file from first stage
        """
        self.details_json_path = details_json_path
        self.details = self.load_details()
        
        # Extract configuration and metadata
        self.configuration = self.details["configuration"]
        self.layer_lookup_tables = self.details["layer_lookup_tables"]
        self.task_layer_requirements = self.details["task_layer_requirements"]
        self.global_attributes = self.details["global_attributes"]
        self.profiles_path = self.details["profiles_path"]
        self.base_output_dir = self.configuration["base_output_dir"]
        self.attribute_dict = self.configuration["attribute_dict"]
        
        # Verify first stage completion
        if self.details.get("generation_stage") != "first_stage_complete":
            raise ValueError("First stage generation not completed. Please run Synthesize.py first.")
        
        print("QA Synthesizer initialized successfully!")
        print(f"Base output directory: {self.base_output_dir}")
        print(f"Profiles path: {self.profiles_path}")
        print(f"Number of layers: {self.configuration['profile_complexity_depth']}")
        print(f"Number of tasks: {self.configuration['number_of_tasks']}")
    
    def load_details(self):
        """Load the details from the JSON file."""
        if not os.path.exists(self.details_json_path):
            raise FileNotFoundError(f"Details file not found: {self.details_json_path}")
        
        try:
            with open(self.details_json_path, 'r', encoding='utf-8') as f:
                details = json.load(f)
            return details
        except json.JSONDecodeError as e:
            raise ValueError(f"Invalid JSON format in details file: {e}")
    
    def print_details_summary(self):
        """Print a summary of the loaded details for verification."""
        print("\n" + "="*60)
        print("LOADED GENERATION DETAILS SUMMARY")
        print("="*60)
        
        print(f"\nConfiguration:")
        for key, value in self.configuration.items():
            if key == "attribute_dict":
                print(f"  {key}: {len(value)} layers defined")
                for i, layer_attrs in enumerate(value, 1):
                    print(f"    Layer {i}: {len(layer_attrs)} attributes")
            else:
                print(f"  {key}: {value}")
        
        print(f"\nLayer Lookup Tables:")
        for layer_num, lookup_values in self.layer_lookup_tables.items():
            print(f"  Layer {layer_num}: {len(lookup_values)} unique lookup values")
            if lookup_values:
                print(f"    Sample values: {lookup_values[:5]}...")
        
        print(f"\nTask Layer Requirements:")
        for layer_num, requirements in self.task_layer_requirements.items():
            print(f"  Layer {layer_num}: {requirements}")
        
        print(f"\nGlobal Attributes:")
        for key, value in self.global_attributes.items():
            print(f"  {key}: {value}")
        
        print(f"\nPaths:")
        print(f"  Profiles path: {self.profiles_path}")
        print(f"  Base output directory: {self.base_output_dir}")
        print("="*60)
    
    def verify_required_files(self):
        """Verify that required files from first stage exist."""
        required_paths = [
            self.profiles_path,
            os.path.join(self.base_output_dir, "Policy"),
            os.path.join(self.base_output_dir, "Tools")
        ]
        
        missing_paths = []
        for path in required_paths:
            if not os.path.exists(path):
                missing_paths.append(path)
        
        if missing_paths:
            raise FileNotFoundError(f"Required paths from first stage not found: {missing_paths}")
        
        # Check for profile files
        profile_files = [f for f in os.listdir(self.profiles_path) if f.endswith('.json')]
        if not profile_files:
            raise FileNotFoundError(f"No profile JSON files found in {self.profiles_path}")
        
        print(f"Verification successful. Found {len(profile_files)} profile files.")
        return True
    
    def generate_and_split_qa(self, total_requests=10100, eval_size=100, include_rollouts=True, seed=42):
        """
        Generate Q&A pairs and split them into train/eval sets with consistent policy.
        
        Args:
            total_requests (int): Total number of Q&A pairs to generate (default: 10100)
            eval_size (int): Number of pairs for evaluation set (default: 100)
            include_rollouts (bool): Whether to include rollouts in generation (default: True)
            seed (int): Random seed for reproducible splits (default: 42)
        
        Returns:
            Tuple[List, List]: (train_data, eval_data)
        """
        try:
            print(f"\nStarting Q&A generation with {total_requests} total requests...")
            print(f"Will split into {total_requests - eval_size} training and {eval_size} evaluation pairs")
            
            # Convert string keys to integers for layer_lookup_tables if needed
            layer_lookup_tables_int_keys = {}
            for key, value in self.layer_lookup_tables.items():
                layer_lookup_tables_int_keys[int(key)] = value
            
            # Convert string keys to integers for task_layer_requirements if needed
            task_layer_requirements_int_keys = {}
            for key, value in self.task_layer_requirements.items():
                task_layer_requirements_int_keys[int(key)] = value
            
            # Create qa_generator
            qa_generator = UserRequestGenerator(
                attribute_dict=self.attribute_dict,
                layer_lookup_tables=layer_lookup_tables_int_keys,
                task_layer_requirements=task_layer_requirements_int_keys,
                global_attributes=self.global_attributes,
                base_output_dir=self.base_output_dir
            )
            
            # Ensure Queries directory exists
            queries_dir = os.path.join(self.base_output_dir, "Queries")
            os.makedirs(queries_dir, exist_ok=True)
            
            # Generate all Q&A data at once
            print("Generating all Q&A data with consistent policy...")
            qa_output_path = os.path.join(queries_dir, "qa_full.json")
            
            all_qa_data = qa_generator.save_requests_to_json(
                output_path=qa_output_path,
                num_requests=total_requests,
                include_rollouts=include_rollouts,
                profiles_path=self.profiles_path
            )
            
            print(f"Generated {len(all_qa_data)} total Q&A pairs")
            
            # Set random seed for reproducible splits
            random.seed(seed)
            
            # Shuffle and split the data
            shuffled_data = all_qa_data.copy()
            random.shuffle(shuffled_data)
            
            # Split into eval and train sets
            eval_data = shuffled_data[:eval_size]
            train_data = shuffled_data[eval_size:]
            
            # Save the split datasets
            eval_output_path = os.path.join(queries_dir, "qa_eval.json")
            train_output_path = os.path.join(queries_dir, "qa_train.json")
            
            with open(eval_output_path, 'w', encoding='utf-8') as f:
                json.dump(eval_data, f, indent=2, ensure_ascii=False)
            
            with open(train_output_path, 'w', encoding='utf-8') as f:
                json.dump(train_data, f, indent=2, ensure_ascii=False)
            
            # Remove the temporary full file
            if os.path.exists(qa_output_path):
                os.remove(qa_output_path)
            
            print(f"\nQ&A generation and splitting completed successfully!")
            print(f"Training set: {len(train_data)} pairs saved to: {train_output_path}")
            print(f"Evaluation set: {len(eval_data)} pairs saved to: {eval_output_path}")
            
            return train_data, eval_data
            
        except Exception as e:
            print(f"Q&A generation failed with error: {e}")
            raise
    
    def update_details_with_qa_completion(self, train_output_path, eval_output_path, num_train, num_eval):
        """
        Update the details.json file to indicate QA generation completion.
        
        Args:
            train_output_path (str): Path where training QA data was saved
            eval_output_path (str): Path where evaluation QA data was saved
            num_train (int): Number of training Q&A pairs generated
            num_eval (int): Number of evaluation Q&A pairs generated
        """
        try:
            # Update details with QA information
            self.details["qa_generation"] = {
                "completed": True,
                "train_output_path": train_output_path,
                "eval_output_path": eval_output_path,
                "num_train": num_train,
                "num_eval": num_eval,
                "total_generated": num_train + num_eval,
                "generation_stage": "qa_stage_complete"
            }
            self.details["generation_stage"] = "fully_complete"
            
            # Save updated details
            with open(self.details_json_path, 'w', encoding='utf-8') as f:
                json.dump(self.details, f, indent=2, ensure_ascii=False)
            
            print(f"Updated details.json with QA completion information")
            
        except Exception as e:
            print(f"Warning: Could not update details.json: {e}")
    
    def generate_variant_evaluation_data(self, variant_type: str = "policy", eval_size: int = 300):
        """
        Generate evaluation data for policy variants using the appropriate exec file.
        
        Args:
            variant_type (str): Either "policy" or "task" to specify which variant
            eval_size (int): Number of evaluation pairs to generate (default: 300)
        
        Returns:
            List: Generated evaluation data
        """
        try:
            print(f"\nGenerating {eval_size} evaluation pairs for {variant_type} variant...")
            
            # Convert string keys to integers for layer_lookup_tables if needed
            layer_lookup_tables_int_keys = {}
            for key, value in self.layer_lookup_tables.items():
                layer_lookup_tables_int_keys[int(key)] = value
            
            # Convert string keys to integers for task_layer_requirements if needed
            task_layer_requirements_int_keys = {}
            for key, value in self.task_layer_requirements.items():
                task_layer_requirements_int_keys[int(key)] = value
            
            # Create qa_generator with variant-specific configuration
            qa_generator = UserRequestGeneratorVariant(
                attribute_dict=self.attribute_dict,
                layer_lookup_tables=layer_lookup_tables_int_keys,
                task_layer_requirements=task_layer_requirements_int_keys,
                global_attributes=self.global_attributes,
                base_output_dir=self.base_output_dir,
                variant_type=variant_type
            )
            
            # Ensure Queries directory exists
            queries_dir = os.path.join(self.base_output_dir, "Queries")
            os.makedirs(queries_dir, exist_ok=True)
            
            # Generate evaluation data
            print(f"Generating {variant_type} variant evaluation data...")
            eval_output_path = os.path.join(queries_dir, f"eval_overide_{variant_type}.json")
            
            eval_data = qa_generator.save_requests_to_json(
                output_path=eval_output_path,
                num_requests=eval_size,
                include_rollouts=True,
                profiles_path=self.profiles_path
            )
            
            print(f"Generated {len(eval_data)} evaluation pairs for {variant_type} variant")
            print(f"Evaluation data saved to: {eval_output_path}")
            
            return eval_data
            
        except Exception as e:
            print(f"Variant evaluation data generation failed with error: {e}")
            raise

def load_qa_synthesizer_from_directory(base_output_dir):
    """
    Convenience function to load QASynthesizer from a base output directory.
    
    Args:
        base_output_dir (str): Base output directory containing details.json
        
    Returns:
        QASynthesizer: Initialized QA synthesizer
    """
    details_json_path = os.path.join(base_output_dir, "details.json")
    return QASynthesizer(details_json_path)

def main():
    """Main function for Q&A generation and splitting."""
    import argparse
    
    parser = argparse.ArgumentParser(description="Generate Q&A pairs from first stage synthesis results and split into train/eval sets")
    parser.add_argument("--details_path", type=str, 
                       help="Path to details.json file from first stage")
    parser.add_argument("--base_dir", type=str, default=None,
                       help="Base directory containing details.json")
    parser.add_argument("--total_requests", type=int, default=10100,
                       help="Total number of Q&A pairs to generate (default: 10100)")
    parser.add_argument("--eval_size", type=int, default=100,
                       help="Number of pairs for evaluation set (default: 100)")
    parser.add_argument("--no_rollouts", action="store_true",
                       help="Disable rollouts in Q&A generation")
    parser.add_argument("--seed", type=int, default=42,
                       help="Random seed for reproducible train/eval splits (default: 42)")
    
    args = parser.parse_args()
    
    # Determine details.json path
    if args.details_path:
        details_json_path = args.details_path
    else:
        # If base_dir is not provided, print a helpful error
        if args.base_dir is None:
            raise ValueError("You must specify either --details_path or --base_dir.")
        details_json_path = os.path.join(args.base_dir, "details.json")
    
    try:
        # Initialize QA synthesizer
        qa_synthesizer = QASynthesizer(details_json_path)
        
        train_size = args.total_requests - args.eval_size
        print(f"Generating {args.total_requests} total Q&A pairs")
        print(f"Will split into {train_size} training and {args.eval_size} evaluation pairs")
        
        # Print summary
        qa_synthesizer.print_details_summary()
        
        # Verify required files
        qa_synthesizer.verify_required_files()
        
        # Generate and split Q&A
        include_rollouts = not args.no_rollouts
        train_data, eval_data = qa_synthesizer.generate_and_split_qa(
            total_requests=args.total_requests,
            eval_size=args.eval_size,
            include_rollouts=include_rollouts,
            seed=args.seed
        )
        
        # Update details with completion
        train_output_path = os.path.join(qa_synthesizer.base_output_dir, "Queries", "qa_train.json")
        eval_output_path = os.path.join(qa_synthesizer.base_output_dir, "Queries", "qa_eval.json")
        qa_synthesizer.update_details_with_qa_completion(
            train_output_path, eval_output_path, len(train_data), len(eval_data)
        )
        
        # Generate variant evaluation data
        try:
            print(f"\n🔄 Generating variant evaluation data...")
            
            # Generate policy variant evaluation data
            qa_synthesizer.generate_variant_evaluation_data(variant_type="policy", eval_size=300)
            
            # Generate task variant evaluation data
            qa_synthesizer.generate_variant_evaluation_data(variant_type="task", eval_size=300)
            
            print(f"✅ Variant evaluation data generated successfully!")
            
        except Exception as e:
            print(f"⚠️  Warning: Could not generate variant evaluation data: {e}")
            print(f"   Main QA generation completed successfully, continuing...")
        
        print(f"\n🎉 QA synthesis completed successfully!")
        print(f"📁 All files are in: {qa_synthesizer.base_output_dir}")
        print(f"🚂 Training set: {len(train_data)} Q&A pairs")
        print(f"🔍 Evaluation set: {len(eval_data)} Q&A pairs")
        print(f"🔍 Policy variant evaluation: 300 Q&A pairs")
        print(f"🔍 Task variant evaluation: 300 Q&A pairs")
        print(f"📊 All sets share consistent underlying policies for reliable evaluation!")
        
        return True
        
    except Exception as e:
        print(f"❌ QA synthesis failed: {e}")
        return False

if __name__ == "__main__":
    main() 