#!/usr/bin/env python3

import json
import re
import time
import logging
from openai import OpenAI
from config import *

def call_llama_api(client, model, prompt, attempt_context=""):
    """Call Llama-3.3-70B API with optimized parameters"""
    max_retries = 3
    retry_delay = 5
    
    for attempt in range(max_retries):
        try:
            response = client.chat.completions.create(
                model=model,
                messages=[{"role": "user", "content": prompt}],
                max_tokens=4096,
                temperature=0.1,
                top_p=0.9,
                frequency_penalty=0.0,
                presence_penalty=0.0,
                stream=False
            )
            
            if response.choices and response.choices[0].message.content:
                return response.choices[0].message.content
            else:
                logging.warning(f"    WARNING: Empty response on attempt {attempt + 1} for {attempt_context}")
                
        except Exception as e:
            logging.warning(f"    WARNING: API call failed on attempt {attempt + 1} for {attempt_context}: {e}")
            if attempt < max_retries - 1:
                logging.info(f"    Retrying in {retry_delay} seconds...")
                time.sleep(retry_delay)
                retry_delay *= 2
            else:
                raise e
    
    raise Exception(f"No response generated after multiple attempts for {attempt_context}.")

def call_llama_api_with_reprompt(client, model, prompt, attempt_context, step_name, db_id, iteration):
    """Call Llama API with automatic reprompting for JSON parsing failures - no fallback, only retry"""
    
    for attempt in range(MAX_JSON_REPROMPT_ATTEMPTS):
        try:
            response = call_llama_api(client, model, prompt, f"{attempt_context} - Attempt {attempt + 1}")
            
            # Try to parse JSON response
            parsed_response = parse_json_response(response, step_name, db_id, iteration)
            
            # If parsing succeeded, return it
            if not parsed_response.get("parsing_failed", False):
                if attempt > 0:
                    logging.info(f"    JSON parsing succeeded on attempt {attempt + 1}")
                return parsed_response
            else:
                logging.info(f"    JSON parsing failed on attempt {attempt + 1}, retrying...")
                if attempt < MAX_JSON_REPROMPT_ATTEMPTS - 1:
                    time.sleep(2)  # Wait before retry
                    continue
        
        except Exception as e:
            logging.error(f"    API call failed on attempt {attempt + 1}: {e}")
            if attempt < MAX_JSON_REPROMPT_ATTEMPTS - 1:
                time.sleep(2)
                continue
            else:
                raise e
    
    # If we reach here, all attempts failed
    raise Exception(f"All {MAX_JSON_REPROMPT_ATTEMPTS} attempts failed for {step_name} - {db_id}")

def parse_json_response(response, step_name, db_id, iteration=None):
    """Enhanced JSON parsing with no fallback - only retry"""
    try:
        # Clean response
        response = response.strip()
        
        # Remove markdown formatting
        if response.startswith('```json'):
            response = response[7:]
        if response.startswith('```'):
            response = response[3:]
        if response.endswith('```'):
            response = response[:-3]
        
        # Find JSON boundaries
        start_idx = response.find('{')
        end_idx = response.rfind('}') + 1
        
        if start_idx == -1 or end_idx == 0:
            raise ValueError("No JSON found in response")
        
        json_str = response[start_idx:end_idx]
        
        # Fix common JSON issues
        json_str = fix_json_issues(json_str)
        
        parsed_json = json.loads(json_str)
        return parsed_json
        
    except (json.JSONDecodeError, ValueError) as e:
        logging.warning(f"    WARNING: Failed to parse {step_name} JSON in iteration {iteration}: {e}")
        
        # Return parsing failed flag to trigger retry
        return {
            "database_id": db_id,
            "iteration": iteration if iteration is not None else 0,
            "parsing_failed": True,
            "raw_response": response[:500] + "..." if len(response) > 500 else response
        }

def fix_json_issues(json_str):
    """Fix common JSON formatting issues"""
    # Fix trailing commas
    json_str = re.sub(r',(\s*[}\]])', r'\1', json_str)
    
    # Fix missing commas between objects
    json_str = re.sub(r'}\s*{', r'},{', json_str)
    
    # Fix unescaped quotes in strings (basic fix)
    json_str = re.sub(r'([^\\])"([^"]*)"([^,}\]]*)', r'\1"\2"\3', json_str)
    
    return json_str