import numpy as np
import json
import subprocess
import re
import os
from skopt import gp_minimize
from skopt.space import Integer, Categorical
from skopt.utils import use_named_args
import time

class BayesianPlacementOptimizer:
    def __init__(self, num_cells, external_script_path="placement_script.py", 
                 original_json_path="./data/circuit_data_2.json", json_file_name="circuit_data_2"):
        """
        Initialize the Bayesian optimizer
        
        Args:
            num_cells: Number of cells
            external_script_path: Path to external Python script
            original_json_path: Path to original JSON file
            json_file_name: JSON file name (without extension), used for naming output files
        """
        self.num_cells = num_cells
        self.external_script_path = external_script_path
        self.original_json_path = original_json_path
        self.json_file_name = json_file_name
        self.iteration = 0
        self.best_result = None
        self.history = []
        
        # Create output directory
        os.makedirs("./bayes", exist_ok=True)
        
        # Define search space
        self.setup_search_space()
        
        # Read external script template
        with open(external_script_path, 'r', encoding='utf-8') as f:
            self.script_template = f.read()
        
        # Update JSON path in script template to current processing JSON file
        # Replace hardcoded JSON path
        self.script_template = re.sub(
            r'with open\(".*?circuit_data.*?\.json"',
            f'with open("{original_json_path}"',
            self.script_template
        )
        self.script_template = re.sub(
            r'open\("/research/d7/gds/yhhan25/Research/LLM-PnR/data/.*?\.json"',
            f'open("{original_json_path}"',
            self.script_template
        )
    
    def setup_search_space(self):
        """Set up search space for Bayesian optimization"""
        # Search space for order_list: use permutation encoding
        # We use num_cells-1 integers to represent the permutation
        self.order_dimensions = [Integer(0, self.num_cells-i-1, name=f'order_{i}') 
                                 for i in range(self.num_cells-1)]
        
        # Search space for rotation_list: each cell can be 'R0' or 'MY'
        self.rotation_dimensions = [Categorical(['R0', 'MY'], name=f'rotation_{i}') 
                                    for i in range(self.num_cells)]
        
        # Merge all dimensions
        self.dimensions = self.order_dimensions + self.rotation_dimensions
    
    def decode_order(self, encoded_order):
        """
        Decode the encoded order to actual order_list
        Using inverse Lehmer coding process
        """
        n = self.num_cells
        available = list(range(n))
        order_list = []
        
        for i in range(n-1):
            idx = encoded_order[i]
            order_list.append(available.pop(idx))
        order_list.append(available[0])
        
        return order_list
    
    def update_external_script(self, order_list, rotation_list):
        """
        Update order_list and rotation_list in external Python script
        
        Args:
            order_list: New order list
            rotation_list: New rotation list
        """
        # Convert lists to string format
        order_str = str(order_list)
        rotation_str = str(rotation_list)
        
        # Use regex to update order_list
        pattern_order = r'order_list\s*=\s*\[.*?\]'
        replacement_order = f'order_list = {order_str}'
        updated_script = re.sub(pattern_order, replacement_order, self.script_template)
        
        # Use regex to update rotation_list
        pattern_rotation = r'rotation_list\s*=\s*\[.*?\]'
        replacement_rotation = f'rotation_list = {rotation_str}'
        updated_script = re.sub(pattern_rotation, replacement_rotation, updated_script)
        
        # Update ss_router's output_file_path parameter to output to ./bayes/routed.json
        pattern_output = r'output_file_path\s*=\s*[\'"][^\'"]*routed\.json[\'"]'
        replacement_output = 'output_file_path=\'/research/d7/gds/yhhan25/Research/LLM-PnR/bayes/routed.json\''
        updated_script = re.sub(pattern_output, replacement_output, updated_script)
        
        # Also update image_path to output to ./bayes/routed.png
        pattern_image = r'image_path\s*=\s*[\'"][^\'"]*routed\.png[\'"]'
        replacement_image = 'image_path=\'/research/d7/gds/yhhan25/Research/LLM-PnR/bayes/routed.png\''
        updated_script = re.sub(pattern_image, replacement_image, updated_script)
        
        # Don't save temporary script, only keep in memory
        self.current_script = updated_script
        
        return updated_script
    
    def run_placement_and_routing(self, order_list, rotation_list):
        """
        Run placement and routing, and get results
        
        Returns:
            tuple: (total_wirelength, total_via_count) or None (if execution fails)
        """
        try:
            # Update external script
            updated_script = self.update_external_script(order_list, rotation_list)
            
            # Create temporary script file for execution (will be deleted after execution)
            temp_script_path = f"./bayes/temp_exec_{self.iteration}.py"
            with open(temp_script_path, 'w', encoding='utf-8') as f:
                f.write(updated_script)
            
            # Execute script
            result = subprocess.run(['python', temp_script_path], 
                                  capture_output=True, 
                                  text=True, 
                                  timeout=60)
            
            # Delete temporary execution script
            if os.path.exists(temp_script_path):
                os.remove(temp_script_path)
            
            if result.returncode != 0:
                print(f"Script execution failed: {result.stderr}")
                return None
            
            # Read generated routed.json file (using absolute path)
            routed_json_path = '/research/d7/gds/yhhan25/Research/LLM-PnR/bayes/routed.json'
            if not os.path.exists(routed_json_path):
                print("Cannot find routed.json file")
                print(f"Expected path: {routed_json_path}")
                return None
            
            with open(routed_json_path, 'r') as f:
                routed_data = json.load(f)
            
            # Debug: print routed_data structure
            print(f"routed_data type: {type(routed_data)}")
            if isinstance(routed_data, list):
                print(f"routed_data length: {len(routed_data)}")
                if len(routed_data) > 1:
                    print(f"Second element type: {type(routed_data[1])}")
                    if isinstance(routed_data[1], dict):
                        print(f"Second element keys: {list(routed_data[1].keys())}")
            
            # Extract routing_info - fix based on your JSON structure
            routing_info = None
            
            # Based on your JSON structure, routing_info is in the second element of the array
            if isinstance(routed_data, list) and len(routed_data) > 1:
                if isinstance(routed_data[1], dict) and 'routing_info' in routed_data[1]:
                    routing_info = routed_data[1]['routing_info']
                    print("Successfully found routing_info")
                else:
                    # Try to find routing_info in the second element
                    for key, value in routed_data[1].items():
                        if isinstance(value, dict) and 'routing_info' in value:
                            routing_info = value['routing_info']
                            break
            elif isinstance(routed_data, dict):
                # If it's a dictionary, access routing_info directly
                if 'routing_info' in routed_data:
                    routing_info = routed_data['routing_info']
                else:
                    # Try to find routing_info in dictionary values
                    for key, value in routed_data.items():
                        if isinstance(value, dict) and 'routing_info' in value:
                            routing_info = value['routing_info']
                            break
            
            if routing_info is None:
                print("routing_info not found in JSON file")
                print(f"routed_data content: {routed_data}")
                return None
            
            total_wirelength = routing_info.get('total_wirelength', float('inf'))
            total_via_count = routing_info.get('total_via_count', float('inf'))
            
            print(f"Extracted wirelength: {total_wirelength}, via_count: {total_via_count}")
            
            return total_wirelength, total_via_count
            
        except subprocess.TimeoutExpired:
            print("Script execution timeout")
            return None
        except Exception as e:
            print(f"Error during execution: {e}")
            import traceback
            traceback.print_exc()
            return None
    
    def objective_function(self, params):
        """
        Objective function for Bayesian optimization (without decorator)
        
        Args:
            params: Parameter list
            
        Returns:
            float: Objective value to minimize
        """
        self.iteration += 1
        print(f"\n========== Iteration {self.iteration} ==========")
        
        # Parse parameters
        # First num_cells-1 parameters are order encoding, next num_cells are rotation
        encoded_order = params[:self.num_cells-1]
        rotation_params = params[self.num_cells-1:]
        
        # Convert rotation parameters to strings
        rotation_list = rotation_params  # Categorical dimensions are already strings
        
        # Decode order_list
        order_list = self.decode_order(encoded_order)
        
        print(f"Order list: {order_list}")
        print(f"Rotation list: {rotation_list}")
        
        # Run placement and routing
        result = self.run_placement_and_routing(order_list, rotation_list)
        
        if result is None:
            # If execution fails, return a large penalty value
            objective_value = 1e10
            print(f"Execution failed, objective value set to: {objective_value}")
        else:
            total_wirelength, total_via_count = result
            # Define objective function (weights can be adjusted)
            wirelength_weight = 1.0
            via_weight = 0.1  # Via weight can be adjusted as needed
            objective_value = wirelength_weight * total_wirelength + via_weight * total_via_count
            
            print(f"Total wirelength: {total_wirelength}")
            print(f"Total via count: {total_via_count}")
            print(f"Objective value: {objective_value}")
            
            # Record history
            self.history.append({
                'iteration': self.iteration,
                'order_list': order_list.copy(),
                'rotation_list': rotation_list.copy() if isinstance(rotation_list, list) else list(rotation_list),
                'total_wirelength': total_wirelength,
                'total_via_count': total_via_count,
                'objective': objective_value
            })
            
            # Update best result
            if self.best_result is None or objective_value < self.best_result['objective']:
                self.best_result = {
                    'iteration': self.iteration,
                    'order_list': order_list.copy(),
                    'rotation_list': rotation_list.copy() if isinstance(rotation_list, list) else list(rotation_list),
                    'total_wirelength': total_wirelength,
                    'total_via_count': total_via_count,
                    'objective': objective_value
                }
                print(f"*** Found new best result! ***")
                
                # Save best result
                self.save_best_result()
        
        return objective_value
    
    def optimize(self, n_calls=10, n_initial_points=10):
        """
        Execute Bayesian optimization
        
        Args:
            n_calls: Total number of function evaluations
            n_initial_points: Number of initial random sampling points
        
        Returns:
            Optimization result
        """
        print("Starting Bayesian optimization...")
        print(f"Total iterations: {n_calls}")
        print(f"Initial random sampling points: {n_initial_points}")
        print(f"Search space dimensions: {len(self.dimensions)}")
        
        # Execute Bayesian optimization
        result = gp_minimize(
            func=self.objective_function,
            dimensions=self.dimensions,
            n_calls=n_calls,
            n_initial_points=n_initial_points,
            acq_func='EI',  # Expected Improvement
            noise=1e-10,
            random_state=42,
            verbose=True
        )
        
        print("\n========== Optimization complete ==========")
        print(f"Best objective value: {result.fun}")
        
        if self.best_result:
            print(f"\nBest configuration:")
            print(f"  Order list: {self.best_result['order_list']}")
            print(f"  Rotation list: {self.best_result['rotation_list']}")
            print(f"  Total wirelength: {self.best_result['total_wirelength']}")
            print(f"  Total via count: {self.best_result['total_via_count']}")
            print(f"  Objective: {self.best_result['objective']}")
        
        # Save optimization history
        self.save_history()
        
        return result
    
    def save_best_result(self):
        """Save best result to file"""
        if self.best_result:
            # Save best configuration (using json file name as prefix)
            best_config_path = f"./bayes/{self.json_file_name}_config.json"
            with open(best_config_path, 'w') as f:
                json.dump(self.best_result, f, indent=2)
            
            # Generate best configuration script (using json file name as prefix)
            best_script = self.script_template
            order_str = str(self.best_result['order_list'])
            rotation_str = str([r for r in self.best_result['rotation_list']])  # 确保是列表格式
            
            pattern_order = r'order_list\s*=\s*\[.*?\]'
            best_script = re.sub(pattern_order, f'order_list = {order_str}', best_script)
            
            pattern_rotation = r'rotation_list\s*=\s*\[.*?\]'
            best_script = re.sub(pattern_rotation, f'rotation_list = {rotation_str}', best_script)
            
            best_script_path = f"./bayes/{self.json_file_name}_script.py"
            with open(best_script_path, 'w', encoding='utf-8') as f:
                f.write(best_script)
            print(f"Best script saved to: {best_script_path}")
    
    def save_history(self):
        """Save optimization history"""
        history_path = f"./bayes/{self.json_file_name}_history.json"
        with open(history_path, 'w') as f:
            json.dump(self.history, f, indent=2)
        
        print(f"Optimization history saved to {history_path}")


def main():
    """Main function - Iterate through all JSON files in data folder for optimization"""
    
    # Set data folder path
    data_folder = "./data"
    
    # Get all JSON files
    json_files = [f for f in os.listdir(data_folder) if f.endswith('.json')]
    
    if not json_files:
        print(f"No JSON files found in {data_folder} folder")
        return
    
    print(f"Found {len(json_files)} JSON files: {json_files}")
    print("="*80)
    
    # Iterate through each JSON file for optimization
    for idx, json_file in enumerate(json_files, 1):
        print(f"\n{'='*80}")
        print(f"Processing file {idx}/{len(json_files)}: {json_file}")
        print(f"{'='*80}\n")
        
        # Build file path
        json_path = os.path.join(data_folder, json_file)
        
        # Get file name (without extension)
        json_file_name = os.path.splitext(json_file)[0]
        
        # Read JSON file to determine number of cells
        try:
            with open(json_path, 'r') as f:
                circuit_data = json.load(f)
            
            # Extract cells count based on actual JSON structure
            # JSON may be dictionary or list format
            if isinstance(circuit_data, dict) and 'cells' in circuit_data:
                # If it's a dictionary containing 'cells' key
                cells = circuit_data['cells']
                num_cells = len(cells)
            elif isinstance(circuit_data, list) and len(circuit_data) > 0:
                # If it's a list, check first element
                if isinstance(circuit_data[0], dict) and 'cells' in circuit_data[0]:
                    cells = circuit_data[0]['cells']
                    num_cells = len(cells)
                else:
                    return -1
            else:
                return -1
            
            print(f"Detected {num_cells} cells")
            
        except Exception as e:
            print(f"Failed to read JSON file: {e}")
            print(f"Using default cell count: 9")
            num_cells = 9
        
        # Create optimizer instance
        optimizer = BayesianPlacementOptimizer(
            num_cells=num_cells,
            external_script_path="main.py",
            original_json_path=json_path,
            json_file_name=json_file_name
        )
        
        # Execute Bayesian optimization
        # You can adjust n_calls (total iterations) and n_initial_points (initial random points)
        try:
            result = optimizer.optimize(n_calls=10, n_initial_points=10)
            
            print(f"\nFile {json_file} optimization complete!")
            print(f"Best configuration saved to ./bayes/{json_file_name}_config.json")
            print(f"Best placement script saved to ./bayes/{json_file_name}_script.py")
            print(f"Optimization history saved to ./bayes/{json_file_name}_history.json")
            
        except Exception as e:
            print(f"\nError during optimization of file {json_file}: {e}")
            import traceback
            traceback.print_exc()
            print("Continuing to next file...")
            continue
    
    print(f"\n{'='*80}")
    print("All files processed!")
    print(f"{'='*80}")


if __name__ == "__main__":
    main()