import os
import sys
import fcntl
import argparse
from concurrent.futures import ThreadPoolExecutor, as_completed
import gc
import logging
import json
import shutil
import tarfile
from threading import Thread


# Add project root to Python path
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(root_dir)


import random
import time
import math
from multiprocessing import cpu_count
from python_src.dag import DAG
from python_src.node import Node
from tqdm import tqdm

class LoggerFactory:
    _instance = None
    _logger = None
    
    @classmethod
    def get_logger(cls, task_tag=None):
        return None

class ConsoleLogger:
    def __init__(self, task_tag=None):
        self.task_tag = task_tag
        
    def log(self, level, message, **kwargs):
        try:
            # Add base fields
            if self.task_tag:
                kwargs['task_tag'] = self.task_tag
            kwargs['timestamp'] = int(time.time() * 1000)
            kwargs['node_type'] = 'master_node'
            
            # Also output to console
            timestamp_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
            console_msg = f"[{timestamp_str}] [{level}] {message}"
            
            # Add important fields to console output
            important_fields = ['task_id', 'task_tag']
            log_details = []
            for key in important_fields:
                if key in kwargs:
                    log_details.append(f"{key}={kwargs[key]}")
            
            # Add other fields
            for k, v in kwargs.items():
                if k not in important_fields and k != 'timestamp' and k != 'node_type':
                    log_details.append(f"{k}={v}")
            
            if log_details:
                console_msg += " | " + ", ".join(log_details)
            
            # Use different colors based on log level
            if level == 'ERROR':
                # red
                print(f"\033[91m{console_msg}\033[0m")
            elif level == 'WARNING':
                # yellow
                print(f"\033[93m{console_msg}\033[0m")
            elif level == 'INFO':
                # green
                print(f"\033[92m{console_msg}\033[0m")
            elif level == 'DEBUG':
                # blue
                print(f"\033[94m{console_msg}\033[0m")
            else:
                print(console_msg)
        except:
            pass
        
class DagTaskCombiner:
    def __init__(
        self,
        base_op_folder="./op",
        task_tag="default",
        batch_size=40,
        check_interval=60,
        logger=None,
        max_tasks=400,
        archive_threshold=1000,
        id_manager=None,
        worker_num=10,
    ):
        # Base folders
        self.base_folder = os.path.join(base_op_folder, task_tag) # Adjusted base folder
        self.task_base_folder = os.path.join(self.base_folder, "task") # New task base
        self.dag_base_folder = os.path.join(self.base_folder, "dag") # New dag base

        self.task_current_folder = os.path.join(self.task_base_folder, "current")
        
        # DAG folder structure
        self.dag_pool_folder = os.path.join(self.dag_base_folder, "pool")
        self.dag_archive_folder = os.path.join(self.dag_base_folder, "archive")
        
        self.batch_size = batch_size
        self.check_interval = check_interval
        self.current_dags = []
        self.task_tag = task_tag
        self.max_tasks = max_tasks
        self.archive_threshold = archive_threshold
        self.worker_num = worker_num
        self.start_time = time.time()
        self.status_file_path = os.path.join(self.base_folder, "status.json")
        self.stop_file_path = os.path.join(self.base_folder, "STOP")
        
        # Create required folders
        for folder in [self.task_base_folder, self.task_current_folder, self.dag_base_folder, self.dag_pool_folder, self.dag_archive_folder]:
            if not os.path.exists(folder):
                os.makedirs(folder)
            
        self.logger = logger
        self.lock_file_path = os.path.join(self.task_base_folder, '.lock') # Lock in task base

        # Initialize task ID counter
        self.task_id_counter = 0
        # Task-creation counter since last archive
        self.tasks_created_since_last_archive = 0
        # ID manager
        self.id_manager = id_manager
        # Total tasks created counter
        self.total_tasks_created = 0
        # Stats output frequency
        self.stats_output_frequency = 10000

    def _stop_requested(self):
        return os.path.exists(self.stop_file_path)

    def _safe_write_json(self, path, payload):
        tmp_path = f"{path}.tmp"
        with open(tmp_path, "w", encoding="utf-8") as handle:
            json.dump(payload, handle, indent=2, ensure_ascii=True)
        os.replace(tmp_path, path)

    def _get_best_summary(self):
        bestof_folder = os.path.join(self.base_folder, "bestof")
        if not os.path.isdir(bestof_folder):
            return None
        files = [f for f in os.listdir(bestof_folder) if f.endswith(".json")]
        if not files:
            return None
        best_error = float("inf")
        best_ops = None
        best_source = None
        for name in files:
            path = os.path.join(bestof_folder, name)
            try:
                with open(path, "r", encoding="utf-8") as handle:
                    data = json.load(handle)
                raw_error = data.get("optimization_error")
                error = float(raw_error)
                if not math.isfinite(error):
                    continue
                ops = 0
                for node in data.get("nodes", []):
                    if node.get("type") in (2, 3, 4, 5, 6):
                        ops += 1
                if (
                    error < best_error
                    or (error == best_error and (best_ops is None or ops < best_ops))
                ):
                    best_error = error
                    best_ops = ops
                    best_source = name
            except Exception:
                continue
        if best_source is None:
            return None
        return {
            "error": best_error,
            "ops": best_ops,
            "source": best_source,
        }

    def write_status(self, status="running", note=None):
        if status == "running" and self._stop_requested():
            status = "stopping"
            if note is None:
                note = "stop_file_detected"
        task_files = [f for f in os.listdir(self.task_current_folder) if f.endswith('.json')]
        processing_files = [f for f in os.listdir(self.task_current_folder) if '.processing_' in f]
        pending = len(task_files)
        processing = len(processing_files)
        completed = max(self.total_tasks_created - pending - processing, 0)
        payload = {
            "status": status,
            "task_tag": self.task_tag,
            "progress": {
                "pending_tasks": pending,
                "processing_tasks": processing,
                "completed_tasks": completed,
                "total_created": self.total_tasks_created,
                "max_tasks": self.max_tasks,
            },
            "elapsed_s": round(time.time() - self.start_time, 3),
            "best_summary": self._get_best_summary(),
            "updated_at": int(time.time()),
        }
        if note:
            payload["note"] = str(note)
        try:
            self._safe_write_json(self.status_file_path, payload)
        except Exception as exc:
            if self.logger:
                self.logger.log('WARNING', f"Failed to write status.json: {str(exc)}", path=self.status_file_path)

    def save_combined_tasks(self, dag_files, task_id):
        """Save DAG files as a task, embedding full DAG contents in the task file."""
        timestamp = int(time.time() * 1000)
        filename = f"{self.task_tag}_task_{timestamp}_{task_id}.json"
       
        file_path = os.path.join(self.task_current_folder, filename)
        
        # Collect DAG contents
        dag_contents = []
        dag_ids = []
        
        for dag_filename in dag_files:
            dag_file_path = os.path.join(self.dag_pool_folder, dag_filename)
            try:
                with open(dag_file_path, 'r') as f:
                    dag_data = json.load(f)
                    # Add filename for traceability
                    dag_data['filename'] = dag_filename
                    
                    # Record DAG ID
                    if 'id' in dag_data:
                        dag_ids.append(dag_data['id'])
                    
                    dag_contents.append(dag_data)
            except Exception as e:
                self.logger.log('ERROR', f"Failed to read DAG file: {str(e)}", filename=dag_filename)
                continue
        
        # Create task data
        task_data = {
            'meta': {
                'combination_time': str(int(time.time())),
                'dag_count': len(dag_contents),
                'type': 'evolution',
                'task_tag': self.task_tag,
                'task_id': task_id,
                'dag_ids': dag_ids
            },
            'dags': dag_contents
        }
        
        lock_file = None
        try:
            # File lock with timeout protection
            lock_start_time = time.time()
            lock_file = open(self.lock_file_path, 'w')
            
            # Try to acquire lock with timeout
            try:
                fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
            except (IOError, OSError):
                # If lock not immediately available, wait up to 30s
                max_wait_time = 30
                wait_interval = 0.1
                waited_time = 0
                
                while waited_time < max_wait_time:
                    try:
                        fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
                        break
                    except (IOError, OSError):
                        time.sleep(wait_interval)
                        waited_time += wait_interval
                else:
                    # Timed out, force continue with warning
                    self.logger.log('WARNING', f"File lock acquisition timed out ({max_wait_time}s), forcing continue")
            
            lock_acquire_time = time.time() - lock_start_time
            if lock_acquire_time > 5:
                self.logger.log('WARNING', f"File lock acquisition took too long: {lock_acquire_time:.2f}s")
            
            with open(file_path, 'w') as f:
                json.dump(task_data, f, indent=2)
        except Exception as e:
            self.logger.log('ERROR', f"Failed to save task file: {str(e)}")
        finally:
            if lock_file is not None:
                try:
                    fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
                    lock_file.close()
                except Exception as e:
                    self.logger.log('ERROR', f"Failed to release file lock: {str(e)}")
        
        
        return file_path

    def get_current_task_count(self):
        """Get current task count in the folder, excluding tasks being processed."""
        # Count only pending task files
        task_files = [f for f in os.listdir(self.task_current_folder) if f.endswith('.json')]
        
        # Count processing files (for logging only)
        processing_files = [f for f in os.listdir(self.task_current_folder) if '.processing_' in f]
        
        # Return only pending task count
        task_count = len(task_files)
        
        if processing_files:
            self.logger.log('DEBUG', f"Task count - pending: {task_count}, processing: {len(processing_files)} (not counted)")
        
        return task_count
        
    def get_worker_count(self):
        """Get worker node count."""
        return self.worker_num
        
    def archive_pool_dags(self):
        """Archive DAG files from the pool (background thread version)."""
        try:
            # Get all DAG files in the pool
            dag_files = [f for f in os.listdir(self.dag_pool_folder) if f.endswith('.json')]
            
            if not dag_files:
                self.logger.log('INFO', "No files in DAG pool to archive")
                return 0
            
            # Compute current batch number
            batch_number = self.total_tasks_created // self.archive_threshold
            
            # Start background thread for archiving (non-blocking)
            archive_thread = Thread(
                target=self._background_archive_task,
                args=(dag_files.copy(), batch_number),
                daemon=True
            )
            archive_thread.start()
            
            self.logger.log('INFO', f"Background archive task started", 
                           file_count=len(dag_files), 
                           batch_number=batch_number)
            
            # Return immediately without waiting
            return len(dag_files)
            
        except Exception as e:
            self.logger.log('ERROR', f"Error starting archive task: {str(e)}")
            return 0
    
    def _background_archive_task(self, dag_files, batch_number):
        """Background archive task."""
        try:
            start_time = time.time()
            
            # Create archive directory
            archive_subfolder = os.path.join(self.dag_archive_folder, f"archive_batch_{batch_number}")
            os.makedirs(archive_subfolder, exist_ok=True)
            
            # Copy files in parallel
            copied_count = 0
            def copy_single_file(filename):
                src_path = os.path.join(self.dag_pool_folder, filename)
                dst_path = os.path.join(archive_subfolder, filename)
                try:
                    # Check if source file still exists (may have been removed)
                    if not os.path.exists(src_path):
                        return 0
                    shutil.copy2(src_path, dst_path)
                    return 1
                except Exception as e:
                    self.logger.log('ERROR', f"Background archive failed for DAG file: {str(e)}", filename=filename)
                    return 0
            
            # Use bounded thread pool for parallel copy
            max_workers = min(8, len(dag_files), 16)
            with ThreadPoolExecutor(max_workers=max_workers) as executor:
                futures = [executor.submit(copy_single_file, filename) for filename in dag_files]
                for future in futures:
                    copied_count += future.result()
            
            duration = time.time() - start_time
            avg_speed = copied_count / duration if duration > 0 else 0
            
            self.logger.log('INFO', f"Background archive task completed", 
                           copied_files=copied_count,
                           total_files=len(dag_files),
                           duration=f"{duration:.2f}s",
                           speed=f"{avg_speed:.1f} files/s",
                           archive_folder=archive_subfolder)
            
        except Exception as e:
            self.logger.log('ERROR', f"Background archive task failed: {str(e)}", 
                           batch_number=batch_number,
                           file_count=len(dag_files))
            
        finally:
            # Cleanup placeholder
            pass

    def check_and_create_tasks(self):
        """Check task count and create new tasks as needed."""
        if self._stop_requested():
            self.logger.log('INFO', "STOP file detected, stopping task creation")
            return False
        # Get current task count and worker count
        task_files = [f for f in os.listdir(self.task_current_folder) if f.endswith('.json')]
        processing_files = [f for f in os.listdir(self.task_current_folder) if '.processing_' in f]
        current_task_count = self.get_current_task_count()
        
        worker_count = self.get_worker_count()
        
        # Target task count (one third of worker count)
        target_task_count = max(1, worker_count // 3)
        
        # Create new tasks if below target and under max limit
        if current_task_count < target_task_count and current_task_count < self.max_tasks:
            self.logger.log('INFO', f"Current tasks {current_task_count} < target {target_task_count} (worker_count={worker_count}/3), creating new tasks", 
                          pending_tasks=len(task_files), 
                          processing_tasks=len(processing_files))
            tasks_created = self.create_new_tasks(target_task_count - current_task_count)
            
            # Update task-creation counter
            self.tasks_created_since_last_archive += tasks_created
            # Update total task counter
            self.total_tasks_created += tasks_created
            
            # Check if stats output is due
            if self.total_tasks_created > 0 and self.total_tasks_created % self.stats_output_frequency == 0:
                self.logger.log('INFO', f"Task combination stats", 
                               total_tasks=self.total_tasks_created, 
                               current_tasks=current_task_count,
                               pending_tasks=len(task_files),
                               processing_tasks=len(processing_files),
                               worker_count=worker_count)
                print(f"\n{'='*40}")
                print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] Task combination stats:")
                print(f"Total tasks created: {self.total_tasks_created}")
                print(f"Current task queue: {current_task_count} (pending: {len(task_files)}, processing: {len(processing_files)})")
                print(f"Worker count: {worker_count}")
                print(f"{'='*40}\n")
            
            # Check if archive is needed
            if self.tasks_created_since_last_archive >= self.archive_threshold:
                self.logger.log('INFO', f"Created {self.tasks_created_since_last_archive} tasks, archiving DAG pool")
                self.archive_pool_dags()
                self.tasks_created_since_last_archive = 0
                
            return True
        return False
            
    def create_new_tasks(self, num_tasks_to_create):
        """Create the specified number of new tasks and remove assigned DAGs from pool."""
        # Get all DAG files in pool
        all_dag_files = [f for f in os.listdir(self.dag_pool_folder) if f.endswith('.json')]
        
        if not all_dag_files:
            self.logger.log('INFO', "No DAG files in pool")
            return 0
        
        self.logger.log('INFO', f"Found in DAG pool: {len(all_dag_files)}  DAG files")
        
        # Shuffle DAG file list for diversity
        random.shuffle(all_dag_files)
        
        # Compute number of batches
        num_batches = min(num_tasks_to_create, (len(all_dag_files) + self.batch_size - 1) // self.batch_size)
        
        # Track files to remove from pool
        files_to_remove = []
        tasks_created = 0
        
        # Progress bar for batch creation
        for i in tqdm(range(num_batches), 
                     desc="Creating tasks",
                     ncols=100,
                     unit="batch",
                     mininterval=1.0):
            start_idx = i * self.batch_size
            end_idx = min((i + 1) * self.batch_size, len(all_dag_files))
            batch_files = all_dag_files[start_idx:end_idx]
            
            try:
                # Get and increment task ID
                task_id = self.task_id_counter
                self.task_id_counter += 1
                
                # Save combined task
                self.save_combined_tasks(batch_files, task_id)
                
                # Add batch files to removal list
                files_to_remove.extend(batch_files)
                tasks_created += 1
                
            except Exception as e:
                self.logger.log('ERROR', f"Error processing DAG batch: {str(e)}")
        
        # Remove used DAG files from pool
        lock_file = None
        try:
            lock_file = open(self.lock_file_path, 'w')
            fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
            for filename in files_to_remove:
                try:
                    file_path = os.path.join(self.dag_pool_folder, filename)
                    if os.path.exists(file_path):
                        os.remove(file_path)
                except Exception as e:
                    self.logger.log('ERROR', f"Failed to remove DAG file: {str(e)}", filename=filename)
        except Exception as e:
            self.logger.log('ERROR', f"Error removing DAG files: {str(e)}")
        finally:
            if lock_file is not None:
                try:
                    fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
                    lock_file.close()
                except Exception as e:
                    self.logger.log('ERROR', f"Failed to release file lock: {str(e)}")
        
        self.logger.log('INFO', f"Created {num_batches} new tasks, removed {len(files_to_remove)} DAG files from pool")
        return tasks_created

    def monitor_and_create_tasks(self, stop_event=None):
        """Continuously monitor task count and create tasks until stop signal."""
        self.logger.log('INFO', "Starting task count monitor")
        if self._stop_requested():
            self.logger.log('INFO', "STOP file detected, stopping task monitor")
            self.write_status(status="stopped", note="stop_file_detected")
            return
        self.write_status(status="running")
        
        # Initialize heartbeat counter
        heartbeat_counter = 0
        last_successful_check = time.time()
        
        while not (stop_event and stop_event.is_set()):
            try:
                if self._stop_requested():
                    self.logger.log('INFO', "STOP file detected, stopping task monitor")
                    self.write_status(status="stopped", note="stop_file_detected")
                    break
                # Heartbeat log every 10 checks
                heartbeat_counter += 1
                if heartbeat_counter % 10 == 0:
                    current_time = time.time()
                    time_since_last_check = current_time - last_successful_check
                    
                    # Get system status
                    try:
                        current_task_count = self.get_current_task_count()
                        worker_count = self.get_worker_count()
                        target_task_count = max(1, worker_count // 3)
                        
                        print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] Task creation heartbeat #{heartbeat_counter}")
                        print(f"  - Since last successful check: {time_since_last_check:.1f}s")
                        print(f"  - Current tasks: {current_task_count}, target: {target_task_count}, workers: {worker_count}")
                        print(f"  - Total tasks created: {self.total_tasks_created}")
                        
                        self.logger.log('INFO', "Task creation thread heartbeat", 
                                       heartbeat_counter=heartbeat_counter,
                                       current_tasks=current_task_count,
                                       target_tasks=target_task_count,
                                       worker_count=worker_count,
                                       total_created=self.total_tasks_created,
                                       time_since_last_check=f"{time_since_last_check:.1f}s")
                        
                    except Exception as status_error:
                        print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] Task creation heartbeat #{heartbeat_counter} - status check failed: {str(status_error)}")
                        self.logger.log('ERROR', f"Heartbeat status check failed: {str(status_error)}", heartbeat_counter=heartbeat_counter)
                
                # Execute task check and creation (with timeout protection)
                check_start_time = time.time()
                try:
                    self.check_and_create_tasks()
                    last_successful_check = time.time()
                    check_duration = last_successful_check - check_start_time
                    
                    # Warn if check took too long
                    if check_duration > 30:
                        self.logger.log('WARNING', f"Task check took too long", duration=f"{check_duration:.2f}s")
                        
                except Exception as check_error:
                    check_duration = time.time() - check_start_time
                    self.logger.log('ERROR', f"Task check failed: {str(check_error)}", 
                                   duration=f"{check_duration:.2f}s",
                                   error_type=type(check_error).__name__)
                    print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] Task check error: {str(check_error)}")
                
                # Periodic check
                if self._stop_requested():
                    self.logger.log('INFO', "STOP file detected, stopping task monitor")
                    self.write_status(status="stopped", note="stop_file_detected")
                    break
                self.write_status(status="running")
                time.sleep(self.check_interval)
                
            except Exception as e:
                error_details = {
                    'error_message': str(e),
                    'error_type': type(e).__name__,
                    'heartbeat_counter': heartbeat_counter
                }
                
                self.logger.log('ERROR', f"Critical error in task monitor: {str(e)}", **error_details)
                print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] Critical task monitor error: {str(e)}")
                self.write_status(status="error", note=str(e))
                
                # Recovery: wait longer before continuing
                time.sleep(min(60, self.check_interval * 3))
                
            except KeyboardInterrupt:
                self.logger.log('INFO', "Interrupt signal received, task creation monitor shutting down")
                break
            except SystemExit:
                self.logger.log('INFO', "System exit signal received, task creation monitor shutting down")
                break
        
        self.logger.log('INFO', f"Task creation monitor exited, total heartbeats: {heartbeat_counter}")
        self.write_status(status="stopped")

class DagIdManager:
    """Manages a globally incrementing DAG ID counter using local file locks."""
    
    def __init__(self, base_op_folder="./op", task_tag="default", logger=None):
        self.task_tag = task_tag
        self.logger = logger
        self.current_id = 0
        
        # Set up local file paths
        self.base_folder = os.path.join(base_op_folder, task_tag)
        os.makedirs(self.base_folder, exist_ok=True)
        
        self.id_file = os.path.join(self.base_folder, "dag_id_counter.txt")
        self.lock_file_path = os.path.join(self.base_folder, "dag_id_counter.lock")
        
        # Initialize ID counter
        self.current_id = self._load_or_create_counter()
    
    def _load_or_create_counter(self):
        """Load or create the ID counter."""
        if os.path.exists(self.id_file):
            try:
                with open(self.id_file, 'r') as f:
                    current_id = int(f.read().strip())
                self.logger.log('INFO', f"Loaded DAG ID counter from file", 
                               current_id=current_id, 
                               id_file=self.id_file)
                return current_id
            except:
                self.logger.log('WARNING', f"Failed to read ID file, reinitializing")
        
        # Initialize to 0
        self._save_counter(0)
        self.logger.log('INFO', f"Initialized DAG ID counter", initial_id=0)
        return 0
    
    def _save_counter(self, value):
        """Save counter value to file."""
        with open(self.id_file, 'w') as f:
            f.write(str(value))
    
    def get_next_id(self):
        """Get next available DAG ID (atomic operation)."""
        lock_file = None
        try:
            lock_file = open(self.lock_file_path, 'w')
            fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
            
            # Read current value
            if os.path.exists(self.id_file):
                with open(self.id_file, 'r') as f:
                    current_id = int(f.read().strip())
            else:
                current_id = 0
            
            # Increment and save
            new_id = current_id + 1
            self._save_counter(new_id)
            self.current_id = new_id
            
            return new_id
            
        except Exception as e:
            self.logger.log('ERROR', f"Failed to get DAG ID: {str(e)}")
            # Fallback
            backup_id = int(time.time() * 1000)
            self.logger.log('WARNING', f"Using timestamp as fallback DAG ID", id=backup_id)
            return backup_id
        finally:
            if lock_file:
                try:
                    fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
                    lock_file.close()
                except:
                    pass
    
    def get_current_id(self):
        """Get current ID counter value."""
        return self.current_id

class BestOfGraphManager:
    """Manages best-of graph storage with incrementing ID filenames and local file locks."""
    def __init__(self, base_op_folder="./op", task_tag="default", logger=None):
        self.task_tag = task_tag
        self.logger = logger
        self.current_id = 0
        
        # Set up storage directory
        self.base_folder = os.path.join(base_op_folder, task_tag)
        self.bestof_folder = os.path.join(self.base_folder, "bestof")
        
        # Create directory
        if not os.path.exists(self.bestof_folder):
            os.makedirs(self.bestof_folder)
        
        # Set up local file paths
        self.id_file = os.path.join(self.base_folder, "bestof_id_counter.txt")
        self.lock_file_path = os.path.join(self.base_folder, "bestof_id_counter.lock")
        
        # Initialize ID counter
        self.current_id = self._load_or_create_counter()
        
        self.logger.log('INFO', f"BestOfGraphManager initialized", 
                       bestof_folder=self.bestof_folder,
                       current_id=self.current_id)
    
    def _load_or_create_counter(self):
        """Load or create the ID counter."""
        if os.path.exists(self.id_file):
            try:
                with open(self.id_file, 'r') as f:
                    current_id = int(f.read().strip())
                self.logger.log('INFO', f"Loaded BestOfGraph ID counter from file", 
                               current_id=current_id, 
                               id_file=self.id_file)
                return current_id
            except:
                self.logger.log('WARNING', f"Failed to read BestOfGraph ID file, reinitializing")
        
        # Initialize to 0
        self._save_counter(0)
        self.logger.log('INFO', f"Initialized BestOfGraph ID counter", initial_id=0)
        return 0
    
    def _save_counter(self, value):
        """Save counter value to file."""
        with open(self.id_file, 'w') as f:
            f.write(str(value))
    
    def get_next_id(self):
        """Get next available BestOfGraph ID (atomic operation)."""
        lock_file = None
        try:
            lock_file = open(self.lock_file_path, 'w')
            fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
            
            # Read current value
            if os.path.exists(self.id_file):
                with open(self.id_file, 'r') as f:
                    current_id = int(f.read().strip())
            else:
                current_id = 0
            
            # Increment and save
            new_id = current_id + 1
            self._save_counter(new_id)
            self.current_id = new_id
            
            return new_id
            
        except Exception as e:
            self.logger.log('ERROR', f"Failed to get BestOfGraph ID: {str(e)}")
            # Fallback
            backup_id = int(time.time() * 1000)
            self.logger.log('WARNING', f"Using timestamp as fallback BestOfGraph ID", id=backup_id)
            return backup_id
        finally:
            if lock_file:
                try:
                    fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
                    lock_file.close()
                except:
                    pass
    
    def save_bestof_graph(self, graph_data):
        """Save a best-of graph to file.
        
        Args:
            graph_data: Graph data structure to persist as-is.
            
        Returns:
            str: Path to the saved file.
        """
        try:
            # Get incremented ID
            graph_id = self.get_next_id()
            
            # Filename format: {id}.json
            filename = f"{graph_id}.json"
            file_path = os.path.join(self.bestof_folder, filename)
            
            # Save data, preserving original structure
            with open(file_path, 'w') as f:
                json.dump(graph_data, f, indent=2)
            
            
            return file_path
            
        except Exception as e:
            self.logger.log('ERROR', f"Failed to save BestOfGraph: {str(e)}", 
                           error_type=type(e).__name__)
            raise
    
    def get_current_id(self):
        """Get current ID counter value."""
        return self.current_id
    
    def get_bestof_count(self):
        """Get the count of saved best-of graphs."""
        try:
            bestof_files = [f for f in os.listdir(self.bestof_folder) if f.endswith('.json')]
            return len(bestof_files)
        except Exception as e:
            self.logger.log('ERROR', f"Failed to get BestOfGraph count: {str(e)}")
            return 0

class InitDagGenerator:
    def __init__(self, base_op_folder="./op", task_tag="default", worker_num=None, batch_size=40, poly=3, type_string="float", logger=None, id_manager=None, init_multiplier=1):
        # Adjusted output folder to be <base_op_folder>/<task_tag>/dag/pool
        self.output_folder = os.path.join(base_op_folder, task_tag, "dag", "pool")
        self.batch_size = batch_size
        self.poly = poly
        self.type_string = type_string
        self.worker_num = worker_num if worker_num else max(1, cpu_count() - 1)
        self.init_multiplier = init_multiplier  # init multiplier factor
        self.init_num = self.worker_num * self.batch_size * self.init_multiplier  # multiply by init_multiplier
        self.logger = logger
        self.id_manager = id_manager
        
        # Ensure output_folder (dag/pool) exists
        if not os.path.exists(self.output_folder):
            os.makedirs(self.output_folder)
            
        self.logger.log('INFO', f"Init generator: target {self.init_num}  DAGs (= {self.worker_num} workers x {self.batch_size} batch x {self.init_multiplier} multiplier) into {self.output_folder}")
        
        # Create template DAG (simple input+output DAG)
        dtype = "float64" if self.type_string == "float" else self.type_string
        self.template_dag = DAG(num_inputs=1, dtype=dtype)

    def generate_single_dag(self, index):
        """Generate a single DAG."""
        # Copy template DAG
        dag = self.template_dag.copy()
        
        # Assign unique ID
        if self.id_manager:
            next_id = self.id_manager.get_next_id()
        else:
            next_id = int(time.time() * 1000) + index  # Fallback
        
        # Generate unique name with timestamp
        timestamp = int(time.time() * 1000)
        dag.name = f"dag_{timestamp}_{index}"
        filename = f"{dag.name}.json"  # .json extension
        file_path = os.path.join(self.output_folder, filename)
        
        # Convert DAG to JSON-serializable dict
        dag_dict = {
            "id": next_id,  # ID field
            "name": dag.name,
            "num_inputs": dag.num_inputs,
            "dtype": dag.dtype,
            "nodes": []
        }
        
        # Add node info
        for i, node in enumerate(dag.nodes):
            node_dict = {
                "id": i,
                "type": node.type,
                "value": node.value,
                "prev": [dag.nodes.index(prev_node) for prev_node in node.prev],
                "next": [dag.nodes.index(next_node) for next_node in node.next]
            }
            dag_dict["nodes"].append(node_dict)
        
        # Save as JSON
        with open(file_path, 'w') as f:
            json.dump(dag_dict, f, indent=2)
            
        return file_path

    def generate_all(self):
        """Generate all initial DAGs in parallel."""
        self.logger.log('INFO', f"Starting parallel generation of {self.init_num}  DAGs")
        
        # Use thread pool for parallel DAG generation
        with ThreadPoolExecutor(max_workers=min(self.worker_num, 32)) as executor:
            futures = []
            
            # Submit all tasks
            for i in range(self.init_num):
                futures.append(executor.submit(self.generate_single_dag, i))

            # Display progress with tqdm
            for future in tqdm(as_completed(futures), 
                              total=self.init_num,
                              desc="Generate initial DAGs",
                              ncols=100,
                              unit="dag",
                              mininterval=1.0):
                try:
                    future.result()  # Check for exceptions
                except Exception as e:
                    self.logger.log('ERROR', f"Failed to generate DAG: {str(e)}")
        
        self.logger.log('INFO', f"DAG generation completed")

def check_existing_task_setup(base_op_folder, task_tag, logger, min_dag_count=10):
    """Check if existing task setup is complete and usable.
    
    Args:
        base_op_folder: Base operation folder.
        task_tag: Task identifier
        logger: Logger instance.
        min_dag_count: Minimum DAG file threshold.
        
    Returns:
        tuple: (is_existing, dag_count, status_message)
    """
    try:
        # Check basic directory structure
        base_folder = os.path.join(base_op_folder, task_tag)
        dag_pool_folder = os.path.join(base_folder, "dag", "pool")
        task_current_folder = os.path.join(base_folder, "task", "current")
        dag_result_folder = os.path.join(base_folder, "dag", "result")
        bestof_folder = os.path.join(base_folder, "bestof")
        
        # Check all required directories exist
        required_folders = [base_folder, dag_pool_folder, task_current_folder, dag_result_folder, bestof_folder]
        missing_folders = [folder for folder in required_folders if not os.path.exists(folder)]
        
        if missing_folders:
            return False, 0, f"Missing required directories: {missing_folders}"
        
        # Check DAG file count in pool
        dag_files = [f for f in os.listdir(dag_pool_folder) if f.endswith('.json')]
        dag_count = len(dag_files)
        
        if dag_count < min_dag_count:
            return False, dag_count, f"Insufficient DAG files: {dag_count} < {min_dag_count}"
        
        # Check for task files (may indicate system is running)
        task_files = [f for f in os.listdir(task_current_folder) if f.endswith('.json')]
        processing_files = [f for f in os.listdir(task_current_folder) if '.processing_' in f]
        
        status_msg = f"Found complete task setup: pool={dag_count} files, queue={len(task_files)} files, processing={len(processing_files)} files"
        
        logger.log('INFO', "Checking existing task setup", 
                  task_tag=task_tag,
                  dag_count=dag_count,
                  task_count=len(task_files),
                  processing_count=len(processing_files),
                  status="complete")
        
        return True, dag_count, status_msg
        
    except Exception as e:
        error_msg = f"Error checking existing task setup: {str(e)}"
        logger.log('ERROR', error_msg, task_tag=task_tag)
        return False, 0, error_msg


def _resolve_op_root() -> str:
    return os.getenv("ANUM_OP_ROOT") or os.path.join(root_dir, "op")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--task_tag', type=str, default='default', help='Task identifier')
    parser.add_argument('--batch_size', type=int, default=40, help='Batch size')
    parser.add_argument('--worker_num', type=int, default=10, help='Worker count')
    parser.add_argument('--max_tasks', type=int, default=400, help='Max task count')
    parser.add_argument('--check_interval', type=int, default=60, help='Check interval (seconds)')
    parser.add_argument('--poly', type=int, default=2, help='Polynomial degree')
    parser.add_argument('--type_string', type=str, default='float', help='Data type')
    parser.add_argument('--run_time', type=int, default=0, help='Run time (seconds), 0=unlimited')
    parser.add_argument('--archive_threshold', type=int, default=1000, help='Archive DAG pool every N tasks')
    parser.add_argument('--stats_frequency', type=int, default=10000, help='Stats output frequency')
    parser.add_argument('--init_multiplier', type=int, default=1, help='Init DAG multiplier (init_count = worker_num * batch_size * multiplier)')
    # 
    parser.add_argument('--force_init', action='store_true', help='Force re-initialization even if task setup exists')
    parser.add_argument('--min_dag_count', type=int, default=10, help='Min DAG count for existing setup validation (default: 10)')

    args = parser.parse_args()
    
    base_op_root = _resolve_op_root()

    # Print startup information
    print("=" * 80)
    print(f"Master node started at: {time.strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Task tag: {args.task_tag}")
    print(f"Batch size: {args.batch_size}")
    print(f"Worker count: {args.worker_num}")
    print(f"Max tasks: {args.max_tasks}")
    print(f"Check interval: {args.check_interval}s")
    print(f"Archive threshold: {args.archive_threshold} tasks")
    print(f"Stats frequency: every {args.stats_frequency} tasks")
    print(f"Init DAG multiplier: {args.init_multiplier}")
    print(f"Run time: {'unlimited' if args.run_time == 0 else f'{args.run_time}s'}")
    print(f"ID management: local file lock")
    print(f"Force re-init: {'yes' if args.force_init else 'no'}")
    print(f"Min DAG threshold: {args.min_dag_count}")
    
    # Directory structure check
    base_folder = os.path.join(base_op_root, args.task_tag)
    task_current_folder = os.path.join(base_folder, "task", "current")
    dag_pool_folder = os.path.join(base_folder, "dag", "pool")
    dag_result_folder = os.path.join(base_folder, "dag", "result")
    bestof_folder = os.path.join(base_folder, "bestof")
    
    print(f"Directory structure check:")
    print(f"  Task dir: {task_current_folder} - {'exists' if os.path.exists(task_current_folder) else 'missing'}")
    print(f"  DAG pool dir: {dag_pool_folder} - {'exists' if os.path.exists(dag_pool_folder) else 'missing'}")
    print(f"  Result dir: {dag_result_folder} - {'exists' if os.path.exists(dag_result_folder) else 'missing'}")
    print(f"  BestOf dir: {bestof_folder} - {'exists' if os.path.exists(bestof_folder) else 'missing'}")
    print("=" * 80)
    
    print("Creating local logger...")

    logger = ConsoleLogger(task_tag=args.task_tag)
    
    # Check if complete task setup already exists
    skip_initialization = False
    if not args.force_init:
        print("Checking existing task setup...")
        is_existing, existing_dag_count, status_message = check_existing_task_setup(
            base_op_root, args.task_tag, logger, args.min_dag_count
        )
        
        if is_existing:
            skip_initialization = True
            print(f"✓ {status_message}")
            print(f"✓ Skipping initialization, using existing setup")
            logger.log('INFO', f"Skipping initialization, using existing setup", 
                      task_tag=args.task_tag,
                      existing_dag_count=existing_dag_count,
                      reason="Found complete task setup")
        else:
            print(f"✗ {status_message}")
            print(f"✓ Will perform full initialization")
            logger.log('INFO', f"Initialization needed", 
                      task_tag=args.task_tag,
                      reason=status_message)
    else:
        print("✓ Force re-initialization mode")
        logger.log('INFO', f"Force re-init", task_tag=args.task_tag)
    
    print("=" * 80)
    
    # Create ID manager (local file mode)
    id_manager = DagIdManager(
        base_op_folder=base_op_root,
        task_tag=args.task_tag,
        logger=logger
    )
    
    logger.log('INFO', f"Using local file storage for DAG IDs")
    
    # Generate initial DAGs if needed
    if not skip_initialization:
        print("Generating initial DAGs...")
        print("=" * 80)
        
        # Create initial DAG generator
        init_generator = InitDagGenerator(
            base_op_folder=base_op_root,
            task_tag=args.task_tag,
            worker_num=args.worker_num,
            batch_size=args.batch_size,
            poly=args.poly,
            type_string=args.type_string,
            logger=logger,
            id_manager=id_manager,
            init_multiplier=args.init_multiplier
        )
        
        # Generate initial DAGs
        try:
            init_generator.generate_all()
            print(f"✓ Initial DAG generation complete, total: {init_generator.init_num}dag")
            logger.log('INFO', f"Initial DAG generation completed", 
                      total_dags=init_generator.init_num,
                      worker_num=args.worker_num,
                      batch_size=args.batch_size,
                      init_multiplier=args.init_multiplier)
        except Exception as e:
            print(f"✗ Initial DAG generation failed: {str(e)}")
            logger.log('ERROR', f"Initial DAG generation failed: {str(e)}")
            print("Exiting")
            sys.exit(1)
        
        print("=" * 80)
    else:
        print("✓ Skipping initial DAG generation, using existing DAGs")
        logger.log('INFO', "Skipping initial DAG generation", reason="using existing setup")
        print("=" * 80)
    
    # Create combiner instance
    combiner = DagTaskCombiner(
        base_op_folder=base_op_root,
        # dag_folder and task_folder are now derived inside
        batch_size=args.batch_size,
        check_interval=args.check_interval,
        task_tag=args.task_tag,
        logger=logger,
        max_tasks=args.max_tasks,
        archive_threshold=args.archive_threshold,  # archive threshold param
        id_manager=id_manager,
        worker_num=args.worker_num  # pass worker_num
    )
    
    # Set stats output frequency
    combiner.stats_output_frequency = args.stats_frequency
    
    # Note: Workers write DAGs directly to pool, no result processor needed
    logger.log('INFO', "Workers write DAGs directly to pool, result processor not needed")
    
    # Run task creation monitor in a thread
    from threading import Thread, Event
    
    # Stop event to signal threads
    stop_event = Event()
    
    # Create and start task creation thread
    task_monitor_thread = Thread(
        target=combiner.monitor_and_create_tasks,
        args=(stop_event,)
    )
    
    print("Starting task creation monitor thread...")
    task_monitor_thread.start()
    
    try:
        # If run_time is set, stop after specified duration
        if args.run_time > 0:
            print(f"System will run for {args.run_time} s then stop")
            time.sleep(args.run_time)
            stop_event.set()
        else:
            # Run indefinitely until interrupted
            print("System running, press Ctrl+C to stop...")
            last_health_check = time.time()
            stop_file_path = os.path.join(base_op_root, args.task_tag, "STOP")
            
            while True:
                if os.path.exists(stop_file_path):
                    print("STOP file detected, shutting down...")
                    logger.log('INFO', "STOP file detected, stopping main loop")
                    stop_event.set()
                    break
                # Health check every 60s
                current_time = time.time()
                if current_time - last_health_check >= 60:
                    task_alive = task_monitor_thread.is_alive()
                    
                    print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] Thread health check:")
                    print(f"  - Task creation thread: {'running' if task_alive else 'stopped'}")
                    
                    logger.log('INFO', "Thread health check", 
                              task_thread_alive=task_alive)
                    
                    # If thread unexpectedly exited, log details
                    if not task_alive:
                        logger.log('ERROR', "Task creation thread unexpectedly exited!")
                        print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] Error: Task creation thread unexpectedly exited!")
                        
                        # Attempt to restart task creation thread
                        try:
                            print("Attempting to restart task creation thread...")
                            task_monitor_thread = Thread(
                                target=combiner.monitor_and_create_tasks,
                                args=(stop_event,)
                            )
                            task_monitor_thread.start()
                            logger.log('INFO', "Task creation thread restarted")
                            print("Task creation thread restarted successfully")
                        except Exception as restart_error:
                            logger.log('ERROR', f"Failed to restart task creation thread: {str(restart_error)}")
                            print(f"Failed to restart task creation thread: {str(restart_error)}")
                    
                    last_health_check = current_time
                
                time.sleep(1)
            
    except KeyboardInterrupt:
        print("\nInterrupt received, shutting down...")
        stop_event.set()
    
    # Wait for threads to finish
    task_monitor_thread.join()
    print("System stopped")
