
import json
import re
from openai import OpenAI
from typing import Dict, Any
import datetime
import copy
from groq import Groq

groq_api_key = "Your_API_KEY"
groq_client = Groq(api_key=groq_api_key)

client = OpenAI(api_key="Your_API_KEY")
claude_api_key = "Your_API_KEY"
gemini_api_key = "Your_API_KEY"

# -- Call LLM
def gpt_call(prompt: str, model: str = "gpt-4o") -> str:
    """Function to call OpenAI GPT model"""
    try:
        res = client.chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.7
        )
        return res.choices[0].message.content.strip()
    except Exception as e:
        print(f"GPT call error: {e}")
        raise

# Function for Claude model
def claude_call(prompt: str, model: str = "claude-3-5-sonnet-20241022") -> str:
    """Function to call Anthropic Claude model"""
    try:
        import anthropic
        claude_client = anthropic.Anthropic(api_key=claude_api_key)
        
        response = claude_client.messages.create(
            model=model,
            max_tokens=4096,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.7
        )
        return response.content[0].text
    except ImportError:
        print("Error: anthropic package is not installed.")
        raise
    except Exception as e:
        print(f"Claude call error: {e}")
        raise

# Function for Gemini model
def gemini_call(prompt: str, model: str = "gemini-2.0-flash") -> str:
    """Function to call Google Gemini model"""
    try:
        import google.generativeai as genai
        genai.configure(api_key=gemini_api_key)
        
        gemini_model = genai.GenerativeModel(model)
        response = gemini_model.generate_content(prompt, generation_config={"temperature": 0.7})
        return response.text
    except ImportError:
        print("Error: google-generativeai package is not installed.")
        raise
    except Exception as e:
        print(f"Gemini call error: {e}")
        raise

# LLaMa call
def groq_call(prompt: str, model: str = "llama-3.3-7b-versatile") -> str:
    """Function to call Groq API LLaMa model"""
    try:
        response = groq_client.chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.7
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        print(f"LLaMa call error: {e}")
        raise

# -- Integrated LLM call function --
def llm_call(prompt: str, model: str = "gpt-4o") -> str:
    """Integrated function for calling various LLM models"""
    if model.startswith("claude"):
        return claude_call(prompt, model)
    elif model.startswith("gemini"):
        return gemini_call(prompt, model)
    elif model.startswith("llama"):
        return groq_call(prompt, model)
    else:  # GPT models
        return gpt_call(prompt, model)


def extract_json(text: str) -> Dict[str, Any]:
    """
    Extracts and cleans a JSON object from a string.
    1. Finds the JSON within a code block.
    2. Finds a complete JSON object starting with { and ending with }.
    3. Removes control characters and cleans up the JSON.
    """
    try:
        # Find JSON within a code block
        match = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, re.DOTALL)
        if match:
            json_str = match.group(1).strip()
            # Handle control characters
            json_str = re.sub(r'[\x00-\x1F\x7F]', ' ', json_str)
            return json.loads(json_str)
        
        # Find the first complete JSON object
        start = text.find('{')
        if start >= 0:
            brace_count = 0
            for i in range(start, len(text)):
                if text[i] == '{':
                    brace_count += 1
                elif text[i] == '}':
                    brace_count -= 1
                    if brace_count == 0:
                        # Found a complete JSON object
                        json_str = text[start:i+1]
                        # Handle control characters
                        json_str = re.sub(r'[\x00-\x1F\x7F]', ' ', json_str)
                        return json.loads(json_str)
        
        # Final attempt: parse the entire text as JSON
        # Handle control characters
        clean_text = re.sub(r'[\x00-\x1F\x7F]', ' ', text)
        return json.loads(clean_text)

    except (json.JSONDecodeError, TypeError) as e:
        print(f"🛑 JSON parsing error: {e}")
        print(f"🧪 Raw text (first 500 chars):\n{text[:500]}")
        
        # Special case: additional attempt on JSON parsing failure
        try:
            # Attempt to clean unnecessary characters from the string
            if isinstance(text, str):
                # Remove all control characters
                clean_text = re.sub(r'[\x00-\x1F\x7F]', ' ', text)
                # Check for escaped quotes
                clean_text = clean_text.replace('\\"', '"').replace('\\\'', '\'')
                
                # Attempt to find JSON format
                match = re.search(r'(\{.*\})', clean_text, re.DOTALL)
                if match:
                    potential_json = match.group(1)
                    return json.loads(potential_json)
        except:
            pass
        
        raise

# -- Round-robin generator for topics/styles --
def round_robin(items):
    while True:
        for item in items:
            yield item


# Global log storage
process_logs = []

def log_step(task_id, sample_index, phase, agent, action, input_content=None, output_content=None, metadata=None):
    """Function to save process logs
    
    Args:
        task_id: The task ID (T1, T2, etc.)
        sample_index: sample_index: The sample index (loop variable i)
        phase: The current step (init, student_evaluation, difficulty_increase, etc.)
        agent: The agent (teacher, student, orchestrator, system, etc.)
        action: The action (prompt, response, validate, etc.)
        input_content: The input content (prompt, etc.)
        output_content: The output content (response, etc.)
        metadata: Additional metadata
    """

    timestamp = datetime.datetime.now().isoformat()

    # Create a deep copy of the object
    if isinstance(input_content, dict):
        input_content = copy.deepcopy(input_content)
    if isinstance(output_content, dict):
        output_content = copy.deepcopy(output_content)
    if isinstance(metadata, dict):
        metadata = copy.deepcopy(metadata)
        
    log_entry = {
        "timestamp": timestamp,
        "task_id": task_id,
        "sample_index": sample_index,
        "phase": phase,
        "agent": agent,
        "action": action,
        "input": input_content,
        "output": output_content,
        "metadata": metadata or {}
    }
    process_logs.append(log_entry)

def get_logs():
    """Returns all saved logs"""
    return process_logs.copy()

def clear_logs():
    """Initializes the log storage"""
    process_logs.clear()