"""
Base node class - Base class for all node types
"""

from __future__ import annotations
from dataclasses import dataclass, fields, asdict
from typing import List, Dict, Any, Optional, Tuple
from enum import Enum
import uuid
from llm.llm_basics import LLMMessage, LLMResponse
import os
import json
from datetime import datetime

class NodeType(Enum):
    """Node type enumeration"""
    ROOT = "root"
    REACT = "react"
    SUMMARY = "summary"
    DONE = "done"

class BaseNode:
    """
    Base node class, representing a single node in the context tree
    """
    
    # Node save directory (class variable), can be set uniformly from outside
    directory: Optional[str] = None
    
    def __init__(self,
                 parent: Optional[BaseNode] = None,
                 node_type: NodeType = NodeType.ROOT):
        self.id = uuid.uuid4()
        self.parent: Optional[BaseNode] = parent
        self.timestamp: str = datetime.now().isoformat()
        self.children: List[BaseNode] = []
        # self.source: List[Dict[str, BaseNode]] = [] # Used to record node sources, e.g. summary nodes generated by summary actions come from 4 react nodes, react nodes after summary nodes come from original react nodes
        self.source: str = "generate" # "generate", "duplicate"
        self.node_type: NodeType = node_type
        
        # Node generation related information
        self.messages: List[LLMMessage] = []
        self.response: LLMResponse = None

        # Node quality assessment
        self.quality: str = "unknown"  # "good", "bad", "unknown"
        
        # Node metadata
        self.metadata: Dict[str, Any] = {}

        self.depth: int = 0
        
        # Automatically set parent node when adding child nodes
        if parent:
            parent.add_child(self)
    
    def set_response(self, response: LLMResponse):
        """Set response"""
        self.response = response
        self.save_to_json(BaseNode.directory, filename_prefix=self.node_type.value)
    
    def set_messages(self, messages: List[LLMMessage]):
        """Set messages"""
        self.messages = messages
        self.save_to_json(BaseNode.directory, filename_prefix=self.node_type.value)
    
    def add_child(self, node: BaseNode):
        """Add child node"""
        self.children.append(node)
        node.parent = self
    
    def add_source(self, source_type: str):
        """Add source node"""
        self.source = source_type
    
    def get_path_from_root(self) -> List[BaseNode]:
        """Get path from root node to current node"""
        path = []
        current = self
        while current:
            path.append(current)
            current = current.parent
        return list(reversed(path))
    
    def get_siblings(self) -> List[BaseNode]:
        """Get sibling nodes"""
        if not self.parent:
            return []
        return [child for child in self.parent.children if child != self]
    
    def get_closest_ancestor(self, node_type: NodeType) -> Optional['BaseNode']:
        """Get closest ancestor node of specified type"""
        current = self.parent
        while current:
            if current.node_type == node_type:
                return current
            current = current.parent
        return None
    
    def estimate_token_count(self) -> int:
        """Estimate the token count of the node"""
        # Simple token estimation: English about 4 characters/token, Chinese about 1.5 characters/token
        total_chars = 0
        
        # Message content
        for message in self.messages:
            if message.content:
                total_chars += len(message.content)
        
        # Response content
        if self.response and self.response.content:
            total_chars += len(self.response.content)
        
        return max(1, total_chars // 3)  # Average estimation 

    def get_id(self) -> str:
        """Return the unique ID string of the node"""
        return str(self.id)

    def duplicate(self, parent: BaseNode = None) -> BaseNode:
        """Duplicate node"""
        # TODO: Need to copy to original node's node type
        parent = self.parent if parent is None else parent
        new_node = BaseNode(parent, self.node_type)
        new_node.depth = parent.depth+1
        new_node.messages = self.messages.copy()
        new_node.response = self.response
        # new_node.source = self.source.copy()
        new_node.add_source("duplicate")
        new_node.children = [child.duplicate(parent=new_node) for child in self.children]
        # new_node.metadata = self.metadata.copy()
        self.save_to_json(BaseNode.directory, filename_prefix=self.node_type.value)
        return new_node

    @property
    def short_id(self) -> str:
        """Return the short ID of the node, used for log recording"""
        return str(self.id)[:8]
    
    def _to_dict_content(self) -> Dict[str, Any]:
        """Converts the node's primary content to a dictionary, excluding children."""
        return {
            "id": str(self.id),
            "parent_id": str(self.parent.id) if self.parent is not None else None,
            "timestamp": self.timestamp,
            "node_type": self.node_type.value,
            "messages": [asdict(message) for message in self.messages],
            "response": asdict(self.response) if self.response else None,
            "source": self.source,
            "depth": self.depth,
            # "source": [source.short_id() for source in self.source],
            # "metadata": self.metadata,
            # "quality": self.quality,
        }

    def to_dict(self) -> Dict[str, Any]:
        """Convert node to dictionary."""
        data = self._to_dict_content()
        data["children"] = [child.to_dict() for child in self.children]
        return data
    
    def to_dict_wo_children(self) -> Dict[str, Any]:
        """Convert node to dictionary without children."""
        data = self._to_dict_content()
        return data

    def save_to_json(self, directory: Optional[str] = None, filename_prefix: str = "node") -> str:
        # Prefer parameters, then use class variables
        effective_directory = directory or BaseNode.directory
        if not effective_directory:
            raise ValueError("Failed to save node: directory not specified and BaseNode.directory not set")
        
        os.makedirs(effective_directory, exist_ok=True)
        file_path = os.path.join(effective_directory, f"{filename_prefix}_{self.get_id()}_{self.timestamp}.json")
        with open(file_path, "w", encoding="utf-8") as f:
            json.dump(self.to_dict_wo_children(), f, ensure_ascii=False, indent=2)
        return file_path

    def from_dict(self, data: Dict[str, Any]):
        """Create node from dictionary"""
        # parent and children should be set when initializing the node directly in the outer layer (e.g. BaseNode(parent, node_type)), here only responsible for restoring data
        self._from_dict(data)
    
    def _from_dict(self, data: Dict[str, Any]):
        self.id = uuid.UUID(data["id"])
        self.timestamp = data["timestamp"]
        self.node_type = NodeType(data["node_type"])
        self.messages = [LLMMessage(**message) for message in data["messages"]]
        self.response = LLMResponse(**data["response"]) if data["response"] else None
        self.source = data["source"]
        self.depth = data["depth"]