import shutil
import json
import argparse
import time
from openai import OpenAI
import re
import os 
import base64
from io import BytesIO
from PIL import Image
import json_repair


# Create an OpenAI client with your token and endpoint
with open("/path/to/key.txt") as f:
    api_key = f.read().strip()
openai = OpenAI(
    api_key=api_key,   
    base_url="your/provider/endpoint",
)

# Set up argument parsing
parser = argparse.ArgumentParser(description="Run a chat completion with streaming output.")
parser.add_argument(
    '--model', 
    type=str, 
    default='meta-llama/Meta-Llama-3-8B-Instruct', 
    help="The model to use (default: 'meta-llama/Meta-Llama-3-8B-Instruct')"
)
parser.add_argument(
    '--json', 
    type=str, 
    required=True, 
    help="Path to the input JSON file"
)
parser.add_argument(
    '--txt', 
    type=str, 
    default="prompts/forecast_A_3D.txt",
    help="Path to the input text file"
)
parser.add_argument(
    '--out', 
    type=str, 
    required=True, 
    help="Path to the output directory for responses"
)
parser.add_argument(
    '--maps', 
    type=str, 
    required=False,
    default=None, 
    help="Path to the directory of map png files (for multimodal forecasting)"
)

# Parse command-line arguments
args = parser.parse_args()
print("Arguments: ", vars(args))

# Extract parameters from the parsed arguments
model = args.model
json_file_path = args.json
txt_file_path = args.txt

# Get output paths
scene_id = args.json.split('/')[-1].split('.json')[0]
model_name = f'{args.model.split("/")[-1]}'
if args.maps is not None:
    model_name += '_map'
out_txt_path = os.path.join(args.out, 'responses', model_name, f'{scene_id}.txt')
out_json_path = os.path.join(args.out, 'forecast_pred', model_name, f'{scene_id}.json')

os.makedirs(os.path.join(args.out, 'responses', model_name), exist_ok=True)
os.makedirs(os.path.join(args.out, 'forecast_pred', model_name), exist_ok=True)

# Skip if already predicted this entry (TODO: remove)
if os.path.exists(out_json_path):
    raise ValueError(f'Skipping {out_json_path} as it was already predicted')

# Open the JSON file and add it as an attachment
with open(json_file_path, "r") as json_file:
    json_data = json.load(json_file)
json_content = json.dumps(json_data, indent=4)

# Read the message content from the text file
with open(txt_file_path, "r") as txt_file:
    message_content = txt_file.read().strip()

# Initialize output storage
output = ["Arguments: ", str(vars(args))]  # Store args info to output file

completed = False  # Flag to indicate if completion has been reached
# Combine the text prompt with the JSON content
combined_prompt = f"{message_content} \n Here is the JSON data: {json_content}"
# combined_prompt = f"Describe the image in one sentence"

start = time.time()

if args.maps is not None:
    # Load corresponding map
    scene_id = args.json.split('/')[-1].split('.')[0]
    main_scene_id = '_'.join(scene_id.split('_')[:-2])
    map_path = os.path.join(args.maps, f'{main_scene_id}.png')
    image = Image.open(map_path).convert("RGB")
    buffered = BytesIO()
    image.save(buffered, format="JPEG")
    image_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
    # Create message with map image
    messages=[
        {
            "role": "user",
            "content": [
                {"type": "text", "text": combined_prompt},
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_data}"}},
                # {"type": "image", "source_type": "base64", "data": image_data, "mime_type": "image/jpeg"},
            ],
        }
    ]
else:
    messages = [{"role": "user", "content": combined_prompt}]

# Loop to stream the output and handle continuation
while not completed:
    # Send the combined prompt to the model with streaming enabled
    chat_completion = openai.chat.completions.create(
        model=model,
        messages=messages,
        max_tokens=None,
        max_completion_tokens=None,
        stream=True,
        stream_options={"include_usage": True},
    )
    
    # Process the response as it streams
    partial_output = ""  # Temporary storage for partial output
    for event in chat_completion:
        # Ensure that 'choices' exists in the event
        if event.choices and event.choices[0]:
            if event.choices[0].finish_reason:
                output.append(f"Finish reason: {event.choices[0].finish_reason}\n")
                print(f"Finish reason: {event.choices[0].finish_reason}\n")
                # If finish_reason is provided, check if it's 'stop'
                if event.choices[0].finish_reason == 'stop':
                    # If the model finishes, process and update the flag
                    output.append(f"Prompt Tokens: {event.usage.prompt_tokens}\n")
                    output.append(f"Completion Tokens: {event.usage.completion_tokens}\n")                    
                    print(f"Prompt Tokens: {event.usage.prompt_tokens}\n")
                    print(f"Completion Tokens: {event.usage.completion_tokens}\n")

                    # Mark as completed
                    completed = True
                else:
                    # Continue if finish_reason is not 'stop'
                    print("The model is not finished yet. Continuing...\n")
                    continue
            else:
                # Print the streamed content immediately
                print(event.choices[0].delta.content, end="")

                # Store the partial content
                partial_output += event.choices[0].delta.content
                # Append the partial output to the final output list
                output.append(f"{event.choices[0].delta.content}")

    # Check if we need to continue generating (if output isn't finished)
    if not completed:
        # Append the partial output to the message to continue
        combined_prompt = f"{partial_output} (continue and generate the full JSON altogether)"
        print("\n\nContinuing generation...\n\n")

print('GENERATION TIME: ', time.time()-start)

# Write the full output to a text file after the stream ends
with open(out_txt_path, "w") as file:
    file.writelines(output)

# Write dictionary entries to a JSON file
pattern = r'```json\s*(\{.*?\})\s*```'
matches = re.findall(pattern, ''.join(output), re.DOTALL)

merged_dict = {}

for js in matches:
    try:
        # Try parsing using standard json first
        parsed = json.loads(js)
    except json.JSONDecodeError:
        try:
            # Attempt to repair broken JSON
            repaired = json_repair.repair_json(js)
            parsed = json.loads(repaired)
        except Exception as e:
            print(f"Warning: Could not parse even with repair_json: {e}")
            parsed = {}

    merged_dict.update(parsed)

# Save to file regardless of errors
with open(out_json_path, "w", encoding="utf-8") as f:
    try:
        json.dump(merged_dict, f, indent=4)
    except Exception as e:
        # Fallback: save as plain text if json dump fails
        print(f"Error dumping JSON: {e}")
        f.write(str(merged_dict))