import base64
import mimetypes
from typing import List, Dict

import os
import argparse
from litellm import completion
from keys import BEDROCK_KEY, BEDROCK_SECRET_KEY
import json
from tqdm import tqdm
from time import sleep

def extract_boxed_content(text: str) -> str:
    """
    Extracts answers in \\boxed{}.
    """
    depth = 0
    start_pos = text.rfind(r"\boxed{")
    end_pos = -1
    if start_pos != -1:
        content = text[start_pos + len(r"\boxed{") :]
        for i, char in enumerate(content):
            if char == "{":
                depth += 1
            elif char == "}":
                depth -= 1

            if depth == -1:  # exit
                end_pos = i
                break

    if end_pos != -1:
        return content[:end_pos].strip()

    return "None"

def construct_multimodal_messages_from_sharegpt_format(
    messages: List[Dict[str, str]],
    image_paths: List[str]
) -> List[Dict[str, List[Dict[str, str]]]]:
    """
    Given:
      - messages: a list of dicts {"role": "...", "content": "..."} whose content may include one or more "<image>" placeholders,
      - image_paths: a list of local file paths in the exact order images should be slotted in,
    Returns:
      - a list of Litellm messages, where each message is {"role": role, "content": [<blocks>]}
        and each block is either {"type":"text", "text": ...}
        or      {"type":"image_url", "image_url":{"url": data_url}}
    """
    # iterator over image paths
    img_iter = iter(image_paths)

    litellm_messages = []
    assert len(messages) == 2
    assert messages[0]["role"] == "user", "First message must be system role"
    assert messages[1]["role"] == "assistant", "Second message must be assistant role"
    for msg in messages[:1]: # skip the assistant message
        role = msg["role"]
        text = msg["content"]
        blocks: List[Dict[str, str]] = []

        # split on the literal placeholder "<image>"
        parts = text.split("<image>")
        for i, part in enumerate(parts):
            # always emit the leading text (which may be empty)
            if part:
                blocks.append({"type": "text", "text": part})

            # after every part except the last, emit an image block
            if i < len(parts) - 1:
                try:
                    image_path = next(img_iter)
                except StopIteration:
                    raise ValueError("Not enough images for the <image> placeholders.")
                mime_type, _ = mimetypes.guess_type(image_path)
                if mime_type is None or not mime_type.startswith("image/"):
                    raise ValueError(f"Unrecognized image mime type for file {image_path}")
                # read & encode
                with open(image_path, "rb") as f:
                    img_b64 = base64.b64encode(f.read()).decode("ascii")
                data_url = f"data:{mime_type};base64,{img_b64}"
                blocks.append({
                    "type": "image_url",
                    "image_url": {"url": data_url}
                })

        litellm_messages.append({
            "role": role,
            "content": blocks
        })

    # if there are leftover images, warn the user
    try:
        extra = next(img_iter)
        raise ValueError("More images provided than <image> placeholders.")
    except StopIteration:
        pass

    return litellm_messages


def load_valid_instances_and_get_done_indices(output_json_path):
    print(f"Resuming from existing results in {output_json_path}")
    done_indices = set()
    results = json.load(open(output_json_path, "r"))
    valid_results = []
    for res in results:
        if res['response_text'] is not None:
            done_indices.add(res["index"])
            valid_results.append(res)
    print(f"Skipping indices: {done_indices}")
    return done_indices, valid_results

def run(
    gen_config: Dict[str, str],
    sharegpt_dataset_path: str,
    output_json_path: str
):
    """
    Run inference over all datapoints in a ShareGPT JSON, save prompt, raw response,
    and processed (boxed) response to output_json_path.
    """
    data = json.load(open(sharegpt_dataset_path, "r"))
    results = []

    # check if there are already results saved, if so, resume from there
    if os.path.exists(output_json_path):
        done_indices, results = load_valid_instances_and_get_done_indices(output_json_path)
    else:
        done_indices = set()

    for idx, dp in tqdm(enumerate(data), total=len(data), desc=f"Processing datapoints from {sharegpt_dataset_path}"):
        if idx in done_indices:
            continue
        print(f"Processing datapoint {idx}/{len(data)}")
        # build messages
        gt = extract_boxed_content(dp["messages"][1]["content"])
        try:
            msgs = construct_multimodal_messages_from_sharegpt_format(
                messages=dp["messages"],
                image_paths=dp.get("images", []),
            )
            gen_config['messages'] = msgs
            # call model
            resp = completion(**gen_config)  # unpack the generation config)
            resp_text = resp.choices[0].message.content
            processed = extract_boxed_content(resp_text)
        except Exception as e:
            print(f"Error processing datapoint {idx}: {e}")
            resp = None
            resp_text = None
            processed = None

        results.append({
            "index": idx,
            "input": dp,
            "raw_response": resp.to_dict() if resp else None,
            "response_text": resp_text,
            "extracted_answer": processed,
            "ground_truth": gt
        })
        sleep(1.0)  # to avoid rate limiting
        with open(output_json_path, "w") as fout:
            json.dump(results, fout, indent=4)

    print(f"Saved {len(results)} results to {output_json_path}")


# set up AWS credentials
os.environ["AWS_ACCESS_KEY_ID"] = BEDROCK_KEY
os.environ["AWS_SECRET_ACCESS_KEY"] = BEDROCK_SECRET_KEY
# os.environ["AWS_DEFAULT_REGION"] = "us-east-2"
os.environ["AWS_DEFAULT_REGION"] = "us-west-2"
MODEL = "bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0"
MAX_TOKENS = 8192+1024
TEMPERATURE = 0.0

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run inference with Claude model on ShareGPT dataset.")
    parser.add_argument("--sharegpt_dataset_path", type=str, required=True, help="Path to the ShareGPT dataset JSON file.")
    parser.add_argument("--output_json_path", type=str, required=True, help="Path to save the output JSON file.")
    parser.add_argument("--model", type=str, required=False, default=MODEL)
    parser.add_argument("--max_tokens", type=int, required=False, default=MAX_TOKENS)
    parser.add_argument("--temperature", type=float, required=False, default=TEMPERATURE)
    args = parser.parse_args()

    os.makedirs(os.path.dirname(args.output_json_path), exist_ok=True)

    gen_config = {
        "model": args.model,
        "max_tokens": args.max_tokens,
        "temperature": args.temperature,
        "stream": False  # set to True if you want streaming responses
    }
    print(f"Using generation config: {gen_config}")
    run(
        gen_config=gen_config,
        sharegpt_dataset_path=args.sharegpt_dataset_path,
        output_json_path=args.output_json_path
    )
