#!/usr/bin/env python3
"""
Synthesize.py - Main synthesis script that chains all generators
This script coordinates the generation of profiles, policies, and tools with consistent parameters.
"""

import os
import json
import argparse
from policy_generator import PolicyGenerator
from profile_generator import ProfileGenerator  
from tool_generator import ToolGenerator
from qa_generator import UserRequestGenerator

class DataSynthesizer:
    def __init__(self, 
                 profile_complexity_depth=3,
                 profile_complexity_width=500, 
                 number_of_tasks=3,
                 k1=1, k2=1, k3=1,
                 attribute_dict=None,
                 base_output_dir=None,
                 structure_complexity=1,
                 args_num=5):
        """
        Initialize the DataSynthesizer with configurable parameters.
        
        Args:
            profile_complexity_depth (int): Number of layers/depth levels (default: 3)
            profile_complexity_width (int): Number of profiles per layer (default: 500)
            number_of_tasks (int): Number of task types to generate (default: 3)
            k1, k2, k3 (int): Constants for profile generation (default: 1 each)
            attribute_dict (list): List of dictionaries defining attribute types for each layer
            base_output_dir (str): Base directory for all generated files
            structure_complexity (int): Number of if-else rounds for policy generation (default: 1)
            args_num (int): Number of arguments for each generated task (default: 5)
        """
        self.profile_complexity_depth = profile_complexity_depth
        self.profile_complexity_width = profile_complexity_width
        self.number_of_tasks = number_of_tasks
        self.k1 = k1
        self.k2 = k2
        self.k3 = k3
        self.structure_complexity = structure_complexity
        self.args_num = args_num
        
        # Dynamically set base_output_dir if not provided
        if base_output_dir is None:
            data_dir_name = f"Generated_data_layer_{profile_complexity_depth}_task_{number_of_tasks}_structure_{structure_complexity}_args_{args_num}"
            self.base_output_dir = data_dir_name
        else:
            self.base_output_dir = base_output_dir
        
        # Use provided attribute configuration
        if attribute_dict is None:
            raise ValueError("attribute_dict must be provided. Please define the attribute structure for each layer.")
        self.attribute_dict = attribute_dict
        
        # Create output directory structure
        self.setup_output_directories()
    
    def setup_output_directories(self):
        """Create the necessary output directory structure."""
        directories = [
            self.base_output_dir,
            os.path.join(self.base_output_dir, "Profiles"),
            os.path.join(self.base_output_dir, "Policy"), 
            os.path.join(self.base_output_dir, "Tools"),
            os.path.join(self.base_output_dir, "Task"),
            os.path.join(self.base_output_dir, "Queries")
        ]
        
        for directory in directories:
            os.makedirs(directory, exist_ok=True)
    
    def generate_profiles(self):
        """Generate hierarchical profiles using ProfileGenerator."""
        profile_generator = ProfileGenerator(
            profile_complexity_depth=self.profile_complexity_depth,
            profile_complexity_width=self.profile_complexity_width,
            k1=self.k1,
            k2=self.k2, 
            k3=self.k3,
            attribute_dict=self.attribute_dict
        )
        
        profiles_output_dir = os.path.join(self.base_output_dir, "Profiles")
        layers_data = profile_generator.save_profile_files(profiles_output_dir)
        
        return layers_data
    
    def generate_policy(self):
        """Generate policy using PolicyGenerator."""
        policy_generator = PolicyGenerator(structure_complexity=self.structure_complexity, args_num=self.args_num)
        # Override the default settings to match our configuration
        policy_generator.task_types = self.number_of_tasks
        policy_generator.layers = self.profile_complexity_depth
        policy_generator.attribute_dict = self.attribute_dict
        
        # Generate global attributes
        global_attributes = policy_generator.generate_global_attributes()
        
        policy_output_path = os.path.join(self.base_output_dir, "Policy", "Policy.md")
        policy_content = policy_generator.save_policy(policy_output_path)
        
        # Get task requirements for summary
        task_requirements = policy_generator.get_task_requirements()
        
        return policy_content, task_requirements, global_attributes
    
    def generate_tools(self):
        """Generate tools using ToolGenerator."""
        tools_output_path = os.path.join(self.base_output_dir, "Tools", "all_tools.py")
        
        # Get the current directory as root path
        root_path = os.getcwd()
        
        tool_generator = ToolGenerator(
            depth_complexity=self.profile_complexity_depth,
            number_of_tasks=self.number_of_tasks,
            output_path=tools_output_path,
            attribute_dict=self.attribute_dict,
            root_path=root_path
        )
        
        tools_content = tool_generator.generate_all_tools()
        
        return tools_content
    
    def extract_lookup_tables(self, layers_data):
        """
        Extract lookup tables from profile generator results and format them for qa_generator.
        
        Args:
            layers_data: List of tuples (layer_profiles, layer_keys, layer_lookup_tables)
            
        Returns:
            Dict[int, List[str]]: Lookup tables organized by layer number
        """
        layer_lookup_tables = {}
        
        for layer_idx, (layer_profiles, layer_keys, layer_lookup_tables_dict) in enumerate(layers_data, 1):
            # Extract all lookup values for this layer
            lookup_values = []
            for attr_name, values_list in layer_lookup_tables_dict.items():
                lookup_values.extend(values_list)
            
            # Remove duplicates and store
            layer_lookup_tables[layer_idx] = list(set(lookup_values))
            
        return layer_lookup_tables
    
    def format_global_attributes_for_qa(self, policy_global_attributes):
        """
        Transform policy generator global attributes format to qa_generator format.
        
        Args:
            policy_global_attributes: Dict with keys like 'Global_Attribute_Value1'
            
        Returns:
            Dict with keys like 'global_attribute_1'
        """
        qa_global_attributes = {}
        for key, value in policy_global_attributes.items():
            # Transform 'Global_Attribute_Value1' to 'global_attribute_1'
            if key.startswith('Global_Attribute_Value'):
                number = key.replace('Global_Attribute_Value', '')
                qa_key = f'global_attribute_{number}'
                qa_global_attributes[qa_key] = value
        
        return qa_global_attributes
    
    def save_details_to_json(self, layers_data, global_attributes, task_requirements):
        """
        Save the metadata from first stage generation to details.json for later QA generation.
        
        Args:
            layers_data: Profile generation results
            global_attributes: Global attributes from policy generator
            task_requirements: Task layer requirements from policy generator
        """
        # Extract lookup tables from profile data
        layer_lookup_tables = self.extract_lookup_tables(layers_data)
        
        # Format global attributes for qa_generator
        qa_global_attributes = self.format_global_attributes_for_qa(global_attributes)
        
        # Create the details dictionary
        details = {
            "configuration": {
                "profile_complexity_depth": self.profile_complexity_depth,
                "profile_complexity_width": self.profile_complexity_width,
                "number_of_tasks": self.number_of_tasks,
                "k1": self.k1,
                "k2": self.k2,
                "k3": self.k3,
                "attribute_dict": self.attribute_dict,
                "base_output_dir": self.base_output_dir,
                "structure_complexity": self.structure_complexity
            },
            "layer_lookup_tables": layer_lookup_tables,
            "task_layer_requirements": task_requirements,
            "global_attributes": qa_global_attributes,
            "original_global_attributes": global_attributes,
            "profiles_path": os.path.join(self.base_output_dir, "Profiles/"),
            "generation_stage": "first_stage_complete"
        }
        
        # Save to details.json
        details_output_path = os.path.join(self.base_output_dir, "details.json")
        with open(details_output_path, 'w', encoding='utf-8') as f:
            json.dump(details, f, indent=2, ensure_ascii=False)
        
        print(f"\nSaved generation details to: {details_output_path}")
        return details_output_path

    def generate_qa(self, layers_data, global_attributes, task_requirements, num_requests=10100):
        """
        Generate Q&A using the qa_generator.
        
        Args:
            layers_data: Profile generation results
            global_attributes: Global attributes from policy generator
            task_requirements: Task layer requirements from policy generator
            num_requests: Number of Q&A pairs to generate (default: 10100 for train/eval split)
        """
        # Extract lookup tables from profile data
        layer_lookup_tables = self.extract_lookup_tables(layers_data)
        
        # Format global attributes for qa_generator
        qa_global_attributes = self.format_global_attributes_for_qa(global_attributes)
        
        # Create qa_generator
        qa_generator = UserRequestGenerator(
            attribute_dict=self.attribute_dict,
            layer_lookup_tables=layer_lookup_tables,
            task_layer_requirements=task_requirements,
            global_attributes=qa_global_attributes,
            base_output_dir=self.base_output_dir
        )
        
        # Generate and save Q&A
        qa_output_path = os.path.join(self.base_output_dir, "Queries", "qa.json")
        profiles_path = os.path.join(self.base_output_dir, "Profiles/")
        
        qa_data = qa_generator.save_requests_to_json(
            output_path=qa_output_path,
            num_requests=num_requests,
            include_rollouts=True,
            profiles_path=profiles_path
        )
        
        return qa_data
    
    def synthesize_first_stage(self):
        """Run the first stage of synthesis: profiles, policy, and tools generation."""
        try:
            # Step 1: Generate profiles
            layers_data = self.generate_profiles()
            
            # Print layers data structure
            print("Generated Layers Data:")
            for layer_idx, (layer_profiles, layer_keys, layer_lookup_tables) in enumerate(layers_data, 1):
                print(f"\nLayer {layer_idx}:")
                print(f"  Number of profiles: {len(layer_profiles)}")
                print(f"  Profile keys: {layer_keys[:5]}...")  # Show first 5 keys
                print(f"  Lookup tables:")
                for attr_name, lookup_values in layer_lookup_tables.items():
                    unique_values = list(set(lookup_values))
                    print(f"    {attr_name}: {len(unique_values)} unique values - {unique_values[:10]}...")
                
                # Show sample profile
                if layer_keys:
                    sample_key = layer_keys[0]
                    sample_profile = layer_profiles[sample_key]
                    print(f"  Sample profile ({sample_key}): {sample_profile}")
            
            # Step 2: Generate policy 
            policy_content, task_requirements, global_attributes = self.generate_policy()
            
            print(f"\nGenerated Policy with:")
            print(f"  Task requirements: {task_requirements}")
            print(f"  Global attributes: {global_attributes}")
            
            # Step 3: Generate tools
            tools_content = self.generate_tools()
            
            print("\nGenerated Tools")
            
            # Step 4: Save details for QA generation
            details_path = self.save_details_to_json(layers_data, global_attributes, task_requirements)
            
            print(f"\nFirst stage synthesis completed successfully!")
            print(f"Details saved to: {details_path}")
            print(f"You can now run QA generation using Synthesize_QA.py")
            
            return True
            
        except Exception as e:
            print(f"First stage synthesis failed with error: {e}")
            return False

    def synthesize_all(self, num_qa_requests=10100):
        """Run the complete synthesis pipeline including Q&A generation."""
        try:
            # Step 1: Generate profiles
            layers_data = self.generate_profiles()
            
            # Print layers data structure
            print("Generated Layers Data:")
            for layer_idx, (layer_profiles, layer_keys, layer_lookup_tables) in enumerate(layers_data, 1):
                print(f"\nLayer {layer_idx}:")
                print(f"  Number of profiles: {len(layer_profiles)}")
                print(f"  Profile keys: {layer_keys[:5]}...")  # Show first 5 keys
                print(f"  Lookup tables:")
                for attr_name, lookup_values in layer_lookup_tables.items():
                    unique_values = list(set(lookup_values))
                    print(f"    {attr_name}: {len(unique_values)} unique values - {unique_values[:10]}...")
                
                # Show sample profile
                if layer_keys:
                    sample_key = layer_keys[0]
                    sample_profile = layer_profiles[sample_key]
                    print(f"  Sample profile ({sample_key}): {sample_profile}")
            
            # Step 2: Generate policy 
            policy_content, task_requirements, global_attributes = self.generate_policy()
            
            print(f"\nGenerated Policy with:")
            print(f"  Task requirements: {task_requirements}")
            print(f"  Global attributes: {global_attributes}")
            
            # Step 3: Generate tools
            tools_content = self.generate_tools()
            
            print("\nGenerated Tools")
            
            # Step 4: Generate Q&A
            qa_data = self.generate_qa(layers_data, global_attributes, task_requirements, num_qa_requests)
            
            print(f"\nGenerated {len(qa_data)} Q&A pairs")
            
            return True
            
        except Exception as e:
            print(f"Synthesis failed with error: {e}")
            return False

def main():
    """Main function with command line argument parsing."""
    
    # Parse command line arguments
    parser = argparse.ArgumentParser(
        description='Synthesize hierarchical data with profiles, policies, and tools',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  python Synthesize.py --depth 3 --tasks 5 --args_num 8
  python Synthesize.py --depth 2 --tasks 3 --width 1000 --args_num 6
  python Synthesize.py --depth 4 --tasks 7 --k1 2 --k2 5 --k3 1 --args_num 10
        """
    )
    
    parser.add_argument('--depth', '--profile_complexity_depth', 
                       type=int, required=True,
                       help='Number of layers/depth levels (required)')
    
    parser.add_argument('--tasks', '--number_of_tasks',
                       type=int, required=True, 
                       help='Number of task types to generate (required)')
    
    parser.add_argument('--width', '--profile_complexity_width',
                       type=int, default=500,
                       help='Number of profiles per layer (default: 500)')
    
    parser.add_argument('--k1', type=int, default=1,
                       help='K1 constant for profile generation (default: 1)')
    
    parser.add_argument('--k2', type=int, default=3,
                       help='K2 constant for profile generation (default: 3)')
    
    parser.add_argument('--k3', type=int, default=1,
                       help='K3 constant for profile generation (default: 1)')
    
    parser.add_argument('--output_dir', type=str, default=None,
                       help='Base output directory (default: auto-generated based on depth, tasks, and structure complexity)')
    
    parser.add_argument('--structure_complexity', type=int, default=1,
                       help='Number of if-else rounds for policy generation (default: 1)')
    
    parser.add_argument('--args_num', type=int, default=5,
                       help='Number of arguments for each generated task (default: 5)')
    
    args = parser.parse_args()
    
    # Generate attribute dictionary based on depth
    def generate_attribute_dict(depth):
        """
        Generate attribute dictionary dynamically based on depth.
        Ensures each layer has at least a reference to itself, the next layer (if exists), and ALWAYS a lookup field.
        """
        attribute_dict = []
        
        for layer in range(1, depth + 1):
            layer_dict = {
                "1": "condition",
                "2": "condition"
            }
            
            attr_count = 3
            
            # Priority 1: ALWAYS add lookup attribute (mandatory)
            layer_dict[str(attr_count)] = "lookup"
            attr_count += 1
            
            # Priority 2: Add reference to current layer (self-reference)
            if attr_count <= 8:
                layer_dict[str(attr_count)] = f"reference_{layer}"
                attr_count += 1
            
            # Priority 3: Add reference to next layer (if it exists)
            if layer < depth and attr_count <= 8:
                layer_dict[str(attr_count)] = f"reference_{layer + 1}"
                attr_count += 1
            
            # Priority 4: Add references to other layers (excluding current and next)
            for ref_layer in range(1, depth + 1):
                if ref_layer != layer and ref_layer != (layer + 1) and attr_count <= 7:
                    layer_dict[str(attr_count)] = f"reference_{ref_layer}"
                    attr_count += 1
            
            # Priority 5: Fill remaining with conditions if needed
            while attr_count <= 8:
                layer_dict[str(attr_count)] = "condition"
                attr_count += 1
                
            attribute_dict.append(layer_dict)
        
        return attribute_dict
    
    # Generate custom attribute configuration based on depth
    custom_attribute_dict = generate_attribute_dict(args.depth)
    
    print(f"Generated attribute configuration for {args.depth} layers:")
    for i, layer_dict in enumerate(custom_attribute_dict, 1):
        print(f"  Layer {i}: {layer_dict}")
    
    # Determine output directory
    if args.output_dir:
        base_output_dir = args.output_dir
    else:
        base_output_dir = f"Generated_data_layer_{args.depth}_task_{args.tasks}_structure_{args.structure_complexity}_args_{args.args_num}"
    
    # Create synthesizer with command line arguments
    synthesizer = DataSynthesizer(
        profile_complexity_depth=args.depth,
        profile_complexity_width=args.width,
        number_of_tasks=args.tasks,
        k1=args.k1, 
        k2=args.k2, 
        k3=args.k3,
        attribute_dict=custom_attribute_dict,
        base_output_dir=base_output_dir,
        structure_complexity=args.structure_complexity,
        args_num=args.args_num
    )
    
    # Run first stage synthesis only
    success = synthesizer.synthesize_first_stage()
    
    # If you want to run the complete pipeline in one go, use:
    # success = synthesizer.synthesize_all(num_qa_requests=10100)
    
    return success

if __name__ == "__main__":
    main()


