import os
import random
import time
from pprint import pformat
from types import SimpleNamespace

import markdown
import numpy as np
from PIL import Image
from absl import logging
from google import genai
from google.genai import types

from veoplace.utils import downsize_image
from veoplace.utils import image_to_base64_str
from veoplace.utils import parse_valid_rectangle
from veoplace.utils.constants import GEMINI_RETRYABLE_ERRORS
from veoplace.utils.constants import GEMINI_THINKING_BUDGET
from veoplace.utils.constants import GEMINI_THINKING_MODELS


def load_gemini_api_keys(path: str | None = None):
    """
    Load Gemini-API keys.

    Parameters
    ----------
    path : str | None
        • If supplied, read keys from that file.
        • Otherwise fall back to   ~/.gemini_api_keys   (old behaviour).

    Returns
    -------
    str | None
        The key string (whitespace-stripped) or None if the file is missing.
    """
    key_file = os.path.expanduser(path) if path else os.path.expanduser(
            "~/.gemini_api_keys")

    try:
        with open(key_file, "r") as f:
            keys = f.read().strip()
        logging.info("Loaded Gemini API keys from %s", key_file)
        return keys

    except FileNotFoundError:
        logging.warning(
                "Gemini API keys not found.  Looked for %s.  "
                "Supply --gemini_key_path or create ~/.gemini_api_keys",
                key_file,
        )
        return None


def log_gemini_error(err, *, key_idx=None, log=logging.warning) -> float | None:
    """
    Pretty‐print a google.generativeai error object and, if a RetryInfo block
    is present, return its retry delay in seconds.  Otherwise returns None.
    """
    header = f'Key #{key_idx} – ' if key_idx is not None else ''
    header += f'{getattr(err, "code", "<?>")} {getattr(err, "status", type(err).__name__)}'
    log(header)

    msg = getattr(err, "message", None) or str(err)
    for line in msg.splitlines():
        log("  " + line)

    details = getattr(err, "details", None)
    if isinstance(details, dict) and "error" in details:
        details = details["error"].get("details", [])
    if not isinstance(details, list):
        return None

    retry_secs = None

    for d in details:
        dtype = d.get("@type", "")
        if dtype.endswith("QuotaFailure"):
            for v in d.get("violations", []):
                qid = v.get("quotaId", v.get("quotaMetric", "?"))
                limit = v.get("quotaValue", "?")
                dims = v.get("quotaDimensions", {})
                model = dims.get("model", "?")
                location = dims.get("location", "?")
                log(f"  Quota exceeded: {qid} = {limit} "
                    f"(model={model}, location={location})")

        elif dtype.endswith("RetryInfo"):
            rd = d.get("retryDelay", "?")
            log(f"  Server says: retry after {rd}")

            # gRPC form: {"seconds":"4","nanos":"500000000"}
            if isinstance(rd, dict):
                retry_secs = float(rd.get("seconds", 0)) + \
                             float(rd.get("nanos", 0)) / 1e9
            # REST form: "3.5s"
            elif isinstance(rd, str) and rd.endswith("s"):
                try:
                    retry_secs = float(rd[:-1])
                except ValueError:
                    pass

        elif dtype.endswith("Help"):
            for link in d.get("links", []):
                log(f"  Help: {link.get('description', '')} – {link.get('url', '')}")
        else:
            log("  Unrecognised detail:\n" + pformat(d, indent=4))

    return retry_secs


def get_gemini_mask(corners, macro_grid_width, macro_grid_height, grid_size):
    corners = parse_valid_rectangle(corners)
    xs = [p[0] for p in corners]
    ys = [p[1] for p in corners]

    # Create mask for rectangle
    min_col, max_col = min(xs), max(xs)
    min_row, max_row = min(ys), max(ys)

    # Get the width and height of the current macro relatively
    end_valid_col = max_col - macro_grid_width + 1
    end_valid_row = max_row - macro_grid_height + 1

    # create different ones matrix with minx/miny to end_valid_x/end_valid_y equal to 1
    # with zeros everywhere else
    valid_mask = np.zeros((grid_size, grid_size), dtype=np.float32)
    valid_mask[min_row:end_valid_row, min_col:end_valid_col] = 1.0
    return valid_mask


def save_prompt_to_markdown_embedded(prompt_elements, output_dir, run_id,
        timestep, response=None, suggestion_image=None, gemini_model_str=None,
        candidate_id=1,
        api_key=None):
    """
    Saves a list of prompt elements to a single Markdown file, then converts that
    same Markdown to HTML and writes it as well. You'll get both .md and .html files
    side by side.
    """

    md_filename = f"debug_prompt_run_{run_id:06d}_timestep_{timestep:04d}_cand_{candidate_id:02d}.md"
    html_filename = f"debug_prompt_run_{run_id:06d}_timestep_{timestep:04d}_cand_{candidate_id:02d}.html"
    prompts_dir = os.path.join(output_dir, "prompts")
    os.makedirs(prompts_dir, exist_ok=True)

    md_path = os.path.join(prompts_dir, md_filename)
    html_path = os.path.join(prompts_dir, html_filename)

    # We'll accumulate the markdown content in 'md_content'
    md_content = []

    # Count tokens if you'd like
    genai.configure(api_key=api_key)
    gemini_model = genai.GenerativeModel(gemini_model_str)

    # Try to count tokens untl it works...
    while True:
        try:
            num_tokens = gemini_model.count_tokens(prompt_elements)
            break
        except Exception as e:
            logging.error(f"Error counting tokens: {e}")
            time.sleep(1)

    # Start building the Markdown content
    md_content.append(f"# Debug Prompt for run {run_id}\n\n")
    md_content.append(f"## Total number of tokens: {num_tokens}\n\n")

    for i, item in enumerate(prompt_elements):
        if isinstance(item, str):
            md_content.append(item)
        elif isinstance(item, Image.Image):
            # Can we downsize the image here for better display?
            downsized_img = downsize_image(item, max_size=200)
            base64_str = image_to_base64_str(downsized_img)
            embed_str = f"![Embedded Image](data:image/png;base64,{base64_str})\n"
            md_content.append(embed_str)

            # base64_str = image_to_base64_str(item)
            # embed_str = f"![Embedded Image](data:image/png;base64,{base64_str})\n"
            # md_content.append(embed_str)
        else:
            md_content.append("```\n")
            md_content.append(f"Unrecognized type: {type(item)}\n{item}\n")
            md_content.append("```\n\n")

    if response:
        md_content.append("\n## Gemini Response\n\n")

        # strip any ``` markdown artifacts from gemini's response
        response = response.replace("```", "")
        md_content.append(f"{response}\n\n")

    if suggestion_image:
        # Can we downsize the image here for better display?
        downsized_img = downsize_image(suggestion_image, max_size=200)
        base64_str = image_to_base64_str(downsized_img)
        embed_str = f"![Suggestion Image](data:image/png;base64,{base64_str})\n"
        md_content.append(embed_str)
        # base64_str = image_to_base64_str(suggestion_image)
        # embed_str = f"![Suggestion Image](data:image/png;base64,{base64_str})\n"
        # md_content.append(embed_str)

    # Join everything into one markdown string
    final_markdown = "".join(md_content)

    # 1) Write out the .md file
    with open(md_path, "w", encoding="utf-8") as f_md:
        f_md.write(final_markdown)

    logging.info("Markdown debug file saved to: %s", md_path)

    # 2) Convert to HTML using markdown2
    html_data = markdown.markdown(final_markdown)

    # 3) Write out the .html file
    with open(html_path, "w", encoding="utf-8") as f_html:
        f_html.write(html_data)

    logging.info("HTML debug file saved to: %s", html_path)


def save_all_candidates_to_markdown_embedded(
        prompt_elements,
        output_dir,
        run_id,
        timestep,
        candidate_data_list,
        *,
        num_tokens=None,
        candidate_num_tokens=None,
):
    """
    Saves the prompt and all candidate responses/images to a single Markdown file,
    then converts that Markdown to HTML.

    Args:
        prompt_elements: List of strings/images that make up your prompt.
        output_dir: Base directory for saving markdown/html files.
        run_id: Unique integer ID for the run.
        timestep: The timestep or iteration number.
        gemini_model_str: Name of the Gemini model, used for token counting.
        api_key: Gemini API key for token counting.
        candidate_data_list: A list of dicts, each containing:
            - "candidate_idx": int
            - "candidate_str": str (Gemini response text)
            - "suggestion_image": PIL Image or None
    """
    md_filename = f"debug_prompt_run_{run_id:05d}_timestep_{timestep:04d}_ALL.md"
    html_filename = f"debug_prompt_run_{run_id:05d}_timestep_{timestep:04d}_ALL.html"
    prompts_dir = os.path.join(output_dir, "prompts")
    os.makedirs(prompts_dir, exist_ok=True)

    md_path = os.path.join(prompts_dir, md_filename)
    html_path = os.path.join(prompts_dir, html_filename)

    md_content = []
    md_content.append(f"# Debug Prompt for run {run_id}\n\n")
    md_content.append(f"## Total number of tokens: {num_tokens}\n\n")
    # number of candidates parsed
    md_content.append(
            f"## Number of candidates: {len(candidate_data_list)}\n\n")
    # 1) Write out the prompt
    md_content.append("## Prompt\n\n")
    for i, item in enumerate(prompt_elements):
        if isinstance(item, str):
            md_content.append(item)
        elif isinstance(item, Image.Image):
            # base64_str = image_to_base64_str(item)
            # embed_str = f"![Embedded Image](data:image/png;base64,{base64_str})\n"
            # md_content.append(embed_str)
            downsized_img = downsize_image(item, max_size=800)
            base64_str = image_to_base64_str(downsized_img)
            embed_str = f"![Embedded Image](data:image/png;base64,{base64_str})\n"
            md_content.append(embed_str)

        else:
            md_content.append("```\n")
            md_content.append(f"Unrecognized type: {type(item)}\n{item}\n")
            md_content.append("```\n\n")

    # 2) Write out all candidates
    md_content.append("\n## Gemini Candidates\n\n")
    # add total tokens
    md_content.append(f"## Total number of tokens: {candidate_num_tokens}\n\n")
    for data in candidate_data_list:
        candidate_idx = data["candidate_idx"]
        candidate_str = data["candidate_str"]
        suggestion_image = data["suggestion_image"]

        md_content.append(f"### Candidate #{candidate_idx}\n\n")

        # Candidate text
        if candidate_str:
            # remove triple-backticks to avoid messing with markdown
            candidate_str_clean = candidate_str.replace("```", "")
            md_content.append(candidate_str_clean)
            md_content.append("\n\n")

        # Candidate image
        if suggestion_image:
            # base64_str = image_to_base64_str(suggestion_image)
            # embed_str = f"![Suggestion Image (candidate #{candidate_idx})](data:image/png;base64,{base64_str})\n"
            # md_content.append(embed_str)

            downsized_img = downsize_image(suggestion_image, max_size=800)
            base64_str = image_to_base64_str(downsized_img)
            embed_str = f"![Suggestion Image (candidate #{candidate_idx})](data:image/png;base64,{base64_str})\n"
            md_content.append(embed_str)

            md_content.append("\n\n")
        # put big dividing line in markdown
        # append <hr style="border:2px solid gray">
        md_content.append('<hr style="border:10px solid gray">\n\n')

    # Join everything into one markdown string
    final_markdown = "".join(md_content)

    # 1) Write out the .md file
    with open(md_path, "w", encoding="utf-8") as f_md:
        f_md.write(final_markdown)
    logging.info("Markdown debug file saved to: %s", md_path)

    # 2) Convert to HTML using markdown2
    html_data = markdown.markdown(final_markdown)

    # 3) Write out the .html file
    with open(html_path, "w", encoding="utf-8") as f_html:
        f_html.write(html_data)
    logging.info("HTML debug file saved to: %s", html_path)


def query_gemini_local(
        prompt,
        api_keys,
        gemini_model_str,
        temperature: float = 1.0,
        timeout_seconds: int = 500,
        num_candidates: int = 1,
):
    """
    Query Gemini with automatic exponential back-off.

    • First retry delay = server-supplied RetryInfo value *or* 10 s if none.
    • Doubles after each fully-failed round, capped at 1 h.
    • Adds ±50 % jitter to avoid thundering-herd retries.
    • Retries forever; never errors just because we hit the cap.
    """
    timeout_ms = timeout_seconds * 1000
    api_key_list = [k.strip() for k in api_keys.split(",")]

    # Back-off bookkeeping  (delay is filled lazily from the first server hint)
    delay: float | None = None
    max_delay = 3600.0  # 1 hour cap

    # Pre-build the request config; it’s the same for every retry
    safety_settings = [
            types.SafetySetting(
                    category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
                    threshold=types.HarmBlockThreshold.BLOCK_NONE),
            types.SafetySetting(
                    category=types.HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY,
                    threshold=types.HarmBlockThreshold.BLOCK_NONE),
            types.SafetySetting(
                    category=types.HarmCategory.HARM_CATEGORY_HARASSMENT,
                    threshold=types.HarmBlockThreshold.BLOCK_NONE),
            types.SafetySetting(
                    category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
                    threshold=types.HarmBlockThreshold.BLOCK_NONE),
            types.SafetySetting(
                    category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
                    threshold=types.HarmBlockThreshold.BLOCK_NONE),
    ]

    if gemini_model_str in GEMINI_THINKING_MODELS:
        config = types.GenerateContentConfig(
                temperature=temperature,
                candidate_count=num_candidates,
                thinking_config=types.ThinkingConfig(
                        thinking_budget=GEMINI_THINKING_BUDGET),
                safety_settings=safety_settings,
        )
    else:
        config = types.GenerateContentConfig(
                temperature=temperature,
                candidate_count=num_candidates,
        )

    start_time = time.perf_counter()

    # ---------------- main retry loop ----------------
    while True:
        for idx, key in enumerate(api_key_list):
            try:
                client = genai.Client(
                        api_key=key,
                        http_options=types.HttpOptions(
                                timeout=timeout_ms)
                )
                resp = client.models.generate_content(
                        model=gemini_model_str,
                        contents=prompt,
                        config=config,
                )

                logging.info("Gemini response time: %.2f s using key #%s",
                             time.perf_counter() - start_time, idx)
                return resp  # ← SUCCESS

            except GEMINI_RETRYABLE_ERRORS as e:
                hint = log_gemini_error(e, key_idx=idx)
                if delay is None and hint:
                    delay = hint  # first hint wins

            except Exception:
                # anything unexpected: surface it
                logging.exception("Key #%s unexpected error", idx)
                raise

        # All keys failed → back-off
        if delay is None:
            delay = 10.0
        sleep_for = delay * random.uniform(1.0, 2.0)
        logging.warning("All keys failed. Sleeping %.1f s …", sleep_for)
        time.sleep(sleep_for)

        delay = min(delay * 2, max_delay)


class MockGeminiResponse:
    def __init__(self, num_candidates=5):
        # Create candidates with nested structure all at once
        self.candidates = [
                SimpleNamespace(
                        content=SimpleNamespace(
                                parts=[SimpleNamespace(text="")]
                        )
                ) for _ in range(num_candidates)
        ]
        # Mock usage metadata
        self.usage_metadata = SimpleNamespace(candidates_token_count=0,
                                              prompt_token_count=0)

    def __iter__(self):
        for candidate in self.candidates:
            yield candidate


def parse_gemini_candidates(
        gemini_resp,
        parse_response_func,
        env,
        first_macros,
        output_dir=None,
        run_id=None,
        timestep=None,
        prompt=None,
):
    """
    Returns two aligned lists
        candidates[i]  …  {node_name: coords | None, …}
        strategies[i]  …  dict(text=<str>, image=<PIL-image | None>)
    """
    parsed_suggestions = []
    strategy_meta = []  # <-- NEW: keep full text / image

    for i, cand in enumerate(gemini_resp.candidates):
        if not cand.content or not cand.content.parts:
            logging.warning(
                    "[parse_gemini_candidates] Candidate #%d empty – skipping",
                    i)
            continue

        sugg_dict, cand_txt, sugg_img = parse_candidate(
                cand, i, parse_response_func, env, first_macros,
                output_dir, run_id)

        if sugg_dict is not None and any(sugg_dict.values()):
            parsed_suggestions.append(sugg_dict)
            strategy_meta.append({
                    "candidate_idx": i,  # <-- keep this
                    'candidate_str': cand_txt,
                    'suggestion_image': sugg_img,
            })

    # optional markdown dump (unchanged)  …
    if output_dir is not None and run_id is not None and strategy_meta:
        save_all_candidates_to_markdown_embedded(
                prompt, output_dir, run_id, timestep,
                candidate_data_list=strategy_meta,
                num_tokens=gemini_resp.usage_metadata.prompt_token_count,
                candidate_num_tokens=gemini_resp.usage_metadata.candidates_token_count
        )

    # order by number-of-suggestions descending  …
    counts = [sum(v is not None for v in sugg.values())
              for sugg in parsed_suggestions]
    order = sorted(range(len(counts)), key=counts.__getitem__, reverse=True)

    candidates = [parsed_suggestions[i] for i in order]
    strategies = [strategy_meta[i] for i in order]

    return candidates, strategies


def parse_candidate(
        candidate,
        candidate_idx,
        parse_response_func,
        env,
        first_macros,
        output_dir,
        run_id,
):
    """
    Helper to parse a single candidate and optionally generate its output
    (but no longer save it to markdown individually).
    """
    try:
        # Extract text from candidate
        candidate_str = candidate.content.parts[0].text

        # Parse the candidate text into a dict of macro suggestions
        candidate_suggestions = parse_response_func(candidate_str, env,
                                                    first_macros)
        if not candidate_suggestions:
            logging.info(
                    f"[parse_candidate] Candidate #{candidate_idx}: no suggestions found.")
            return None, None, None

        # Optionally render the candidate's suggestions for debugging
        suggestion_image = None
        if output_dir is not None and run_id is not None:
            try:
                suggestion_image = env.render_all_gemini_suggestions(
                        candidate_suggestions, return_bytes=False
                )
            except Exception as viz_err:
                logging.info(
                        f"[parse_candidate] Error rendering candidate #{candidate_idx}: {viz_err}"
                )

        return candidate_suggestions, candidate_str, suggestion_image

    except Exception as e:
        logging.exception(
                f"[parse_candidate] Failed to parse candidate #{candidate_idx}: {e}. Skipping."
        )
        return None, None, None
