import os
import json
import time
from typing import List, Dict, Any, Tuple
from dotenv import load_dotenv
import requests

import asyncio
import aiohttp
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from tqdm import tqdm

# Add this import for token analysis
from scripts.utils.token_utils import analyze_response_tokens

# Load .env file
load_dotenv()

# one-time banner guard
_printed_model_note = False


def load_api_key():
    """
    Loads the DeepInfra API key from environment variables.

    Returns:
    - str: The API key as a string.

    Raises:
    - ValueError: If the API key is not found in the environment.
    """
    api_key = os.getenv("DEEPINFRA_API_KEY")
    if not api_key:
        raise ValueError(
            "DeepInfra API key not found. Set DEEPINFRA_KEY in your environment."
        )
    return api_key


def query_deepinfra(
    prompt, model="Qwen/QwQ-32B", verbose: bool = False
) -> Tuple[str, Dict[str, Any]]:
    """
    Sends a text prompt to DeepInfra's API and returns the generated response with token usage.
    Retries transient errors with exponential backoff. Optional verbose banner.
    """
    api_key = load_api_key()

    # Canonical model names on DeepInfra
    model_mapping = {
        "qwq-32b": "Qwen/QwQ-32B",
        "qwq": "Qwen/QwQ-32B",
        "deepseek-r1": "deepseek-ai/DeepSeek-R1",
    }

    global _printed_model_note
    if (model.lower() in ("qwq", "qwq-32b")) and verbose and not _printed_model_note:
        print(f"⚠️ Note: Attempting to use {model_mapping.get(model.lower(), model)}")
        print("If this fails, the model might not be available or named differently.")
        _printed_model_note = True

    # Use mapping if available, otherwise the raw string
    final_model = model_mapping.get(model.lower(), model)

    if verbose and not _printed_model_note:
        print(f"🤖 Using model: {final_model}")

    headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}

    payload = {
        "model": final_model,
        "messages": [{"role": "user", "content": prompt}],
        "temperature": 1,
        "top_p": 1,
        # "max_tokens": 4096,  # enable if you want a hard cap
    }

    # Retries with exponential backoff for transient DeepInfra errors
    response = None
    backoff = 1.0
    for attempt in range(5):
        try:
            response = requests.post(
                "https://api.deepinfra.com/v1/openai/chat/completions",
                headers=headers,
                json=payload,
                timeout=300,  # up to 5 minutes for long generations
            )
            if response.status_code in (429, 500, 502, 503, 504):
                if verbose:
                    print(
                        f"⚠️ DeepInfra {response.status_code}, retrying in {backoff:.1f}s (attempt {attempt+1}/5)"
                    )
                time.sleep(backoff)
                backoff = min(backoff * 2, 20)
                continue
            response.raise_for_status()
            break
        except requests.exceptions.RequestException as e:
            if attempt == 4:
                if verbose:
                    print(f"❌ DeepInfra API call failed: {e}")
            else:
                if verbose:
                    print(
                        f"⚠️ Request error, retrying in {backoff:.1f}s (attempt {attempt+1}/5): {e}"
                    )
                time.sleep(backoff)
                backoff = min(backoff * 2, 20)

    data = response.json() if response is not None else {}
    response_text = (
        (data.get("choices") or [{}])[0].get("message", {}).get("content", "")
    )

    # Extract API usage data when present
    api_usage = None
    if isinstance(data, dict) and "usage" in data:
        api_usage = {
            "prompt_tokens": data["usage"].get("prompt_tokens"),
            "completion_tokens": data["usage"].get("completion_tokens"),
            "total_tokens": data["usage"].get("total_tokens"),
        }

    # Analyze tokens comprehensively
    token_usage = analyze_response_tokens(
        prompt_text=prompt,
        response_text=response_text,
        api_usage=api_usage,
    )

    return response_text, token_usage


def solve_graph(prompt, output_path, model="qwq", verbose: bool = False):
    """
    Sends a graph transformation prompt to DeepInfra and saves the resulting output with token usage.

    Parameters:
    - prompt (str): The prompt text describing the input-output transformation.
    - output_path (str): Path to the file where the response should be saved.
    - model (str): The DeepInfra model to use (default: 'qwq').
    - verbose (bool): Print per-item details.
    """
    if verbose:
        print(f"Solving benchmark using DeepInfra API with model: {model}...")

    ai_response, token_usage = query_deepinfra(prompt, model, verbose=verbose)

    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    # Save the response
    with open(output_path, "w", encoding="utf-8") as f:
        f.write(ai_response)

    # Save token usage metadata alongside response
    token_output_path = output_path.replace(".txt", "_tokens.json")
    with open(token_output_path, "w", encoding="utf-8") as f:
        json.dump(token_usage, f, indent=2)

    if verbose:
        rel = os.path.relpath(output_path)
        print(f"✅ Saved: {rel}")
        print(
            f"📊 Tokens: {token_usage['total_tokens']} total "
            f"({token_usage['input_tokens']} in, {token_usage['output_tokens']} out)"
        )
        if token_usage.get("reasoning_tokens"):
            print(
                f"🧠 Reasoning tokens: {token_usage['reasoning_tokens']} "
                f"({token_usage.get('reasoning_ratio', 0):.2%})"
            )


def solve_graphs_batch(
    prompts: List[str],
    output_paths: List[str],
    model: str = "qwq",
    batch_size: int = 10,
    delay_between_requests: float = 0.5,
) -> List[Dict[str, Any]]:
    """
    Process multiple prompts sequentially with rate limiting.
    DeepInfra doesn't have a batch API like OpenAI, so we process sequentially.

    Parameters:
    - prompts: List of prompts to process
    - output_paths: List of output file paths
    - model: Model to use
    - batch_size: Number of requests to process before a longer pause
    - delay_between_requests: Seconds to wait between requests

    Returns:
    - List of results with token usage
    """
    if len(prompts) != len(output_paths):
        raise ValueError("prompts and output_paths must have the same length")

    results = []
    total = len(prompts)

    print(f"🚀 Processing {total} prompts using DeepInfra {model} (sequential)...")
    print(f"⏱️ Rough ETA: {total * delay_between_requests / 60:.1f} minutes")

    pbar = tqdm(total=total, unit="prompt")
    for i, (prompt, output_path) in enumerate(zip(prompts, output_paths), 1):

        # Check if output already exists
        if os.path.exists(output_path):
            pbar.update(1)
            pbar.set_postfix_str(f"skipped: {'/'.join(Path(output_path).parts[-4:])}")
            continue

        try:
            # Process the prompt
            ai_response, token_usage = query_deepinfra(prompt, model, verbose=False)

            # Save response
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            with open(output_path, "w", encoding="utf-8") as f:
                f.write(ai_response)

            # Save token usage
            token_output_path = output_path.replace(".txt", "_tokens.json")
            with open(token_output_path, "w", encoding="utf-8") as f:
                json.dump(token_usage, f, indent=2)

            results.append(
                {
                    "output_path": output_path,
                    "success": True,
                    "token_usage": token_usage,
                }
            )

            pbar.update(1)
            tail = "/".join(Path(output_path).parts[-4:])
            pbar.set_postfix_str(f"ok: {tail} ({token_usage['total_tokens']} tok)")

        except (requests.exceptions.RequestException, OSError, ValueError) as e:
            pbar.update(1)
            tail = "/".join(Path(output_path).parts[-4:])
            pbar.set_postfix_str(f"FAIL: {tail} ({e})")
            results.append(
                {"output_path": output_path, "success": False, "error": str(e)}
            )

        # Rate limiting
        if i < total:  # Don't sleep after the last request
            if i % batch_size == 0:
                pbar.set_postfix_str(f"pause after {batch_size}…")
                time.sleep(5)
            else:
                time.sleep(delay_between_requests)
    pbar.close()

    # Summary
    successful = sum(1 for r in results if r.get("success", False))
    total_tokens = sum(
        r.get("token_usage", {}).get("total_tokens", 0)
        for r in results
        if r.get("success", False)
    )

    print(f"\n✅ Completed: {successful}/{total} successful")
    print(f"📊 Total tokens used: {total_tokens:,}")

    return results


def solve_graphs_parallel(
    prompts: List[str],
    output_paths: List[str],
    model: str = "qwq",
    max_workers: int = 5,
    delay_between_batches: float = 1.0,
    verbose: bool = False,
) -> List[Dict[str, Any]]:
    """
    Process multiple prompts in parallel using thread pool.

    Parameters:
    - prompts: List of prompts to process
    - output_paths: List of output file paths
    - model: Model to use
    - max_workers: Maximum number of concurrent requests
    - delay_between_batches: Seconds to wait between small batches
    - verbose: Print high-level info
    """
    if len(prompts) != len(output_paths):
        raise ValueError("prompts and output_paths must have the same length")

    results = []
    total = len(prompts)
    completed = 0

    if verbose:
        print(
            f"🚀 Processing {total} prompts using DeepInfra {model} with {max_workers} workers..."
        )
        print(f"⏱️ Rough ETA: {total / max_workers * 2 / 60:.1f} minutes")

    def process_single(prompt: str, output_path: str, index: int) -> Dict[str, Any]:
        """Process a single prompt."""
        # Check if output already exists
        if os.path.exists(output_path):
            return {
                "index": index,
                "output_path": output_path,
                "success": True,
                "skipped": True,
            }

        try:
            # Process the prompt
            ai_response, token_usage = query_deepinfra(prompt, model, verbose=False)

            # CHECK IF RESPONSE IS EMPTY (indicates API error)
            if not ai_response or not ai_response.strip():
                return {
                    "index": index,
                    "output_path": output_path,
                    "success": False,
                    "error": "Empty response - likely API error or model not found",
                    "skipped": False,
                }

            # Save response only if non-empty
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            with open(output_path, "w", encoding="utf-8") as f:
                f.write(ai_response)

            # Save token usage
            token_output_path = output_path.replace(".txt", "_tokens.json")
            with open(token_output_path, "w", encoding="utf-8") as f:
                json.dump(token_usage, f, indent=2)

            return {
                "index": index,
                "output_path": output_path,
                "success": True,
                "token_usage": token_usage,
                "skipped": False,
            }

        except Exception as e:
            return {
                "index": index,
                "output_path": output_path,
                "success": False,
                "error": str(e),
                "skipped": False,
            }

    # Process in parallel
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit all tasks
        future_to_index = {
            executor.submit(process_single, prompt, output_path, i): i
            for i, (prompt, output_path) in enumerate(zip(prompts, output_paths))
        }

        # Progress
        with tqdm(total=total, unit="prompt") as pbar:
            for future in as_completed(future_to_index):
                result = future.result()
                results.append(result)
                completed += 1
                pbar.update(1)

                # Rich short status (tail of path: last 4 parts)
                tail = "/".join(Path(result["output_path"]).parts[-4:])
                if result.get("skipped"):
                    pbar.set_postfix_str(f"skipped: {tail}")
                elif result.get("success"):
                    tokens = result.get("token_usage", {}).get("total_tokens", 0)
                    pbar.set_postfix_str(f"ok: {tail} ({tokens} tok)")
                else:
                    pbar.set_postfix_str(
                        f"FAIL: {tail} ({result.get('error', 'unknown')})"
                    )

                # Light pacing every ~2×worker completions
                if completed % (max_workers * 2) == 0 and completed < total:
                    time.sleep(delay_between_batches)

    # Sort results by original index
    results.sort(key=lambda x: x["index"])

    # Summary
    successful = sum(
        1 for r in results if r.get("success", False) and not r.get("skipped", False)
    )
    skipped = sum(1 for r in results if r.get("skipped", False))
    failed = sum(1 for r in results if not r.get("success", False))
    total_tokens = sum(
        r.get("token_usage", {}).get("total_tokens", 0)
        for r in results
        if r.get("success", False)
    )

    print(
        f"\n✅ Completed: {successful} successful, {skipped} skipped, {failed} failed"
    )
    print(f"📊 Total tokens used: {total_tokens:,}")

    return results
