import os
import json
import openai
from openai import OpenAI

# Use your paste API configuration
BASE_URL = "https://inference-3scale-apicast-production.apps.rits.fmaas.res.ibm.com/llama-3-3-70b-instruct/v1"
MODEL_NAME = "meta-llama/llama-3-3-70b-instruct"

# Setup API client like in your syn_data_gen_code
def setup_paste_api_client():
    """Setup client using your paste API configuration"""
    os.environ['RITS_API_KEY'] = ''
    api_key = os.environ.get("RITS_API_KEY")
    
    if not api_key:
        raise ValueError("Please set RITS_API_KEY environment variable")
    
    client = OpenAI(
        api_key="dummy",
        base_url=BASE_URL,
        default_headers={"RITS_API_KEY": api_key},
        timeout=300
    )
    return client

# Initialize the client
paste_client = setup_paste_api_client()

def extract_json_from_end(text):
    try:
        return extract_json_from_end_backup(text)
    except:
        pass
    
    # Find the start of the JSON object
    json_start = text.find("{")
    if json_start == -1:
        raise ValueError("No JSON object found in the text.")

    # Extract text starting from the first '{'
    json_text = text[json_start:]
    
    # Remove backslashes used for escaping in LaTeX or other formats
    json_text = json_text.replace("\\", "")

    # Remove any extraneous text after the JSON end
    ind = len(json_text) - 1
    while json_text[ind] != "}":
        ind -= 1
    json_text = json_text[: ind + 1]

    # Find the opening curly brace that matches the closing brace
    ind -= 1
    cnt = 1
    while cnt > 0 and ind >= 0:
        if json_text[ind] == "}":
            cnt += 1
        elif json_text[ind] == "{":
            cnt -= 1
        ind -= 1

    # Extract the JSON portion and load it
    json_text = json_text[ind + 1:]

    # Attempt to load JSON
    try:
        jj = json.loads(json_text)
    except json.JSONDecodeError as e:
        raise ValueError(f"Failed to decode JSON: {e}")

    return jj

def extract_json_from_end_backup(text):
    if "```json" in text:
        text = text.split("```json")[1]
        text = text.split("```")[0]
    ind = len(text) - 1
    while text[ind] != "}":
        ind -= 1
    text = text[: ind + 1]

    ind -= 1
    cnt = 1
    while cnt > 0:
        if text[ind] == "}":
            cnt += 1
        elif text[ind] == "{":
            cnt -= 1
        ind -= 1

    # find comments in the json string (texts between "//" and "\n") and remove them
    while True:
        ind_comment = text.find("//")
        if ind_comment == -1:
            break
        ind_end = text.find("\n", ind_comment)
        text = text[:ind_comment] + text[ind_end + 1 :]

    # convert to json format
    jj = json.loads(text[ind + 1 :])
    return jj

def extract_list_from_end(text):
    ind = len(text) - 1
    while text[ind] != "]":
        ind -= 1
    text = text[: ind + 1]

    ind -= 1
    cnt = 1
    while cnt > 0:
        if text[ind] == "]":
            cnt += 1
        elif text[ind] == "[":
            cnt -= 1
        ind -= 1

    # convert to json format
    jj = json.loads(text[ind + 1 :])
    return jj

# Modified get_response function to use your paste API
def get_response(prompt, model="llama3-70b-8192"):
    """Use your paste API configuration like in syn_data_gen_code"""
    max_retries = 3
    retry_delay = 5
    
    for attempt in range(max_retries):
        try:
            response = paste_client.chat.completions.create(
                model=MODEL_NAME,
                messages=[{"role": "user", "content": prompt}],
                max_tokens=4096,
                temperature=0.1,
                top_p=0.9,
                frequency_penalty=0.0,
                presence_penalty=0.0,
                stream=False
            )
            
            if response.choices and response.choices[0].message.content:
                return response.choices[0].message.content
            else:
                print(f"WARNING: Empty response on attempt {attempt + 1}")
                
        except Exception as e:
            print(f"WARNING: API call failed on attempt {attempt + 1}: {e}")
            if attempt < max_retries - 1:
                print(f"Retrying in {retry_delay} seconds...")
                import time
                time.sleep(retry_delay)
                retry_delay *= 2
            else:
                raise e
    
    return "No response generated after multiple attempts."

def load_state(state_file):
    with open(state_file, "r") as f:
        state = json.load(f)
    return state

def save_state(state, dir):
    with open(dir, "w") as f:
        json.dump(state, f, indent=4)

def shape_string_to_list(shape_string):
    if type(shape_string) == list:
        return shape_string
    # convert a string like "[N, M, K, 19]" to a list like ['N', 'M', 'K', 19]
    shape_string = shape_string.strip()
    shape_string = shape_string[1:-1]
    shape_list = shape_string.split(",")
    shape_list = [x.strip() for x in shape_list]
    shape_list = [int(x) if x.isdigit() else x for x in shape_list]
    if len(shape_list) == 1 and shape_list[0] == "":
        shape_list = []
    return shape_list

def extract_equal_sign_closed(text):
    ind_1 = text.find("=====")
    ind_2 = text.find("=====", ind_1 + 1)
    obj = text[ind_1 + 6 : ind_2].strip()
    return obj

class Logger:
    def __init__(self, file):
        self.file = file

    def log(self, text):
        with open(self.file, "a") as f:
            f.write(text + "\n")

    def reset(self):
        with open(self.file, "w") as f:
            f.write("")

# Updated create_state function for your data format - NO USER INTERACTION
def create_state_from_problem_description(problem_description_path, run_dir):
    """Create state from your generated problem_description.md"""
    
    # Read the problem description
    with open(problem_description_path, "r", encoding="utf-8") as f:
        desc = f.read()
    
    # Extract parameters using OptiMUS parameter extraction WITHOUT user interaction
    from parameters import get_params
    params = get_params(desc, check=False)  # CHANGED: check=False to skip user interaction
    
    # Create empty data.json (OptiMUS will work with the description)
    data = {}
    with open(os.path.join(run_dir, "data.json"), "w") as f:
        json.dump(data, f, indent=4)
    
    state = {"description": desc, "parameters": params}
    return state

def get_labels(dir):
    """Default labels for compatibility"""
    return {"types": ["Mathematical Optimization"], "domains": ["Operations Management"]}

if __name__ == "__main__":
    text = 'To maximize the number of successfully transmitted shows, we can introduce a new variable called "TotalTransmittedShows". This variable represents the total number of shows that are successfully transmitted.\n\nThe constraint can be formulated as follows:\n\n\\[\n\\text{{Maximize }} TotalTransmittedShows\n\\]\n\nTo model this constraint in the MILP formulation, we need to add the following to the variables list:\n\n\\{\n    "TotalTransmittedShows": \\{\n        "shape": [],\n        "type": "integer",\n        "definition": "The total number of shows transmitted"\n    \\}\n\\}\n\nAnd the following auxiliary constraint:\n\n\\[\n\\forall i \\in \\text{{NumberOfShows}}, \\sum_{j=1}^{\\text{{NumberOfStations}}} \\text{{Transmitted}}[i][j] = \\text{{TotalTransmittedShows}}\n\\]\n\nThe complete output in the requested JSON format is:\n\n\\{\n    "FORMULATION": "",\n    "NEW VARIABLES": \\{\n        "TotalTransmittedShows": \\{\n            "shape": [],\n            "type": "integer",\n            "definition": "The total number of shows transmitted"\n        \\}\n    \\},\n    "AUXILIARY CONSTRAINTS": [\n        ""\n    ]\n\\'
    
    extract_json_from_end(text)