import os
import json
import time
import random
from google import genai
from dotenv import load_dotenv
from typing import Dict, Any, List
from concurrent.futures import ThreadPoolExecutor, as_completed

# Add this import for token analysis
from scripts.utils.token_utils import analyze_response_tokens

# Load .env file
load_dotenv()


def load_api_key():
    """
    Loads the Gemini API key from environment variables.
    """
    api_key = os.getenv("GEMINI_API_KEY")
    if not api_key:
        raise ValueError(
            "Gemini API key not found. Set GEMINI_API_KEY in your .env file or as an environment variable."
        )
    return api_key


def query_gemini(prompt, model="gemini-2.5-pro"):
    """
    google-genai (v1.31) version using client.models.generate_content(...)
    - Robust parsing (no reliance on .output_text alone)
    - Gentle retries with exponential backoff
    - Starts with a high max_output_tokens (65535) to avoid silent truncation
    - Surfaces real HTTP/API errors via token_usage["error"]
    - Returns (response_text, token_usage)
    """

    api_key = load_api_key()
    client = genai.Client(api_key=api_key)

    # Start with a generous output cap; adjust if needed per your costs.
    gen_config = {
        "temperature": 0.3,
        "top_p": 0.9,
        "max_output_tokens": 65535,
    }
    # NOTE: omit "thinking" config; some accounts/regions reject it silently.

    def _extract_text_and_meta(resp):
        """Prefer resp.output_text; fall back to concatenating candidate parts."""
        meta = {}

        # Usage metadata (if provided by the API)
        um = getattr(resp, "usage_metadata", None)
        if um:
            meta["api_usage"] = {
                "prompt_tokens": getattr(um, "prompt_token_count", None),
                "completion_tokens": getattr(um, "candidates_token_count", None),
                "total_tokens": getattr(um, "total_token_count", None),
            }

        # Prompt feedback (blocking)
        pf = getattr(resp, "prompt_feedback", None)
        if pf and getattr(pf, "block_reason", None):
            meta["prompt_blocked"] = pf.block_reason

        # Candidate + finish reason
        text = (getattr(resp, "output_text", None) or "").strip()
        finish_reason = None
        cand0 = None
        cands = getattr(resp, "candidates", None) or []
        if cands:
            cand0 = cands[0]
            fr = getattr(cand0, "finish_reason", None)
            finish_reason = (
                fr.name
                if hasattr(fr, "name")
                else (str(fr) if fr is not None else None)
            )

            # If output_text is empty, join all text parts
            if not text:
                content = getattr(cand0, "content", None)
                parts = getattr(content, "parts", None) if content else None
                if parts:
                    chunks = []
                    for p in parts:
                        t = getattr(p, "text", None)
                        if t:
                            chunks.append(t)
                    text = "".join(chunks).strip()
        else:
            finish_reason = "NO_CANDIDATES"

        meta["finish_reason"] = finish_reason

        # Optional: safety categories
        try:
            sr_list = getattr(cand0, "safety_ratings", None) or []
            meta["safety_ratings"] = [getattr(sr, "category", None) for sr in sr_list]
        except Exception:
            pass

        return text, meta

    # Gentle retry policy
    max_attempts = 3
    backoff_base = 0.6
    backoff_factor = 1.7

    last_text, last_meta = "", {}
    last_error_msg = None

    for attempt in range(max_attempts):
        try:
            resp = client.models.generate_content(
                model=model,
                contents=prompt,  # string is fine; SDK wraps it
                config=gen_config,
            )
        except Exception as e:
            last_error_msg = str(e)
            # Backoff on likely transient errors
            transient = any(
                s in last_error_msg.lower()
                for s in (
                    "429",
                    "503",
                    "500",
                    "temporar",
                    "timeout",
                    "rate",
                    "quota",
                    "unavailable",
                    "reset",
                )
            )
            if transient and attempt < max_attempts - 1:
                time.sleep(backoff_base * (backoff_factor**attempt))
                continue
            token_usage = analyze_response_tokens(
                prompt_text=prompt, response_text="", api_usage=None
            )
            token_usage["error"] = last_error_msg
            return "", token_usage

        text, meta = _extract_text_and_meta(resp)
        last_text, last_meta = text, meta

        if text:
            break  # success

        # No text — decide whether to retry
        fr = (meta.get("finish_reason") or "").upper()
        if meta.get("prompt_blocked"):
            # Respect safety blocks; don't retry blindly
            break

        # We already set a large cap; still back off once in case it was a fluke
        if attempt < max_attempts - 1:
            time.sleep(backoff_base * (backoff_factor**attempt))

    # Build token_usage for your analyzer
    api_usage = None
    if last_meta.get("api_usage"):
        api_usage = {
            "prompt_tokens": last_meta["api_usage"].get("prompt_tokens"),
            "completion_tokens": last_meta["api_usage"].get("completion_tokens"),
            "total_tokens": last_meta["api_usage"].get("total_tokens"),
        }

    token_usage = analyze_response_tokens(
        prompt_text=prompt,
        response_text=last_text,
        api_usage=api_usage,
    )
    token_usage["finish_reason"] = last_meta.get("finish_reason")
    token_usage["prompt_blocked"] = last_meta.get("prompt_blocked")
    if not last_text and not token_usage.get("error"):
        token_usage["error"] = (
            f"no_text finish_reason={last_meta.get('finish_reason')} blocked={last_meta.get('prompt_blocked')}"
        )

    return last_text, token_usage


def solve_graph(prompt, output_path, model="gemini-2.0-flash-exp"):
    """
    Sends a graph transformation prompt to Gemini and saves the resulting output with token usage.
    """
    print(f"Solving benchmark using Gemini API with model: {model}...")

    ai_response, token_usage = query_gemini(prompt, model)

    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    # Save the response
    with open(output_path, "w", encoding="utf-8") as f:
        f.write(ai_response)

    # Save token usage metadata alongside response
    token_output_path = output_path.replace(".txt", "_tokens.json")
    with open(token_output_path, "w", encoding="utf-8") as f:
        json.dump(token_usage, f, indent=2)

    print(f"✅ Solution saved to {output_path}")
    print(
        f"📊 Token usage: {token_usage['total_tokens']} total ({token_usage['input_tokens']} input, {token_usage['output_tokens']} output)"
    )
    if token_usage.get("reasoning_tokens"):
        print(
            f"🧠 Reasoning tokens: {token_usage['reasoning_tokens']} ({token_usage.get('reasoning_ratio', 0):.2%} of response)"
        )


def solve_graphs_parallel(
    prompts: List[str],
    output_paths: List[str],
    model: str = "gemini-2.5-pro",
    max_workers: int = 5,
    delay_between_workers: float = 0.3,  # slightly higher default helps smooth bursts
    batch_delay: float = 1.0,  # short pause between batches
) -> List[Dict[str, Any]]:
    """
    Process multiple prompts in parallel using Gemini API with smart throttling.

    Improvements vs previous version:
    - Caps concurrency to the number of prompts (effective_workers)
    - Adds per-task outer retries specifically for 500/INTERNAL/unavailable errors
    - Uses jittered backoff between worker starts to smooth spikes
    - Sets batch_size = effective_workers to avoid overfilling the queue
    """

    if len(prompts) != len(output_paths):
        raise ValueError("prompts and output_paths must have the same length")

    total = len(prompts)
    if total == 0:
        return []

    # Keep concurrency reasonable for Gemini to avoid 500 bursts
    effective_workers = max(1, min(max_workers, total))
    batch_size = effective_workers  # avoid double-sized batches

    print(
        f"🚀 Processing {total} prompts using Gemini {model} with "
        f"{effective_workers} worker{'s' if effective_workers>1 else ''}..."
    )

    # Small heuristic for eta: ~2 sec / item at low concurrency
    approx_seconds = max(1, total // effective_workers) * 2
    print(f"⏱️ Estimated time: {approx_seconds/60:.1f} minutes")

    results: List[Dict[str, Any]] = []
    completed = 0

    def process_single(prompt: str, output_path: str, index: int) -> Dict[str, Any]:
        """
        Process one prompt with gentle outer retry for INTERNAL errors.
        """
        # Skip if output exists
        if os.path.exists(output_path):
            return {
                "index": index,
                "output_path": output_path,
                "success": True,
                "skipped": True,
            }

        # jittered stagger at worker start
        # use a small random component + index-based spacing
        jitter = random.uniform(0.0, delay_between_workers)
        time.sleep((index % effective_workers) * delay_between_workers + jitter)

        # outer retries for INTERNAL/500/unavailable (query_gemini already retries inside)
        outer_attempts = 2
        backoff_base = 0.7
        backoff_factor = 1.8

        last_err = None
        last_token_usage = None

        for attempt in range(outer_attempts):
            try:
                ai_response, token_usage = query_gemini(prompt, model)
                last_token_usage = token_usage

                if ai_response and ai_response.strip():
                    # Save response
                    os.makedirs(os.path.dirname(output_path), exist_ok=True)
                    with open(output_path, "w", encoding="utf-8") as f:
                        f.write(ai_response)

                    # Save token usage
                    token_output_path = output_path.replace(".txt", "_tokens.json")
                    with open(token_output_path, "w", encoding="utf-8") as f:
                        json.dump(token_usage, f, indent=2)

                    return {
                        "index": index,
                        "output_path": output_path,
                        "success": True,
                        "token_usage": token_usage,
                        "skipped": False,
                    }

                # No text -> inspect reason to decide retry
                err = (
                    (token_usage or {}).get("error", "")
                    or (token_usage or {}).get("finish_reason", "")
                    or ""
                )
                last_err = err

                transient = any(
                    s in err.upper()
                    for s in ("500", "INTERNAL", "UNAVAILABLE", "RESET")
                )
                if transient and attempt < outer_attempts - 1:
                    # jittered backoff before retrying this item
                    time.sleep(
                        backoff_base * (backoff_factor**attempt)
                        + random.uniform(0, 0.25)
                    )
                    continue

                # Non-transient empty or out of attempts
                return {
                    "index": index,
                    "output_path": output_path,
                    "success": False,
                    "error": err or "Empty response",
                    "skipped": False,
                }

            except Exception as e:
                last_err = str(e)
                transient = any(
                    s in last_err.lower()
                    for s in ("500", "internal", "unavailable", "reset")
                )
                if transient and attempt < outer_attempts - 1:
                    time.sleep(
                        backoff_base * (backoff_factor**attempt)
                        + random.uniform(0, 0.25)
                    )
                    continue

                return {
                    "index": index,
                    "output_path": output_path,
                    "success": False,
                    "error": last_err,
                    "skipped": False,
                }

        # Shouldn’t reach here
        return {
            "index": index,
            "output_path": output_path,
            "success": False,
            "error": last_err or "Unknown",
            "skipped": False,
            "token_usage": last_token_usage,
        }

    with ThreadPoolExecutor(max_workers=effective_workers) as executor:
        # Submit in batches of effective_workers
        for batch_start in range(0, total, batch_size):
            batch_end = min(batch_start + batch_size, total)
            batch_prompts = prompts[batch_start:batch_end]
            batch_paths = output_paths[batch_start:batch_end]

            future_to_index = {
                executor.submit(
                    process_single, prompt, path, batch_start + i
                ): batch_start
                + i
                for i, (prompt, path) in enumerate(zip(batch_prompts, batch_paths))
            }

            for future in as_completed(future_to_index):
                result = future.result()
                results.append(result)
                completed += 1

                if result.get("skipped"):
                    print(
                        f"[{completed}/{total}] ⏩ Skipped: {os.path.basename(result['output_path'])}"
                    )
                elif result.get("success"):
                    tokens = result.get("token_usage", {}).get("total_tokens", 0)
                    print(
                        f"[{completed}/{total}] ✅ Completed: {os.path.basename(result['output_path'])} ({tokens} tokens)"
                    )
                else:
                    print(
                        f"[{completed}/{total}] ❌ Failed: {os.path.basename(result['output_path'])} - {result.get('error', 'Unknown')}"
                    )

            if batch_end < total:
                print(
                    f"⏸️ Batch complete. Pausing {batch_delay:.1f}s before next batch..."
                )
                time.sleep(batch_delay)

    # Order by original index
    results.sort(key=lambda x: x["index"])

    # Summary
    successful = sum(
        1 for r in results if r.get("success", False) and not r.get("skipped", False)
    )
    skipped = sum(1 for r in results if r.get("skipped", False))
    failed = sum(1 for r in results if not r.get("success", False))
    total_tokens = sum(
        r.get("token_usage", {}).get("total_tokens", 0)
        for r in results
        if r.get("success", False)
    )

    print(
        f"\n✅ Completed: {successful} successful, {skipped} skipped, {failed} failed"
    )
    print(f"📊 Total tokens used: {total_tokens:,}")

    return results
