import numpy as np
import json
from param import args  # Import args to get exec_cap

def log_observation(obs, detailed_task_info=False, reward=0):
    """
    Record system state information contained in the observation, according to the following rules:
    1. If source_job == "None" && num_source_exec == args.exec_cap, return complete observation data
    2. If the above condition is not met:
       - If reward == 0, return {} (empty dictionary)
       - If reward != 0, return normal observation data except for job_dags being empty
    
    Parameters:
        obs: Tuple containing system state
        detailed_task_info: Whether to record detailed task information
        reward: Reward value for the current step
    
    Returns:
        state_dict: State dictionary returned according to conditions
    """
    # Destructure observation tuple
    job_dags, source_job, num_source_exec, \
    frontier_nodes, executor_limits, \
    exec_commit, moving_executors, action_map = obs
    
    # Check source_job condition
    is_source_job_none = not hasattr(source_job, "name") or source_job.name == "None"
    is_exec_cap_condition = num_source_exec == args.exec_cap
    
    # Create basic state dictionary
    state_dict = {
        "source_job": "None" if not hasattr(source_job, "name") else source_job.name,
        "num_source_exec": num_source_exec,
        "frontier_nodes": [[node.idx, node.job_dag.name] for node in frontier_nodes],
        "executor_limits": {},
        # "exec_commit": {},
        "moving_executors": [],
        # "action_map": {},
        "system_stats": {      # System-level statistics
            "total_jobs": len(job_dags),
            "total_nodes": sum(len(job_dag.nodes) for job_dag in job_dags if hasattr(job_dag, "nodes")),
            "total_frontier_nodes": len(frontier_nodes),
            "total_executors": sum(len(job_dag.executors) for job_dag in job_dags if hasattr(job_dag, "executors")),
            "total_moving_executors": len(moving_executors.moving_executors)
        }
    }
    
    # Condition 1: Meets source_job == "None" && num_source_exec == args.exec_cap
    if is_source_job_none and is_exec_cap_condition:
        # Return complete observation data, including complete job_dags
        state_dict["job_dags"] = {}
        # Record job DAG information, each job includes information on all its nodes
        for job_dag in job_dags:
            job_name = job_dag.name
            job_info = {
                "name": job_name,
                "num_nodes": job_dag.num_nodes if hasattr(job_dag, "num_nodes") else None,
                "num_nodes_done": job_dag.num_nodes_done if hasattr(job_dag, "num_nodes_done") else None,
                "executor_count": len(job_dag.executors) if hasattr(job_dag, "executors") else 0,
                "nodes": {},  # Node information recorded directly under job_dag
                "critical_path": []  # Record critical path information
            }
            
            # Add job time information (if available)
            if hasattr(job_dag, "arrived"):
                job_info["arrived"] = job_dag.arrived
            if hasattr(job_dag, "start_time"):
                job_info["start_time"] = job_dag.start_time
            if hasattr(job_dag, "completion_time"):
                job_info["completion_time"] = job_dag.completion_time if job_dag.completion_time != float('inf') else "inf"
            if hasattr(job_dag, "completed"):
                job_info["completed"] = job_dag.completed
            
            # Record frontier nodes (if available)
            if hasattr(job_dag, "frontier_nodes"):
                job_info["frontier_nodes"] = [node.idx for node in job_dag.frontier_nodes]
            
            # Record node information
            if hasattr(job_dag, "nodes"):
                for node in job_dag.nodes:
                    # Record detailed node information
                    node_info = record_node_info(node, detailed_task_info)
                    # Add flag indicating whether it's a frontier node
                    node_info["is_frontier"] = node in frontier_nodes if frontier_nodes else False
                    # Add node information to job_info
                    job_info["nodes"][str(node.idx)] = node_info
            
            # if hasattr(job_dag, "adj_mat"):
            #     job_info["adj_mat"] = job_dag.adj_mat
            
            state_dict["job_dags"][job_name] = job_info
    # Condition 2: Above condition not met and reward == 0
    elif reward == 0:
        return {}  # Return empty dictionary
    # Condition 3: Above condition not met and reward != 0
    else:
        # job_dags is empty, others normal
        state_dict["job_dags"] = {}
    
    # Record executor limit information
    for job_dag, limit in executor_limits.items():
        job_name = job_dag.name
        state_dict["executor_limits"][job_name] = limit
    
    # # Record executor commit information
    # if hasattr(exec_commit, "commit") and isinstance(exec_commit.commit, dict):
    #     for source, targets in exec_commit.commit.items():
    #         if source is None:
    #             source_key = "None"
    #         else:
    #             source_key = source.name if hasattr(source, "name") else (source.idx,source.job_dag.name)
            
    #         state_dict["exec_commit"][source_key] = {}
    #         for node, amount in targets.items():
    #             if node is None:
    #                 node_key = "None"
    #             else:
    #                 node_key = (node.idx,node.job_dag.name)
    #             state_dict["exec_commit"][source_key][node_key] = amount
    
    # Record information about moving executors
    for exec_id, node in moving_executors.moving_executors.items():
        if node is None:
            node_info = "None"
        else:
            node_info = (node.idx,node.job_dag.name)
        state_dict["moving_executors"].append({
            "executor_id": exec_id,
            "target_node": node_info
        })
    
    # Record action mapping information
    # action_map_dict = {}
    # for action, node in action_map.map.items():
    #     state_dict["action_map"][str(action)] = (node.idx,node.job_dag.name)
    
    return state_dict


def record_node_info(node, detailed_task_info=False):
    """
    Record detailed node information, based on the structure of the Node class
    
    Parameters:
        node: Node object
        detailed_task_info: Whether to record detailed task information
    
    Returns:
        dict: Dictionary containing detailed node information
    """
    node_info = {
        # "idx": node.idx,
        "num_tasks": node.num_tasks,
        # "num_finished_tasks": node.num_finished_tasks,
        # "next_task_idx": node.next_task_idx,
        # "no_more_tasks": node.no_more_tasks,
        # "tasks_all_done": node.tasks_all_done,
        # "node_finish_time": "inf" if node.node_finish_time == np.inf else node.node_finish_time,
        # "active_executors": len(node.executors),
        # "is_schedulable": node.is_schedulable(),
        "parent_nodes": [parent.idx for parent in node.parent_nodes],
        "child_nodes": [child.idx for child in node.child_nodes],
        "descendant_nodes": [desc.idx for desc in node.descendant_nodes]
    }
    
    # Record node duration information
    try:
        node_info["get_node_duration()"] = node.get_node_duration()
    except:
        # If duration cannot be obtained, set to None
        node_info["get_node_duration()"] = None
    
    # If needed, record detailed task information
    if detailed_task_info and hasattr(node, "tasks"):
        node_info["tasks"] = {}  # Record tasks directly under the node
        
        for task in node.tasks:
            task_info = {
                "idx": task.idx,
                "duration": task.duration,
                "scheduled": not np.isnan(task.start_time) if hasattr(task, "start_time") else False
            }
            
            # Record task time information
            if hasattr(task, "start_time"):
                task_info["start_time"] = None if np.isnan(task.start_time) else float(task.start_time)
            if hasattr(task, "finish_time"):
                task_info["finish_time"] = None if np.isnan(task.finish_time) else (
                                          "inf" if task.finish_time == np.inf else float(task.finish_time))
            
            # Use task IDX as key
            node_info["tasks"][str(task.idx)] = task_info
    
    return node_info

class CustomJSONEncoder(json.JSONEncoder):
    """JSON encoder that handles various complex data types"""
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.bool_):
            return bool(obj)
        if isinstance(obj, (complex, np.complex_)):
            return {'real': obj.real, 'imag': obj.imag}
        if isinstance(obj, np.datetime64):
            return str(obj)
        if isinstance(obj, bytes):
            return obj.decode('utf-8')
        if isinstance(obj, set):
            return list(obj)
        if hasattr(obj, '__class__') and obj.__class__.__name__ == 'Executor':
            # Return a simple representation of Executor, such as its ID or other attributes
            return f"Executor({id(obj)})"  # Or other meaningful representation
        return super().default(obj)

def convert_dict_for_json(d):
    """
    Recursively convert dictionary to JSON-compatible format
    - Convert non-string keys to strings
    - Handle nested dictionaries and lists
    """
    if isinstance(d, dict):
        return {str(k): convert_dict_for_json(v) for k, v in d.items()}
    elif isinstance(d, list) or isinstance(d, tuple):
        return [convert_dict_for_json(i) for i in d]
    elif isinstance(d, set):
        return [convert_dict_for_json(i) for i in d]
    elif isinstance(d, np.ndarray):
        return d.tolist()
    elif isinstance(d, (np.integer, np.floating, np.bool_)):
        return d.item()
    else:
        return d