import openai
import anthropic
from dotenv import load_dotenv
import os
import re
import subprocess
import sys
import tempfile
import copy
import json
import argparse
from itertools import zip_longest
from utils import (
    is_number_string,
    convert_to_number,
    extract_best_objective,
    extract_and_execute_python_code,
    eval_model_result
)

# Load environment variables from .env file
load_dotenv()

# OpenAI API setup
openai_api_data = dict(
    api_key = os.getenv("OPENAI_API_KEY"),
    base_url = os.getenv("OPENAI_API_BASE")
)  

# Anthropic API setup
anthropic_api_data = dict(
    api_key = os.getenv("CLAUDE_API_KEY"),
)

# Initialize clients
openai_client = openai.OpenAI(
    api_key=openai_api_data['api_key'],
    base_url=openai_api_data['base_url'] if openai_api_data['base_url'] else None
)

anthropic_client = anthropic.Anthropic(
    api_key=anthropic_api_data['api_key']
)

def query_llm(messages, model_name="o3", temperature=0.2):
    """
    Call LLM to get response results.
    
    Args:
        messages (list): List of conversation context.
        model_name (str): LLM model name, default is "o3".
        temperature (float): Controls the randomness of output, default is 0.2.

    Returns:
        str: Response content generated by the LLM.
    """
    # Check if model is Claude (Anthropic)
    if model_name.lower().startswith("claude"):
        # Convert OpenAI message format to Anthropic format
        system_message = next((m["content"] for m in messages if m["role"] == "system"), "")
        user_messages = [m["content"] for m in messages if m["role"] == "user"]
        assistant_messages = [m["content"] for m in messages if m["role"] == "assistant"]
        
        # Combine messages into a single conversation string
        conversation = system_message + "\n\n"
        for user_msg, asst_msg in zip_longest(user_messages, assistant_messages, fillvalue=None):
            if user_msg:
                conversation += f"Human: {user_msg}\n\n"
            if asst_msg:
                conversation += f"Assistant: {asst_msg}\n\n"
        
        # Add the final user message if there is one
        if len(user_messages) > len(assistant_messages):
            conversation += f"Human: {user_messages[-1]}\n\n"

        response = anthropic_client.messages.create(
            model=model_name,
            max_tokens=8192,
            temperature=temperature,
            messages=[{
                "role": "user",
                "content": conversation
            }]
        )
        return response.content[0].text
    else:
        # Use OpenAI API
        response = openai_client.chat.completions.create(
            model=model_name,
            messages=messages,
            temperature=temperature
        )
        return response.choices[0].message.content

def generate_or_code_solver(messages_bak, model_name, max_attempts):
    messages = copy.deepcopy(messages_bak)

    gurobi_code = query_llm(messages, model_name)
    print("【Python Gurobi Code】:\n", gurobi_code)

    # 4. Code execution & fixes
    text = f"{gurobi_code}"
    attempt = 0
    while attempt < max_attempts:
        success, error_msg = extract_and_execute_python_code(text)
        if success:
            messages_bak.append({"role": "assistant", "content": gurobi_code})
            return True, error_msg, messages_bak

        print(f"\nAttempt {attempt + 1} failed, requesting LLM to fix code...\n")

        # Build repair request
        messages.append({"role": "assistant", "content": gurobi_code})
        messages.append({"role": "user", "content": f"Code execution encountered an error, error message is as follows:\n{error_msg}\nPlease fix the code and provide the complete executable code again."})

        # Get the fixed code
        gurobi_code = query_llm(messages, model_name)
        text = f"{gurobi_code}"

        print("\nReceived fixed code, preparing to execute again...\n")
        attempt += 1
    # not add gurobi code
    messages_bak.append({"role": "assistant", "content": gurobi_code})
    print(f"Reached maximum number of attempts ({max_attempts}), could not execute code successfully.")
    return False, None, messages_bak

def or_llm_agent(user_question, model_name="o3", max_attempts=3):
    """
    Request Gurobi code solution from LLM and execute it, attempt to fix if it fails.

    Args:
        user_question (str): User's problem description.
        model_name (str): LLM model name to use, default is "gpt-4".
        max_attempts (int): Maximum number of attempts, default is 3.

    Returns:
        tuple: (success: bool, best_objective: float or None, final_code: str)
    """
    # Initialize conversation history
    messages = [
        {"role": "system", "content": (
            "You are an operations research expert. Based on the optimization problem provided by the user, construct a mathematical model that effectively models the original problem using mathematical (linear programming) expressions."
            "Focus on obtaining a correct mathematical model expression without too much concern for explanations."
            "This model will be used later to guide the generation of Gurobi code, and this step is mainly used to generate effective linear scale expressions."
        )},
        {"role": "user", "content": user_question}
    ]

    # 1. Generate mathematical model
    math_model = query_llm(messages, model_name)
    print("【Mathematical Model】:\n", math_model)

    # # 2. Validate mathematical model
    # messages.append({"role": "assistant", "content": math_model})
    # messages.append({"role": "user", "content": (
    #     "Please check if the above mathematical model matches the problem description. If there are errors, make corrections; if there are no errors, check if it can be optimized."
    #     "In any case, please output the final mathematical model again."
    # )})

    # validate_math_model = query_llm(messages, model_name)
    # print("【Validated Mathematical Model】:\n", validate_math_model)
    
    validate_math_model = math_model
    messages.append({"role": "assistant", "content": validate_math_model})
    
    # ------------------------------
    messages.append({"role": "user", "content": (
        "Based on the above mathematical model, write complete and reliable Python code using Gurobi to solve this operations research optimization problem."
        "The code should include necessary model construction, variable definitions, constraint additions, objective function settings, as well as solving and result output."
        "Output in the format ```python\n{code}\n```, without code explanations."
    )})
    # copy msg; solve; add the laset gurobi code 
    is_solve_success, result, messages = generate_or_code_solver(messages, model_name,max_attempts)
    print(f'Stage result: {is_solve_success}, {result}')
    if is_solve_success:
        if not is_number_string(result):
            print('!![No available solution warning]!!')
            # no solution 
            messages.append({"role": "user", "content": (
                "The current model resulted in *no feasible solution*. Please carefully check the mathematical model and Gurobi code for errors that might be causing the infeasibility."
                "After checking, please reoutput the Gurobi Python code."
                "Output in the format ```python\n{code}\n```, without code explanations."
            )})
            is_solve_success, result, messages = generate_or_code_solver(messages, model_name, max_attempts=1)
    else:
        print('!![Max attempt debug error warning]!!')
        messages.append({"role": "user", "content": (
                "The model code still reports errors after multiple debugging attempts. Please carefully check if there are errors in the mathematical model."
                "After checking, please rebuild the Gurobi Python code."
                "Output in the format ```python\n{code}\n```, without code explanations."
            )})
        is_solve_success, result, messages = generate_or_code_solver(messages, model_name, max_attempts=2)
    
    return is_solve_success, result

def gpt_code_agent_simple(user_question, model_name="o3", max_attempts=3):
    """
    Request Gurobi code solution from LLM and execute it, attempt to fix if it fails.

    Args:
        user_question (str): User's problem description.
        model_name (str): LLM model name to use, default is "gpt-4".
        max_attempts (int): Maximum number of attempts, default is 3.

    Returns:
        tuple: (success: bool, best_objective: float or None, final_code: str)
    """
    # Initialize conversation history
    messages = [
        {"role": "system", "content": (
            "You are an operations research expert. Based on the optimization problem provided by the user, construct a mathematical model and write complete, reliable Python code using Gurobi to solve the operations research optimization problem."
            "The code should include necessary model construction, variable definitions, constraint additions, objective function settings, as well as solving and result output."
                "Output in the format ```python\n{code}\n```, without code explanations."
        )},
        {"role": "user", "content": user_question}
    ]

    # copy msg; solve; add the laset gurobi code
    gurobi_code = query_llm(messages, model_name)
    print("【Python Gurobi Code】:\n", gurobi_code)
    text = f"{gurobi_code}"
    is_solve_success, result = extract_and_execute_python_code(text)
    
    print(f'Stage result: {is_solve_success}, {result}')
    
    return is_solve_success, result

def parse_args():
    """
    Parse command line arguments.
    
    Returns:
        argparse.Namespace: The parsed arguments
    """
    parser = argparse.ArgumentParser(description='Run optimization problem solving with LLMs')
    parser.add_argument('--agent', action='store_true', 
                        help='Use the agent. If not specified, directly use the model to solve the problem')
    parser.add_argument('--model', type=str, default='o3',
                        help='Model name to use for LLM queries. Use "claude-..." for Claude models.')
    parser.add_argument('--data_path', type=str, default='data/datasets/IndustryOR.json',
                        help='Path to the dataset JSON file (supports both JSONL and regular JSON formats)')
    return parser.parse_args()

def load_dataset(data_path):
    """
    Load dataset from either JSONL format (IndustryOR.json, BWOR.json) or regular JSON format
    """
    dataset = {}
    
    with open(data_path, 'r', encoding='utf-8') as f:
        # Try to detect format by reading first line
        first_line = f.readline().strip()
        f.seek(0)  # Reset file pointer
        
        if first_line.startswith('{"en_question"') or first_line.startswith('{"cn_question"'):
            # JSONL format (IndustryOR.json, BWOR.json)
            for line_num, line in enumerate(f, 1):
                line = line.strip()
                if line:
                    try:
                        item = json.loads(line)
                        # Convert to expected format
                        dataset_item = {
                            'question': item.get('en_question', item.get('cn_question', '')),
                            'answer': item.get('en_answer', item.get('cn_answer', '')),
                            'difficulty': item.get('difficulty', 'Unknown'),
                            'id': item.get('id', line_num - 1)
                        }
                        # Use id as string key
                        dataset[str(dataset_item['id'])] = dataset_item
                    except json.JSONDecodeError as e:
                        print(f"Warning: Could not parse line {line_num}: {line}")
                        continue
        else:
            # Regular JSON format (legacy)
            dataset = json.load(f)
    
    return dataset

if __name__ == "__main__":
    args = parse_args()
    
    dataset = load_dataset(args.data_path)
    #print(dataset['0'])

    model_name = args.model

    pass_count = 0
    correct_count = 0
    error_datas = []
    for i, d in dataset.items():
        print(f"=============== num {i} ==================")
        user_question, answer = d['question'], d['answer']
        print(user_question)
        print('-------------')
        
        if args.agent:
            is_solve_success, llm_result = or_llm_agent(user_question, model_name)
        else:
            is_solve_success, llm_result = gpt_code_agent_simple(user_question, model_name)
            
        if is_solve_success:
            print(f"Successfully executed code, optimal solution value: {llm_result}")
        else:
            print("Failed to execute code.")
        print('------------------')
        pass_flag, correct_flag = eval_model_result(is_solve_success, llm_result, answer)

        pass_count += 1 if pass_flag else 0
        correct_count += 1 if correct_flag else 0

        if not pass_flag or not correct_flag:
            error_datas.append(i)

        print(f'solve: {is_solve_success}, llm: {llm_result}, ground truth: {answer}')
        print(f'[Final] run pass: {pass_flag}, solve correct: {correct_flag}')
        print(' ')
            
    print(f'[Total {len(dataset)}] run pass: {pass_count}, solve correct: {correct_count}')
    print(f'[Total fails {len(error_datas)}] error datas: {error_datas}')