import os
import json
import time
from datetime import datetime
from typing import List, Dict, Any, Optional, Tuple
from dotenv import load_dotenv
import openai

# Add this import for token analysis
from scripts.utils.token_utils import analyze_response_tokens, extract_reasoning_tokens

# Load .env file
load_dotenv()


def load_api_key():
    """
    Loads the OpenAI API key from environment variables.

    Returns:
    - str: The API key as a string.

    Raises:
    - ValueError: If the API key is not found in the environment.
    """
    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        raise ValueError(
            "OpenAI API key not found. Set it in a .env file or as an environment variable."
        )
    return api_key


def query_openai(prompt, model="gpt-4o-mini") -> Tuple[str, Dict[str, Any]]:
    """
    Sends a text prompt to OpenAI's API and returns the generated response with token usage.

    Parameters:
    - prompt (str): The input prompt to send to the model.
    - model (str): The OpenAI model to use (default: 'gpt-4o-mini').

    Returns:
    - Tuple[str, Dict]: (response_text, token_usage_dict)

    Notes:
    - Falls back to 'gpt-4o-mini' if an unsupported model is specified.
    - Token usage includes both API-reported and analyzed breakdown
    """
    openai.api_key = load_api_key()

    supported_models = ["gpt-4o-mini", "o3-mini", "o4-mini", "gpt-4.1", "gpt-4.1-nano"]
    if model not in supported_models:
        model = "gpt-4o-mini"

    client = openai.OpenAI()

    response = client.chat.completions.create(
        model=model, messages=[{"role": "user", "content": prompt}]
    )

    response_text = response.choices[0].message.content

    # Extract API usage data
    api_usage = None
    if hasattr(response, "usage") and response.usage:
        api_usage = {
            "prompt_tokens": response.usage.prompt_tokens,
            "completion_tokens": response.usage.completion_tokens,
            "total_tokens": response.usage.total_tokens,
        }

    # Analyze tokens comprehensively
    token_usage = analyze_response_tokens(
        prompt_text=prompt,
        response_text=response_text,
        api_usage=api_usage,
    )

    return response_text, token_usage


def submit_batch(
    prompts: List[str],
    output_paths: List[str],
    model: str = "gpt-4o-mini",
    batch_name: Optional[str] = None,
) -> str:
    """
    Submits multiple prompts as a batch job to OpenAI's Batch API.

    Parameters:
    - prompts: List of prompt strings to process
    - output_paths: List of file paths where responses should be saved
    - model: OpenAI model to use
    - batch_name: Optional name for the batch (for tracking purposes)

    Returns:
    - str: Batch ID for checking status later

    Notes:
    - This function only submits the job; use check_batch_status() to monitor completion
    - Use retrieve_batch_results() to download and save results when complete
    """
    if not prompts or len(prompts) != len(output_paths):
        raise ValueError(
            "Prompts and output_paths must be non-empty and of the same length"
        )

    # Initialize API client
    client = openai.OpenAI(api_key=load_api_key())

    # Create directories
    batch_dir = "batch_jobs"
    os.makedirs(batch_dir, exist_ok=True)

    # Generate batch name if not provided
    if not batch_name:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        batch_name = f"batch_{timestamp}"

    # Create metadata to track output paths
    batch_metadata = {
        "output_paths": output_paths,
        "model": model,
        "timestamp": datetime.now().isoformat(),
        "total_requests": len(prompts),
    }

    # Create batch input file
    input_file_path = f"{batch_dir}/{batch_name}_input.jsonl"

    # Write all requests to JSONL file
    with open(input_file_path, "w", encoding="utf-8") as f:
        for i, prompt in enumerate(prompts):
            request = {
                "custom_id": f"request-{i+1}",
                "method": "POST",
                "url": "/v1/chat/completions",
                "body": {
                    "model": model,
                    "messages": [{"role": "user", "content": prompt}],
                },
            }
            f.write(json.dumps(request) + "\n")

    # Save metadata for later retrieval
    with open(f"{batch_dir}/{batch_name}_metadata.json", "w", encoding="utf-8") as f:
        json.dump(batch_metadata, f)

    # Upload batch file
    try:
        print(f"Uploading batch file with {len(prompts)} requests...")
        with open(input_file_path, "rb") as f:
            file = client.files.create(file=f, purpose="batch")

        # Create batch job
        print(f"Creating batch job for model {model}...")
        batch = client.batches.create(
            input_file_id=file.id,
            endpoint="/v1/chat/completions",
            completion_window="24h",
            metadata={"batch_name": batch_name},
        )

        # Save batch info
        batch_info = {
            "batch_id": batch.id,
            "input_file_id": file.id,
            "status": batch.status,
            "created_at": batch.created_at,
            "batch_name": batch_name,
        }

        with open(f"{batch_dir}/{batch_name}_info.json", "w", encoding="utf-8") as f:
            json.dump(batch_info, f)

        print(f"Batch job submitted successfully! Batch ID: {batch.id}")
        print(
            f"The job contains {len(prompts)} requests and will complete within 24 hours."
        )
        print(f"Check status with: check_batch_status('{batch.id}')")

        return batch.id

    except Exception as e:
        print(f"Error submitting batch: {str(e)}")
        raise


def check_batch_status(batch_id: str) -> Dict[str, Any]:
    """
    Checks the status of a batch job.

    Parameters:
    - batch_id: The ID of the batch to check

    Returns:
    - dict: Status information about the batch
    """
    client = openai.OpenAI(api_key=load_api_key())

    try:
        batch = client.batches.retrieve(batch_id)

        status_info = {
            "batch_id": batch.id,
            "status": batch.status,
            "created_at": batch.created_at,
            "total_requests": batch.request_counts.total,
            "completed_requests": batch.request_counts.completed,
            "failed_requests": batch.request_counts.failed,
            "progress": f"{batch.request_counts.completed / max(1, batch.request_counts.total) * 100:.1f}%",
            "is_complete": batch.status == "completed",
        }

        print(f"Batch {batch_id} status: {batch.status}")
        print(
            f"Progress: {status_info['progress']} ({batch.request_counts.completed}/{batch.request_counts.total})"
        )

        if batch.status == "completed":
            print(
                "✅ Batch has completed! Retrieve results with retrieve_batch_results()"
            )
        elif batch.status in ["failed", "expired", "cancelled"]:
            print(f"❌ Batch has {batch.status}!")
        else:
            print(f"⏳ Batch is still {batch.status}...")

        return status_info

    except Exception as e:
        print(f"Error checking batch status: {str(e)}")
        raise


def retrieve_batch_results(batch_id: str) -> Dict[str, Any]:
    """
    Retrieves and saves the results of a completed batch job with token analysis.

    Parameters:
    - batch_id: The ID of the batch to retrieve

    Returns:
    - dict: Information about retrieved results including token usage
    """
    client = openai.OpenAI(api_key=load_api_key())
    batch_dir = "batch_jobs"

    try:
        # Get batch information
        batch = client.batches.retrieve(batch_id)

        if batch.status != "completed":
            print(
                f"⚠️ Batch is not completed (status: {batch.status}). Cannot retrieve results."
            )
            return {"status": batch.status, "error": "Batch not completed"}

        # Find the batch name and metadata file
        batch_name = None
        for filename in os.listdir(batch_dir):
            if filename.endswith("_info.json"):
                with open(f"{batch_dir}/{filename}", "r", encoding="utf-8") as f:
                    info = json.load(f)
                    if info.get("batch_id") == batch_id:
                        batch_name = info.get("batch_name")
                        break

        if not batch_name:
            raise ValueError(
                f"Could not find batch information for batch ID {batch_id}"
            )

        # Load metadata (contains output paths)
        with open(
            f"{batch_dir}/{batch_name}_metadata.json", "r", encoding="utf-8"
        ) as f:
            metadata = json.load(f)

        output_paths = metadata.get("output_paths", [])

        # Download the output file
        print(f"Downloading results for batch {batch_id}...")
        output_file_content = client.files.content(batch.output_file_id)
        output_file_path = f"{batch_dir}/{batch_name}_output.jsonl"

        # Handle HttpxBinaryResponseContent correctly
        with open(output_file_path, "wb") as f:
            f.write(output_file_content.read())

        # Process outputs and save to specified paths with token analysis
        print("Processing outputs and saving to specified paths...")
        results = {"saved_outputs": [], "errors": [], "total_tokens": 0}
        request_map = {}

        # Read the output file line by line
        with open(output_file_path, "r", encoding="utf-8") as f:
            for line in f:
                response_data = json.loads(line)
                custom_id = response_data.get("custom_id", "")
                if not custom_id.startswith("request-"):
                    continue

                # Extract the index from the custom_id (e.g., "request-1" -> 0)
                try:
                    idx = int(custom_id.split("-")[1]) - 1
                    if idx < len(output_paths):
                        request_map[idx] = response_data
                except (ValueError, IndexError):
                    continue

        # Save each response to its corresponding output path with token analysis
        for idx, output_path in enumerate(output_paths):
            if idx in request_map:
                response_data = request_map[idx]

                # Extract the response content
                response_body = response_data.get("response", {}).get("body", "{}")
                if isinstance(response_body, str):
                    try:
                        response_body = json.loads(response_body)
                    except json.JSONDecodeError:
                        results["errors"].append(
                            {
                                "index": idx,
                                "output_path": output_path,
                                "error": "Failed to parse response body",
                            }
                        )
                        continue

                try:
                    response_content = (
                        response_body.get("choices", [{}])[0]
                        .get("message", {})
                        .get("content", "")
                    )

                    # Extract usage data from batch response
                    api_usage = response_body.get("usage", {})

                    # We need the original prompt for full token analysis
                    # For batch jobs, we'll do a simpler analysis since we don't store original prompts
                    token_usage = {
                        "input_tokens": api_usage.get("prompt_tokens"),
                        "output_tokens": api_usage.get("completion_tokens"),
                        "total_tokens": api_usage.get("total_tokens"),
                        "api_reported": api_usage,
                        "estimation_method": "api",
                        "reasoning_tokens": None,  # Could be analyzed if needed
                        "answer_tokens": None,
                    }

                    # Analyze reasoning breakdown if response follows expected format
                    if response_content:
                        reasoning_breakdown = extract_reasoning_tokens(response_content)
                        token_usage.update(
                            {
                                "reasoning_tokens": reasoning_breakdown.get(
                                    "reasoning_tokens"
                                ),
                                "answer_tokens": reasoning_breakdown.get(
                                    "answer_tokens"
                                ),
                                "has_thinking_section": reasoning_breakdown.get(
                                    "has_thinking_section"
                                ),
                                "has_answer_section": reasoning_breakdown.get(
                                    "has_answer_section"
                                ),
                            }
                        )

                    # Create directory if it doesn't exist
                    os.makedirs(os.path.dirname(output_path), exist_ok=True)

                    # Save the response
                    with open(output_path, "w", encoding="utf-8") as f:
                        f.write(response_content)

                    # Save token usage
                    token_output_path = output_path.replace(".txt", "_tokens.json")
                    with open(token_output_path, "w", encoding="utf-8") as f:
                        json.dump(token_usage, f, indent=2)

                    results["saved_outputs"].append(
                        {
                            "index": idx,
                            "output_path": output_path,
                            "token_usage": token_usage,
                        }
                    )

                    # Track total tokens
                    if token_usage.get("total_tokens"):
                        results["total_tokens"] += token_usage["total_tokens"]

                except (IOError, OSError) as e:
                    results["errors"].append(
                        {"index": idx, "output_path": output_path, "error": str(e)}
                    )
            else:
                results["errors"].append(
                    {
                        "index": idx,
                        "output_path": output_path,
                        "error": "No response found for this index",
                    }
                )

        print(f"✅ Saved {len(results['saved_outputs'])} responses!")
        print(f"📊 Total tokens used: {results['total_tokens']:,}")
        if results["errors"]:
            print(f"⚠️ Encountered {len(results['errors'])} errors.")

        return {
            "batch_id": batch_id,
            "status": "completed",
            "saved_count": len(results["saved_outputs"]),
            "error_count": len(results["errors"]),
            "total_tokens": results["total_tokens"],
            "details": results,
        }

    except Exception as e:
        print(f"Error retrieving batch results: {str(e)}")
        raise


def solve_graph(prompt, output_path, model="gpt-4o-mini"):
    """
    Sends a graph transformation prompt to OpenAI and saves the resulting output with token usage.

    If model ends with '-batch-api', it will submit a single request using the Batch API,
    wait for completion, and retrieve the result.

    Parameters:
    - prompt (str): The prompt text describing the input-output transformation.
    - output_path (str): Path to the file where the response should be saved.
    - model (str): The OpenAI model to use (default: 'gpt-4o-mini').
      Append '-batch-api' to use the Batch API (e.g., 'gpt-4o-mini-batch-api').

    Side Effects:
    - Creates necessary directories for the output path if they do not exist.
    - Writes the OpenAI response to a text file.
    - Collects token usage metadata for analysis.
    """
    # Check if using Batch API
    if model.endswith("-batch-api"):
        base_model = model.replace("-batch-api", "")
        print(f"Solving benchmark using OpenAI Batch API with model: {base_model}...")

        # Submit single request as a batch
        batch_id = submit_batch(
            prompts=[prompt], output_paths=[output_path], model=base_model
        )

        # Wait for completion (with timeout)
        print("Waiting for batch completion...")
        timeout = 600  # 10 minutes timeout
        start_time = time.time()

        while time.time() - start_time < timeout:
            status = check_batch_status(batch_id)
            if status["is_complete"]:
                # Retrieve and save results
                retrieve_batch_results(batch_id)
                return

            # Wait before checking again
            print("Still processing... waiting 10 seconds.")
            time.sleep(10)

        print("Batch is taking longer than expected. You can check status later with:")
        print(f"check_batch_status('{batch_id}')")
        print(
            f"And retrieve results when complete with: retrieve_batch_results('{batch_id}')"
        )
    else:
        # Use regular API
        print(f"Solving benchmark using OpenAI API with model: {model}...")

        # Query OpenAI for a solution (now returns tokens too)
        ai_response, token_usage = query_openai(prompt, model)

        # Save AI response
        os.makedirs(os.path.dirname(output_path), exist_ok=True)

        with open(output_path, "w", encoding="utf-8") as f:
            f.write(ai_response)

        # Save token usage metadata alongside response
        token_output_path = output_path.replace(".txt", "_tokens.json")
        with open(token_output_path, "w", encoding="utf-8") as f:
            json.dump(token_usage, f, indent=2)

        print(f"✅ Solution saved to {output_path}")
        print(
            f"📊 Token usage: {token_usage['total_tokens']} total ({token_usage['input_tokens']} input, {token_usage['output_tokens']} output)"
        )
        if token_usage.get("reasoning_tokens"):
            print(
                f"🧠 Reasoning tokens: {token_usage['reasoning_tokens']} ({token_usage.get('reasoning_ratio', 0):.2%} of response)"
            )


def solve_graphs_batch(
    prompts, output_paths, model="gpt-4o-mini-batch-api", batch_name=None
):
    """
    Submits multiple graph transformation tasks as a single batch job.

    This function does not wait for completion but returns the batch ID
    that can be used to check status and retrieve results later.

    Parameters:
    - prompts (List[str]): List of prompts to submit.
    - output_paths (List[str]): Paths where responses should be saved.
    - model (str): OpenAI model to use, with or without '-batch-api' suffix.
    - batch_name (str, optional): Custom name for the batch (for tracking purposes).

    Returns:
    - str: Batch ID for status checking and result retrieval.
    """
    # Ensure we're using the batch API
    if not model.endswith("-batch-api"):
        model = f"{model}-batch-api"

    # Extract base model name
    base_model = model.replace("-batch-api", "")

    # Submit the batch
    print(f"Submitting batch of {len(prompts)} prompts using model {base_model}...")
    batch_id = submit_batch(
        prompts=prompts,
        output_paths=output_paths,
        model=base_model,
        batch_name=batch_name,
    )

    print("\nBatch submitted successfully!")
    print(f"To check status later, use: check_batch_status('{batch_id}')")
    print(
        f"To retrieve results when complete, use: retrieve_batch_results('{batch_id}')"
    )

    return batch_id
