import anthropic
import json
import os
import re
import random


def extract_text_from_response(response):
    """Extract text content from Claude response, handling thinking mode special structure"""
    for block in response.content:
        # Find text type content block
        if block.type == "text":
            return block.text
    return None


def process_graph_file(graph_file_path):
    # List of domains to randomly select from
    domains = ["Business", "Medical", "Legal"]

    # Extract file number from the path (e.g., "graph_0.json" -> "0")
    file_number = re.search(r'graph_(\d+)\.json', graph_file_path).group(1)
    output_variables_file = f'variables/variables_{file_number}.json'
    output_text_file = f'text/text_{file_number}.json'

    # Check if files already exist to avoid reprocessing
    if os.path.exists(output_variables_file) and os.path.exists(output_text_file):
        print(f"Files for graph_{file_number}.json already exist. Skipping...")

        # Read existing variables file to return concepts
        try:
            with open(output_variables_file, 'r', encoding='utf-8') as f:
                data = json.load(f)
                if "Real concepts assigned to variables" in data:
                    return data["Real concepts assigned to variables"]
                return None
        except Exception as e:
            print(f"Error reading existing file: {e}")
            return None

    # Create output directory if it doesn't exist
    os.makedirs(os.path.dirname(output_variables_file), exist_ok=True)
    os.makedirs(os.path.dirname(output_text_file), exist_ok=True)

    # Load the graph data from the JSON file
    with open(graph_file_path, 'r') as f:
        graph_data = json.load(f)

    # Extract the matrix and n value
    adjacency_matrix = graph_data["adjacency_matrix"]
    n = graph_data["params"]["n"]

    # Format the matrix for display in the prompt
    matrix_str = "\n".join([str(row) for row in adjacency_matrix])

    # Randomly select a domain from the list
    selected_domain = random.choice(domains)

    # Construct the first prompt with the randomly selected domain
    first_prompt = f"""
Adjacency Matrix:
{matrix_str}

Task: Please assign concepts from a meaningful real-world {selected_domain} field to the {n} nodes in the causal DAG represented by this adjacency matrix, while fully conforming to the causal relationships between nodes.

Requirements: In your thinking, please use the following separators to assist your reasoning, and only output the final result when you are satisfied with it:
----Let me first analyze carefully---- 
(First list all relationships between nodes represented by 1s in the matrix and all non-existent relationships represented by 0s in the matrix)
----First attempt---- 
(Then write out the concepts corresponding to the nodes)
----Check for errors---- 
(Please use the complete paradigm '''First, imagine that in the real world, [variable A] occurs (or takes some value) and [variable B] subsequently occurs (or takes some value). If [variable A] had not occurred (or had taken a different value), would [variable B] still occur in the same way (or maintain the same value) under the same background conditions? If in the counterfactual scenario where [variable A] did not occur, [variable B] significantly changes (either does not occur at all, or occurs in a substantially different way, time, intensity, or characteristics), and this change is systematic rather than accidental, while all other potential background conditions and common causes that might affect [variable B] remain constant, then we can reasonably infer a causal relationship between [variable A] and [variable B], meaning [variable A] is a cause of [variable B]. Conversely, if in the counterfactual scenario, even when [variable A] does not occur, [variable B] still occurs in essentially the same way, or changes in [variable B] can be fully explained by changes in other variables, and this situation stably repeats across various background conditions, this indicates there is no direct or substantial causal relationship between [variable A] and [variable B], and the observed correlation between them may be coincidental, a spurious association due to common causes, or an indirect effect mediated through other variables rather than a true causal connection.''' to check whether the concepts conform to ALL relationships marked by 1s and do NOT conform to ALL relationships marked by 0s. If causal relationships are unreasonable, consider the reasons for errors and avoid them in the next attempt)
Begin second analysis
----Second attempt---- 
...

Your answer should be in JSON format:
{{
  "Existing causal relationships (values of 1 in the matrix)": [
    "Node 0 → Node 1",
    ...
  ],
  "Non-existing causal relationships (values of 0 in the matrix)": [
    "Node 0 → Node 1",
    ...
  ],
  "Domain": [
    "{selected_domain}"
  ],
  "Real concepts assigned to variables": [
    "Node 0: ___",
    ...
  ],
  "Relationship verification": {{
    "Existing causal relationships": [
      "___ (natural language description conforming to the reasoning paradigm)",
      ...
    ],
    "Non-existing causal relationships": [
      "___ (natural language description conforming to the reasoning paradigm)",
      ...
    ]
  }}
}}
"""

    # Initialize the Anthropic client
    client = anthropic.Anthropic(
        api_key="YOUR_API_KEY",
    )

    # Create the message - with thinking
    first_response = client.messages.create(
        model="claude-3-7-sonnet-20250219",
        max_tokens=20000,
        thinking={
            "type": "enabled",
            "budget_tokens": 18000
        },
        messages=[{
            "role": "user",
            "content": first_prompt
        }]
    )

    # Extract the content from the response
    print(f"First response received. Processing...")

    # Extract response content using the extraction method
    first_response_content = extract_text_from_response(first_response)

    # Access the response content
    try:
        # Check if we got content
        if first_response_content:
            print(f"Successfully accessed first response content")

            # Try to extract the JSON data using regex
            try:
                # Find JSON content in the response (anything between { and })
                json_match = re.search(r'({[\s\S]*})', first_response_content)
                if json_match:
                    json_str = json_match.group(1)
                    # Parse the JSON data
                    response_json = json.loads(json_str)

                    # Extract the "Real concepts assigned to variables" section
                    variables_concepts = response_json.get("Real concepts assigned to variables", [])

                    # Save to variables_x.json file
                    with open(output_variables_file, 'w', encoding='utf-8') as f:
                        json.dump(response_json, f, ensure_ascii=False, indent=2)

                    print(f"Successfully saved variable concepts to {output_variables_file}")

                    # Now let's create the second prompt for generating the natural language description
                    second_prompt = f"""
Concepts:
{json.dumps(response_json.get("Real concepts assigned to variables", []), ensure_ascii=False, indent=2)}

Adjacency matrix between concepts:
{matrix_str}

Task: Please express all concepts clearly in a paragraph of natural language (implicitly conveying relationships between concepts rather than explicitly stating them), without introducing any additional concepts.

Requirements: In your thinking, please use the following separators to assist your reasoning, and only output the final result when you are satisfied with it:
----Let me first analyze carefully---- 
(Which relationships between concepts should be indirectly described and which should not appear)
----First attempt---- 
(Try writing your paragraph)
----Check for implicitness---- 
(Even though all concepts appear in this paragraph, the causal relationships between them are not clearly stated, ensuring readers must make their own judgments)
----Check for errors---- 
(Check if the description avoids expressing relationships that don't exist, i.e., 0s in the matrix. If the description is not rigorous or not implicit, consider the reasons and begin a second analysis)
Begin second analysis
----Second attempt---- 
...

Your answer should be in JSON format:
{{
  "Natural language description": "..."
}}
"""

                    # Create the second message (continues the conversation with previous context)
                    second_response = client.messages.create(
                        model="claude-3-7-sonnet-20250219",
                        max_tokens=20000,
                        thinking={
                            "type": "enabled",
                            "budget_tokens": 18000
                        },
                        messages=[
                            {
                                "role": "user",
                                "content": first_prompt
                            },
                            {
                                "role": "assistant",
                                "content": first_response_content
                            },
                            {
                                "role": "user",
                                "content": second_prompt
                            }
                        ]
                    )

                    # Process the second response
                    second_response_content = extract_text_from_response(second_response)

                    if second_response_content:
                        print(f"Successfully accessed second response content")

                        # Extract the JSON data from the second response
                        try:
                            # Find JSON content in the response
                            json_match = re.search(r'({[\s\S]*})', second_response_content)
                            if json_match:
                                json_str = json_match.group(1)
                                # Parse the JSON data
                                text_json = json.loads(json_str)

                                # Save to text_x.json file
                                with open(output_text_file, 'w', encoding='utf-8') as f:
                                    json.dump(text_json, f, ensure_ascii=False, indent=2)

                                print(f"Successfully saved natural language description to {output_text_file}")
                            else:
                                print("Could not find JSON content in the second response")
                                # Save the full response as a fallback
                                with open(output_text_file, 'w', encoding='utf-8') as f:
                                    json.dump({"full_response": second_response_content}, f, ensure_ascii=False,
                                              indent=2)
                                print(f"Saved full second response to {output_text_file}")
                        except json.JSONDecodeError as e:
                            print(f"Error decoding JSON from second response: {e}")
                            # Save the full response as a fallback
                            with open(output_text_file, 'w', encoding='utf-8') as f:
                                json.dump({"full_response": second_response_content}, f, ensure_ascii=False, indent=2)
                            print(f"Saved full second response to {output_text_file}")
                    else:
                        print("No text content found in second response")
                        with open(output_text_file, 'w', encoding='utf-8') as f:
                            json.dump({"error": "No text content found in response"}, f, ensure_ascii=False, indent=2)

                    return variables_concepts
                else:
                    print("Could not find JSON content in the first response")
                    # Save the full response as a fallback
                    with open(output_variables_file, 'w', encoding='utf-8') as f:
                        json.dump({"full_response": first_response_content}, f, ensure_ascii=False, indent=2)
                    print(f"Saved full first response to {output_variables_file}")
                    return None
            except json.JSONDecodeError as e:
                print(f"Error decoding JSON from first response: {e}")
                print("Raw response content:")
                print(first_response_content[:500])  # Print first 500 chars

                # Save the full response as a fallback
                with open(output_variables_file, 'w', encoding='utf-8') as f:
                    json.dump({"full_response": first_response_content}, f, ensure_ascii=False, indent=2)
                print(f"Saved full first response to {output_variables_file}")
                return None
        else:
            print("No text content found in first response")
            # Save error info to file
            with open(output_variables_file, 'w', encoding='utf-8') as f:
                json.dump({"error": "No text content found in response"}, f, ensure_ascii=False, indent=2)
            return None
    except Exception as e:
        print(f"Error processing response: {e}")
        # Check response structure for debugging
        print(f"Response structure: {dir(first_response)}")
        if hasattr(first_response, 'content'):
            print(f"Content structure: {first_response.content}")
            print(f"Content types: {[block.type for block in first_response.content if hasattr(block, 'type')]}")
        return None


# Process all graph files in the directory
def process_all_graph_files(input_dir='./dag_data', max_files=None):
    results = {}
    processed_count = 0

    # Sort filenames to process them in order
    filenames = sorted(os.listdir(input_dir))

    for filename in filenames:
        if filename.startswith('graph_') and filename.endswith('.json'):
            # Skip if we've already processed the maximum number of files
            if max_files is not None and processed_count >= max_files:
                break

            print(f"\nProcessing {filename}...")
            file_path = os.path.join(input_dir, filename)
            variables_concepts = process_graph_file(file_path)

            if variables_concepts:
                file_number = re.search(r'graph_(\d+)\.json', filename).group(1)
                results[file_number] = variables_concepts

            processed_count += 1

    print(f"\nCompleted processing {processed_count} graph files.")
    return results


# Process the first three graph files from the specified directory
input_dir = './dag_data'
print(f"Looking for graph files in: {input_dir}")

processed_count = 0
for filename in sorted(os.listdir(input_dir)):
    if filename.startswith('graph_') and filename.endswith('.json') and processed_count < 10000:
        print(f"\nProcessing {filename}...")
        file_path = os.path.join(input_dir, filename)
        variables_concepts = process_graph_file(file_path)

        if variables_concepts:
            file_number = re.search(r'graph_(\d+)\.json', filename).group(1)
            print(f"\nExtracted variable concepts for graph_{file_number}.json:")
            for concept in variables_concepts:
                print(concept)

        processed_count += 1

print(f"\nCompleted processing {processed_count} graph files.")