#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import pickle
import os
import time
import networkx as nx
from networkx.readwrite import json_graph
import logging


class GraphDB:
    
    def __init__(
            self, 
            db_path : str = None,
            db_name : str = None,
            enable_forgetting: bool = False
                 ) -> None :
 
        self.graph : nx.DiGraph = nx.DiGraph()
        self.db_path = f"{db_path}/{db_name}"
        
        # Forgetting algorithm related properties
        self.enable_forgetting = enable_forgetting #     
        self.access_count = 0  # Total access count, used for periodic cleanup
        
        # Try to load existing graph from disk
        self._load_graph() #        load

    def _standardize_entity_name(self, entity: str) -> str:

        entity = entity.strip()
        if entity.casefold() == 'i':
            return 'I'
        return entity

    def create_sub_graph(
            self,
            sub_graph_dict : list[dict[str, str]]
    ) -> nx.DiGraph:
 
        sub_graph = nx.DiGraph()
        current_time = time.time()
        
        for graph_item in sub_graph_dict:
            node_a = self._standardize_entity_name(graph_item["node_a"])
            relation = graph_item["relation"].strip().casefold()
            node_b = self._standardize_entity_name(graph_item["node_b"])

            sub_graph.add_node(node_a)
            sub_graph.add_node(node_b)
            
            # Add forgetting algorithm attributes to edges
            edge_attrs = {
                'relation': relation,
                'created_time': current_time,
                'access_count': 1,  # Initial access count is 1 at creation
                'last_access_time': current_time
            }
            sub_graph.add_edge(node_a, node_b, **edge_attrs)

        return sub_graph

    def merge_graph(
        self,
        sub_graph : nx.DiGraph,
        save_graph : bool = True
    ) -> None:

        self.graph = nx.compose(self.graph, sub_graph)
        if save_graph:
            self._save_graph()

    def add_node(self, node : str):
        sub_graph = nx.DiGraph()
        sub_graph.add_node(node)

        self.graph = nx.compose(self.graph, sub_graph)
        return self.graph

    def _save_graph(self):

        try:
            # Ensure directory exists
            os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
            with open(self.db_path, "wb") as f:
                pickle.dump(self.graph, f)
        except Exception as e:
            logging.error(f"Failed to save graph database: {e}")

    def search_graph(
            self,
            entities : list[str],
            adj_degree : int = 2,
    ) -> list[list[str]]:

        try:
            # Increment total access count
            self.access_count += 1
            current_time = time.time()
            
            subgraph_nodes = set()
            
            #                
            normalized_entities = [self._standardize_entity_name(entity) for entity in entities]
            subgraph_nodes.update(normalized_entities)
            
            # Find neighbors for each entity
            for entity in normalized_entities:
                if entity in self.graph.nodes:
                    # Get all neighbor nodes within specified degree
                    try:
                        neighbors = nx.single_source_shortest_path_length(
                            self.graph.copy().to_undirected(), entity, cutoff=adj_degree
                        ).keys()
                        subgraph_nodes.update(neighbors)
                    except nx.NetworkXError:
                        # Skip if node doesn't exist
                        continue
            
            # Build subgraph
            if not subgraph_nodes:
                return []
                
            subgraph = self.graph.subgraph(subgraph_nodes)
            
            # Extract relationship triplets and update access attributes
            relationships = []
            edges_to_update = []
            
            for u, v, data in subgraph.edges(data=True):
                if 'relation' in data:
                    relationships.append([u, data['relation'], v])
                    
                    # Update edge access attributes if forgetting algorithm is enabled
                    if self.enable_forgetting:
                        edges_to_update.append((u, v))
            
            # Batch update edge access attributes
            if self.enable_forgetting and edges_to_update:
                self._update_edge_access(edges_to_update, current_time)
            
            # Periodically apply forgetting algorithm (every 10 accesses)
            if self.enable_forgetting and self.access_count % 10 == 0:
                self._apply_forgetting_algorithm()
                    
            return relationships
            
        except Exception as e:
            logging.error(f"Graph search failed: {e}")
            return []

    def get_graph_stats(self) -> dict:
 
        return {
            "nodes": self.graph.number_of_nodes(),
            "edges": self.graph.number_of_edges(),
            "is_directed": self.graph.is_directed()
        }

    
    def get_node_count(self) -> int:
 
        return self.graph.number_of_nodes()
    
    def get_edge_count(self) -> int:
 
        return self.graph.number_of_edges()

    def export_graph_to_json(self, output_path: str = None) -> str:

        try:
            graph_data = json_graph.node_link_data(self.graph)
            
            if output_path:
                os.makedirs(os.path.dirname(output_path), exist_ok=True)
                with open(output_path, 'w', encoding='utf-8') as f:
                    import json
                    json.dump(graph_data, f, ensure_ascii=False, indent=2)
                    
            return graph_data
        except Exception as e:
            logging.error(f"Failed to export graph data: {e}")
            return ""

    def delete_relationships(self, relationships: list[list[str]]) -> bool:
 
        try:
            deleted_count = 0
            
            for relationship in relationships:
                if len(relationship) != 3:
                    continue
                    
                #                
                node_a = self._standardize_entity_name(relationship[0])
                relation = relationship[1].strip().casefold()
                node_b = self._standardize_entity_name(relationship[2])
                
                # Check if edge exists
                if self.graph.has_edge(node_a, node_b):
                    edge_data = self.graph.get_edge_data(node_a, node_b)
                    # Check if relation matches
                    if edge_data and edge_data.get('relation') == relation:
                        # Delete edge
                        self.graph.remove_edge(node_a, node_b)
                        deleted_count += 1

            if deleted_count > 0:
                self._save_graph()
                return True
            else:
                print("No relationships were deleted")
                return False
                
        except Exception as e:
            print(f"Error during relationship deletion: {e}")
            return False
    
    def delete_isolated_nodes(self) -> list[str]:

        try:
            isolated_nodes = list(nx.isolates(self.graph))
            
            if isolated_nodes:
                self.graph.remove_nodes_from(isolated_nodes)
                self._save_graph()
                return isolated_nodes
            else:
                print("No isolated nodes to delete")
                return []
                
        except Exception as e:
            print(f"Error during isolated node deletion: {e}")
            return []
    
    def update_relationship(self, old_relationship: list[str], new_relationship: list[str]) -> bool: 

        try:
            # Delete old relationship
            old_deleted = self.delete_relationships([old_relationship])
            
            # Add new relationship
            if len(new_relationship) == 3:

                node_a = self._standardize_entity_name(new_relationship[0])
                relation = new_relationship[1].strip().casefold()
                node_b = self._standardize_entity_name(new_relationship[2])
                
                self.graph.add_edge(node_a, node_b, relation=relation)
                self._save_graph()
                return True
            else:
                print("Invalid new relationship format")
                return False
                
        except Exception as e:
            print(f"Error during relationship update: {e}")
            return False

    def _load_graph(self):
   
        try:
            if os.path.exists(self.db_path):
                with open(self.db_path, "rb") as f:
                    self.graph = pickle.load(f)
            else:
                logging.warning(f"Database path does not exist: {self.db_path}")
        except Exception as e:
            logging.error(f"Failed to load graph database: {e}")

    def _update_edge_access(self, edges: list[tuple], current_time: float):
 
        try:
            for u, v in edges:
                if self.graph.has_edge(u, v):
                    # Get current edge attributes
                    edge_data = self.graph[u][v]
                    
                    # Update access count and last access time
                    edge_data['access_count'] = edge_data.get('access_count', 0) + 1
                    edge_data['last_access_time'] = current_time
                    
        except Exception as e:
            logging.error(f"Failed to update edge access attributes: {e}")

    def _remove_isolated_nodes(self):

        try:
            nodes_to_remove = []
            
            for node in self.graph.nodes():
                in_degree = self.graph.in_degree(node)
                out_degree = self.graph.out_degree(node)
                
                # If both in-degree and out-degree are 0, remove the node
                if in_degree == 0 and out_degree == 0:
                    nodes_to_remove.append(node)
            
            if nodes_to_remove:
                self.graph.remove_nodes_from(nodes_to_remove)
                return nodes_to_remove
            
            return []
                
        except Exception as e:
            logging.error(f"Failed to remove isolated nodes: {e}")
            return []

    def _apply_forgetting_algorithm(self):

        try:
            current_time = time.time()
            edges_to_remove = []
            
            # Forgetting parameters (shorter time for testing purposes)
            time_threshold = 30 * 60  # 30 minutes without access
            min_access_threshold = 2  # Minimum 2 accesses to avoid forgetting
            
            # Check all edges
            for u, v, data in self.graph.edges(data=True):
                # Get edge attributes (compatible with old data)
                last_access = data.get('last_access_time', data.get('created_time', current_time))
                access_count = data.get('access_count', 1)
                
                # Calculate time since last access
                time_since_access = current_time - last_access
                
                # Forgetting condition: long time without access AND low access count
                if (time_since_access > time_threshold and 
                    access_count < min_access_threshold):
                    edges_to_remove.append((u, v))
            
            # Batch remove edges
            removed_nodes = []
            if edges_to_remove:
                self.graph.remove_edges_from(edges_to_remove)
                
                # Remove isolated nodes and get list of removed nodes
                removed_nodes = self._remove_isolated_nodes()
                
                # Save updated graph
                self._save_graph()
            
            return {
                "removed_edges": len(edges_to_remove),
                "removed_nodes": removed_nodes
            }
                
        except Exception as e:
            logging.error(f"Failed to apply forgetting algorithm: {e}")
            return {
                "removed_edges": 0,
                "removed_nodes": []
            }

    def get_forgetting_stats(self) -> dict:

        try:
            if not self.enable_forgetting:
                return {"enabled": False}
                
            current_time = time.time()
            edge_ages = []
            access_counts = []
            
            for u, v, data in self.graph.edges(data=True):
                created_time = data.get('created_time', current_time)
                access_count = data.get('access_count', 1)
                
                age = current_time - created_time
                edge_ages.append(age / (24 * 3600))  # Convert to days
                access_counts.append(access_count)
            
            stats = {
                "enabled": True,
                "total_access_count": self.access_count,
                "total_edges": len(edge_ages),
                "avg_edge_age_days": sum(edge_ages) / len(edge_ages) if edge_ages else 0,
                "avg_access_count": sum(access_counts) / len(access_counts) if access_counts else 0,
                "max_access_count": max(access_counts) if access_counts else 0,
                "min_access_count": min(access_counts) if access_counts else 0
            }
            
            return stats
            
        except Exception as e:
            logging.error(f"Failed to get forgetting algorithm statistics: {e}")
            return {"enabled": self.enable_forgetting, "error": str(e)}

    def toggle_forgetting(self, enable: bool = None) -> bool:

        if enable is None:
            self.enable_forgetting = not self.enable_forgetting
        else:
            self.enable_forgetting = enable
            
        print(f"Forgetting algorithm is now {'enabled' if self.enable_forgetting else 'disabled'}")
        return self.enable_forgetting  # This function is used to toggle the forgetting algorithm state in test files

    def manual_forgetting(self) -> dict:

        if not self.enable_forgetting:
            return {"error": "Forgetting algorithm is not enabled"}
            
        try:
            before_stats = self.get_graph_stats()
            forgetting_result = self._apply_forgetting_algorithm()
            after_stats = self.get_graph_stats()
            
            result = {
                "success": True,
                "nodes_before": before_stats["nodes"],
                "edges_before": before_stats["edges"],
                "nodes_after": after_stats["nodes"],
                "edges_after": after_stats["edges"],
                "nodes_removed": before_stats["nodes"] - after_stats["nodes"],
                "edges_removed": before_stats["edges"] - after_stats["edges"],
                "removed_node_list": forgetting_result.get("removed_nodes", [])  # Add list of removed nodes
            }
            
            return result  # This function is used to manually trigger the forgetting algorithm and return result statistics in test files
            
        except Exception as e:
            return {"success": False, "error": str(e)}