import json
import os
import pickle
import re
from datetime import datetime
from typing import Any

from causallearn.graph.GeneralGraph import GeneralGraph
from causallearn.graph.GraphNode import GraphNode
from causallearn.graph.Edge import Edge
from causallearn.graph.Endpoint import Endpoint

def parse_causal_graph(text: str) -> list[tuple[str]]:
    # Adjusted regex to handle both square brackets and no brackets, and ensure proper closing tags
    edges_list_pattern = re.compile(r"<edges>([^<]*?)</edges>", re.DOTALL)

    matches = edges_list_pattern.findall(text)
    edges = []

    if not matches:
        # If the parsing fails (no <edges> tags found), return -1
        return -1
    elif matches == [""]:
        # If there is nothing between the <edges></edges> tags, return an empty list
        return edges
    
    pattern = re.compile(r"\(([^)]+)\)", re.DOTALL)
    matches = pattern.findall(matches[0])
    for match in matches:
        edge = match.split(",")
        edges.append((edge[0].strip(), edge[1].strip()))
    return edges

def edgelist2nodelist(input_list: list[tuple[str]]) -> list:
    """Extract unique nodes from a list of edges."""
    seen = set()
    result = []
    for item in input_list:
        if item[0] not in seen:
            seen.add(item[0])
            result.append(item[0])
        if item[1] not in seen:
            seen.add(item[1])
            result.append(item[1])
    return result

def subset_graph(graph: list[list[str]] | GeneralGraph, nodes: list[str]) -> list[list[str]] | GeneralGraph:
    """Subset a graph to only include the specified nodes.
    Args:
        graph (list[list[str]] | GeneralGraph): The input graph, either as an edge list or a GeneralGraph object.
        nodes (list[str]): The nodes to include in the subset.
    Returns:
        list[list[str]] | GeneralGraph: The subset graph, either as an edge list or a GeneralGraph object.
    """
    if isinstance(graph, GeneralGraph):
        for node in graph.get_nodes():
            if node.get_name() not in nodes:
                graph.remove_node(node.get_name())
        return graph

    if isinstance(graph, list):
        for edge in graph:
            if edge[0] not in nodes or edge[1] not in nodes:
                graph.remove(edge)

        return graph


def check_nodes_exist(proposed_edges_list: list[tuple[str]], true_variable_list: list[str]) -> list:
    """Check if the nodes in the proposed_edges_list are actually in the true_variable_list.
    The LLM may create a causal graph with variables that are not in the dataset. 
    This is a possible hallucination, created by combining recurring acronyms or words in variable names."""
    # Extract unique nodes from the edge list
    nodes_in_edges = set(edgelist2nodelist(proposed_edges_list))
    
    # Convert variable_list to a set for comparison
    nodes_in_variables = set(true_variable_list)
    
    # Find nodes in edges that are not in the variable list
    return list(nodes_in_edges - nodes_in_variables)

def merge_causallearn_graphs(graphs: list[GeneralGraph]) -> GeneralGraph:
    """
    Merge two or more causallearn graphs into one.
    """
    merged_graph = GeneralGraph(nodes=[])

    d_separated_nodes = get_d_separated_nodes(graphs.copy())
    
    for graph in graphs:
        # Add nodes to the merged graph
        for node in graph.get_nodes():
            if node not in merged_graph.get_nodes():
                merged_graph.add_node(node)
        for edge in graph.get_graph_edges():
            if edge is None:
                continue
            if edge not in merged_graph.get_graph_edges():
                merged_graph.add_edge(edge)

    # Remove d-separated nodes from the merged graph
    for node1, node2 in d_separated_nodes:
        merged_graph.remove_edge(Edge(node1, node2, Endpoint.TAIL_AND_ARROW, Endpoint.TAIL_AND_ARROW))

    return merged_graph

def get_d_separated_nodes(graphs: list[GeneralGraph]) -> list[list[Any]]:
    d_separated_nodes = []
    for graph1 in graphs:
        for graph2 in graphs:
            if graph1 == graph2:
                continue

            overlap_nodes = overlap(graph1, graph2)

            for node1 in overlap_nodes:
                for node2 in overlap_nodes:
                    if node1 == node2:
                        continue

                    # Check if there is a d-separating set between node1 and node2
                    if graph1.is_adjacent_to(node1, node2) or graph2.is_adjacent_to(node1, node2):
                        # Add the nodes to the d-separated list
                        if graph1.is_adjacent_to(node1, node2):
                            d_separated_nodes.append(
                                [node2, node1])
                        if graph2.is_adjacent_to(node1, node2):
                            d_separated_nodes.append(
                                [node2, node1])
                            
    return d_separated_nodes

def overlap(graph1, graph2: GeneralGraph):
    """
    Check if two graphs have overlapping nodes.
    """
    nodes1 = set(graph1.get_nodes())
    nodes2 = set(graph2.get_nodes())
    return nodes1.intersection(nodes2) 

def edgelist_to_generalgraph(edge_list: list[list[str, str]]) -> GeneralGraph:
    # Extract unique nodes from the edge list
    unique_nodes = edgelist2nodelist(edge_list)
    # Create GeneralGraph nodes
    nodes = [GraphNode(n) for n in unique_nodes]
    # Create a new GeneralGraph
    g = GeneralGraph(nodes=nodes)
    # Build a mapping from node name to GraphNode
    node_dict = {node.get_name(): node for node in nodes}
    
    # Add directed edges from the edge list
    for u, v in edge_list:
        node_u = node_dict[u]
        node_v = node_dict[v]
        g.add_directed_edge(node_u, node_v)
    
    return g

def save_graph_to_file(input_file_name: str, causal_graph: list | str | GeneralGraph, filetype='txt') -> str:
    # Create the 'runs' directory if it doesn't exist
    if not os.path.exists('runs'):
        os.makedirs('runs')

    # Get the current date and time
    now = datetime.now()
    date_time = now.strftime("%Y%m%d_%H%M")

    # Create a subdirectory within 'runs' with the file name and date-time
    base_filename = os.path.basename(input_file_name).split('.')[0]
    subdirectory = f"runs/{base_filename}_{date_time}"
    if not os.path.exists(subdirectory):
        os.makedirs(subdirectory)

    if filetype == 'txt':
        # Construct the filename
        filename = f"{subdirectory}/causal_graph.txt"

        # Save the 'causal_graph' item to the file
        with open(filename, 'w') as file:
            file.write(str(causal_graph))
    
    elif filetype == 'json' and isinstance(causal_graph, list):
        # Construct the filename
        filename = f"{subdirectory}/causal_graph.json"

        # Convert causal_graph to the required JSON format
        graph_data = {"edges": [list(edge) for edge in causal_graph]}

        # Save the 'causal_graph' item to the file
        with open(filename, 'w') as file:
            json.dump(graph_data, file, indent=4)

    elif isinstance(causal_graph, GeneralGraph):
        # Construct the filename
        filename = f"{subdirectory}/causal_graph.pkl"

        # Save the GeneralGraph object to the file using pickle
        with open(filename, 'wb') as file:
            pickle.dump(causal_graph, file)

    else:
        raise ValueError("Unsupported causal_graph type. Must be a list, str, or GeneralGraph.")

    return filename