import anthropic
import json
import os
import re
import glob
import sys
from datetime import datetime


def extract_text_from_response(response):
    """Extract text content from Claude response, handling thinking mode special structure"""
    for block in response.content:
        # Find text type content block
        if block.type == "text":
            return block.text
    return None


def process_file_pair(text_file_path, json_file_path, output_dir, api_key, log_file=None):
    """Process a pair of text and json files, using their contents to create a prompt and call the LLM"""
    # Extract file name for the output
    file_name = os.path.basename(text_file_path).split('.')[0]
    output_file = os.path.join(output_dir, f"{file_name}.json")

    # Log message
    log_message = f"Processing {file_name}: {text_file_path} with {json_file_path}"
    print(log_message)
    if log_file:
        log_file.write(f"{datetime.now().isoformat()} - {log_message}\n")

    # Check if file already exists to avoid reprocessing
    if os.path.exists(output_file):
        log_message = f"Output file for {file_name} already exists. Skipping..."
        print(log_message)
        if log_file:
            log_file.write(f"{datetime.now().isoformat()} - {log_message}\n")
        return

    # Load the text data from the txt file
    try:
        with open(text_file_path, 'r', encoding='utf-8') as f:
            text_data = f.read()
    except Exception as e:
        log_message = f"Error reading text file {text_file_path}: {e}"
        print(log_message)
        if log_file:
            log_file.write(f"{datetime.now().isoformat()} - {log_message}\n")
        return

    # Load the important concepts from the json file
    try:
        with open(json_file_path, 'r', encoding='utf-8') as f:
            json_data = json.load(f)
    except Exception as e:
        log_message = f"Error reading JSON file {json_file_path}: {e}"
        print(log_message)
        if log_file:
            log_file.write(f"{datetime.now().isoformat()} - {log_message}\n")
        return

    # Extract important concepts based on the structure of the JSON file
    # Flexibly handle different possible JSON structures
    if isinstance(json_data, list):
        important_concepts = json_data
    elif isinstance(json_data, dict):
        # Try different possible keys that might contain the concepts
        for key in ["concepts", "important_concepts", "entities", "keywords"]:
            if key in json_data and isinstance(json_data[key], list):
                important_concepts = json_data[key]
                break
        else:
            # If no recognized key found, use all values as concepts
            important_concepts = list(json_data.values())
    else:
        log_message = f"Warning: Unexpected JSON structure in {json_file_path}"
        print(log_message)
        if log_file:
            log_file.write(f"{datetime.now().isoformat()} - {log_message}\n")
        important_concepts = [str(json_data)]

    # Format the important concepts for display in the prompt
    concepts_str = "\n".join(str(concept) for concept in important_concepts)

    # Construct the prompt following the specified format
    prompt = f"""
Text:
{text_data}

Important concepts appearing in the text:
{concepts_str}

====================================

Task: For the text and the important concepts appearing in it, please infer the **direct causal relationships** between each concept based on the text and common sense reasoning (causal relationships are not the same as correlations. For example, high temperature has causal relationships with both the number of drownings and ice cream sales, but the number of drownings and ice cream sales only have correlation without direct causal relationship).

Requirements: The format for annotating causal relationships for each text should be:
0101 (means that the first concept has direct causal relationships with the second and fourth concepts, and the first concept is the cause of the second and fourth concepts)
0010 (means that the second concept has a direct causal relationship with the third concept, and the second concept is the cause of the third concept)
0000 (means that the third concept is not the cause of any other concept)
0100 (similarly, ...)

Your response must be in .json format containing the following:
{{
  "adjacency matrix": [
    [0,1,0,...],
    [0,0,1,...],
    ...
  ]
}}
"""

    # Initialize the Anthropic client
    client = anthropic.Anthropic(
        api_key=api_key,
    )

    try:
        # Create the message - with thinking, mirroring the pattern in the pasted code
        response = client.messages.create(
            model="claude-3-7-sonnet-20250219",
            max_tokens=20000,
            thinking={
                "type": "enabled",
                "budget_tokens": 18000
            },
            messages=[{
                "role": "user",
                "content": prompt
            }]
        )

        # Extract the content from the response
        log_message = f"Response received for {file_name}. Processing..."
        print(log_message)
        if log_file:
            log_file.write(f"{datetime.now().isoformat()} - {log_message}\n")

        # Extract response content using the extraction method
        response_content = extract_text_from_response(response)

        # Process the response
        if response_content:
            try:
                # Find JSON content in the response (anything between { and })
                json_match = re.search(r'({[\s\S]*})', response_content)
                if json_match:
                    json_str = json_match.group(1)
                    # Parse the JSON data
                    response_json = json.loads(json_str)

                    # Save to output file
                    with open(output_file, 'w', encoding='utf-8') as f:
                        json.dump(response_json, f, ensure_ascii=False, indent=2)

                    log_message = f"Successfully saved results to {output_file}"
                    print(log_message)
                    if log_file:
                        log_file.write(f"{datetime.now().isoformat()} - {log_message}\n")
                    return response_json
                else:
                    log_message = f"Could not find JSON content in the response for {file_name}"
                    print(log_message)
                    if log_file:
                        log_file.write(f"{datetime.now().isoformat()} - {log_message}\n")

                    # Save the full response as a fallback
                    with open(output_file, 'w', encoding='utf-8') as f:
                        json.dump({"full_response": response_content}, f, ensure_ascii=False, indent=2)

                    log_message = f"Saved full response to {output_file}"
                    print(log_message)
                    if log_file:
                        log_file.write(f"{datetime.now().isoformat()} - {log_message}\n")
                    return None
            except json.JSONDecodeError as e:
                log_message = f"Error decoding JSON from response for {file_name}: {e}"
                print(log_message)
                if log_file:
                    log_file.write(f"{datetime.now().isoformat()} - {log_message}\n")

                # Save the full response as a fallback
                with open(output_file, 'w', encoding='utf-8') as f:
                    json.dump({"full_response": response_content}, f, ensure_ascii=False, indent=2)

                log_message = f"Saved full response to {output_file}"
                print(log_message)
                if log_file:
                    log_file.write(f"{datetime.now().isoformat()} - {log_message}\n")
                return None
        else:
            log_message = f"No text content found in response for {file_name}"
            print(log_message)
            if log_file:
                log_file.write(f"{datetime.now().isoformat()} - {log_message}\n")

            # Save error info to file
            with open(output_file, 'w', encoding='utf-8') as f:
                json.dump({"error": "No text content found in response"}, f, ensure_ascii=False, indent=2)
            return None
    except Exception as e:
        log_message = f"Error calling Anthropic API for {file_name}: {e}"
        print(log_message)
        if log_file:
            log_file.write(f"{datetime.now().isoformat()} - {log_message}\n")

        # Save error info to file
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump({"error": f"API error: {str(e)}"}, f, ensure_ascii=False, indent=2)
        return None


def process_all_files(datasets_dir, output_dir, api_key):
    """Process all text files in the datasets directory and its subdirectories"""
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Create a log file
    log_file_path = os.path.join(output_dir, "processing_log.txt")
    log_file = open(log_file_path, "a", encoding="utf-8")

    try:
        # Log start time
        log_message = f"Starting processing at {datetime.now().isoformat()}"
        print(log_message)
        log_file.write(f"{log_message}\n")

        # Find all subdirectories in the datasets directory
        try:
            subdirs = [d for d in os.listdir(datasets_dir) if os.path.isdir(os.path.join(datasets_dir, d))]
        except Exception as e:
            log_message = f"Error accessing datasets directory {datasets_dir}: {e}"
            print(log_message)
            log_file.write(f"{datetime.now().isoformat()} - {log_message}\n")
            return

        processed_count = 0
        skipped_count = 0

        for subdir in subdirs:
            subdir_path = os.path.join(datasets_dir, subdir)

            # Find all txt files in this subdirectory
            try:
                txt_files = glob.glob(os.path.join(subdir_path, "*.txt"))
            except Exception as e:
                log_message = f"Error accessing subdirectory {subdir_path}: {e}"
                print(log_message)
                log_file.write(f"{datetime.now().isoformat()} - {log_message}\n")
                continue

            for txt_file in txt_files:
                # Get the corresponding json file (same name but .json extension)
                base_name = os.path.basename(txt_file).split('.')[0]
                json_file = os.path.join(subdir_path, f"{base_name}.json")

                if os.path.exists(json_file):
                    process_file_pair(txt_file, json_file, output_dir, api_key, log_file)
                    processed_count += 1
                else:
                    log_message = f"Skipping {txt_file} - no corresponding JSON file found."
                    print(log_message)
                    log_file.write(f"{datetime.now().isoformat()} - {log_message}\n")
                    skipped_count += 1

        # Log completion
        log_message = f"Completed processing {processed_count} file pairs, skipped {skipped_count} files."
        print(log_message)
        log_file.write(f"{datetime.now().isoformat()} - {log_message}\n")

    except Exception as e:
        log_message = f"Unexpected error: {e}"
        print(log_message)
        log_file.write(f"{datetime.now().isoformat()} - {log_message}\n")

    finally:
        # Close the log file
        log_file.close()


# Get the API key from an environment variable or a config file
def get_api_key():
    # Try to get the API key from an environment variable
    api_key = os.environ.get("ANTHROPIC_API_KEY")

    # If not found, try to load it from a config file
    if not api_key:
        try:
            with open("config.json", "r") as f:
                config = json.load(f)
                api_key = config.get("api_key")
        except FileNotFoundError:
            pass

    if not api_key:
        print(
            "ERROR: API key not found. Please set the ANTHROPIC_API_KEY environment variable or create a config.json file with an 'api_key' field.")
        sys.exit(1)

    return api_key


# Main execution
if __name__ == "__main__":
    datasets_dir = "../../datasets"
    output_dir = "./results"

    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    api_key = get_api_key()

    print(f"Looking for text files in: {datasets_dir}")
    process_all_files(datasets_dir, output_dir, api_key)