import os
import json
import boto3
import sys
from botocore.exceptions import ClientError
import base64

config = json.load(open("api_keys/claude.json"))
region = config["region"]

client = boto3.client(
    service_name="bedrock-runtime",
    region_name=region,
    aws_access_key_id=config["access_key_id"],
    aws_secret_access_key=config["secret_access_key"],
)

reasoning_config = {
        "thinking": {
            "type": "enabled",
            "budget_tokens": 2048
        }
}

def generate(model, messages, generation_config={"max_output_tokens": 16384, "temperature": 0.0}, enable_reasoning=True, reasoning_buget=2048, client=client):
    """
    Prompt should be:
    """
    # for claude-3.7-sonnet
    metadata = {
        "prompt_token_count": 0,
        "candidates_token_count": 0,
        "thoughts_token_count": 0,
    }
    request_params = {
        "modelId": config[model],
        "messages": messages,
        "inferenceConfig": {
            "temperature": generation_config["temperature"],
            "maxTokens": generation_config["max_output_tokens"],
        }
    }
    if enable_reasoning:
        request_params["inferenceConfig"]["temperature"] = 1.0
        if generation_config["max_output_tokens"] <= reasoning_buget:
            adjusted_max_tokens = reasoning_buget + 1
            request_params["inferenceConfig"]["maxTokens"] = adjusted_max_tokens
        request_params["additionalModelRequestFields"] = {
            "reasoning_config": {
                "type": "enabled",
                "budget_tokens": reasoning_buget,
            }
        }
    try:
        response = client.converse(
            **request_params
        )
    except Exception as e:
        print(e)
        return None, None

    try:
        # print("====Response====")
        # print(response)
        # print("====Text====")
        text = ""
        if response.get("output", {}).get("message", {}).get("content", [{}]):
            content_blocks = response['output']['message']['content']
            for block in content_blocks:
                if 'reasoningContent' in block:
                    text += '<thinking>' + block['reasoningContent']['reasoningText']['text'] + '</thinking>\n'
                if 'text' in block:
                    metadata["prompt_token_count"] = response['usage']['inputTokens']
                    metadata["candidates_token_count"] = response['usage']['outputTokens']
                    text += block['text']
                    # print(text)
                    # print("====Metadata====")
                    # print(metadata)
                    return text, metadata
        return "No response content found", metadata
    except Exception as e:
        print(e)
        return "No response content found", metadata

if __name__ == "__main__":
    print('Running boto3 version: ', boto3.__version__)
    print('Using region: ', region)
    with open("unit_test/data/screenshots/screenshot_001.png", "rb") as f:
        image_data = f.read()
    print(generate("claude-4-sonnet", [
        {"role": "user", "content": [
            {
                "image": {
                    "format": "png",
                    "source": {"bytes": image_data}
                }
            },
            {"text": "Describe the image"}
        ]},
    ], enable_reasoning=True))

    # for image input:
    # in content: [{"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": base64_encoded_image}}]