import os
from symb.Problem import Problem
from LLMServices.LLMService import LLMService
from LLMServices.VllamaLLMService import VllamaLLMService
from LLMServices.OllamaLLMService import OllamaLLMService
from LLMServices.VertexLLMService import VertexLLMService
import pickle
import json
import sys

def Baseline_LLM_Problem_Answer(problem: Problem, llm_service: LLMService,problem_name:str, llm_name:str, include_sym : bool = False):
    """
    Posts problem solution (and attributes) from team or generates a new one. You can feed this the instantiated team, your own problem,
    or nothing--in which case it will generate a new one.

    :param team: Team object
    :param target: team member to address
    :param problem: Problem object
    :param llm_service: LLMService object

    :returns: response containing problem solution and attributes.
    """

    prompt =f"""
    PURPOSE: I need you to solve an optimization problem, outputting gurobi code that captures the problem
    description and provides a solution, or otherwise indicates the problem is infeasible.
    CONTEXT: I have the following variables to consider: {problem.problem_variables}  which have the following resources/attributes that I need
    to deal with: {problem.nl_resources_dict}
    
    ROLE: You are a consulting team of business analysts, operations researchers, and programmers who will convert
    my natural language description of an optimization problem into functional gurobi code that answers my problem.
    INPUT: I need to {problem.goal} the following objective function : {problem.objective_function_statement}
    subject to the following constraints:\n""" 
    for constraint in problem.problem_constraints:
        prompt += f"* {constraint}\n"
    
    prompt+="OUTPUT:"
    if include_sym:
        prompt += f""" 
        In order to convince me that the code you are producing is correct, I also need to have a symbolic 
        representation of the problem showing me that you have converted the description above into an appropriate
        symbolic representation of the optimization problem. This consists of a pairs of variables in symbolic notation
        for the first item in the pair of the form 'x1', 'x2', and so on, and the second item of the pair being the 
        natural language object appearing in the problem description; the objective function rendered as an algebraic
        term where all natural language objects are substituted for the corresponding symbolic variable; and the
        list of semi-algebraic constraints where the natural language object is substituted with its 
        symbolic variable counterpart.
        
        Return this solution in a code bloc encased as ```json 
        {dict({"sym_variables":[("x#i","object#i")],
        "objective_function":"objective function description with sym variables",
        "constraints":["constraint",]})} 
        ```
        """
    prompt += f"""
    Finally, please output gurobi code enclosed as ```python \n <CODE>\n```.
    * Do not have anything else AFTER this final block.
    * If you provide any reasoning for your final answer, you MUST put it before the final ```python 
    <CODE>
    ``` 
    bloc
    """
    
    response = llm_service.do_prompt_get_text(prompt,name="Consultant")
    if include_sym:
        with open(os.path.join(os.getcwd(),'data','baseline_0shot_with_sym',problem_name+"_"+llm_name+"_response.txt"), "w") as outfile:
            outfile.write(response)
        sym_response = response.split('json')[-1]
        sym_response = sym_response.split('````')[0]
        with open(os.path.join(os.getcwd(),'data','baseline_0shot_with_sym',problem_name+"_"+llm_name+"_solution.json"), 'w') as f:
            json.dump(sym_response,f)
        code = response.split("python")[-1]
        code = code.split('```')[0]
        with open(os.path.join(os.getcwd(),'data', 'baseline_0shot_with_sym',problem_name+"_"+llm_name+".py"),'w',encoding='utf-8') as f:
            f.write(code)
    else:
        with open(os.path.join(os.getcwd(),'data','baseline_0shot',problem_name+"_"+llm_name+"_response.txt"), "w") as outfile:
            outfile.write(response)
        code = response.split("python")[-1]
        code = code.split('```')[0]
        with open(os.path.join(os.getcwd(),'data', 'baseline_0shot',problem_name+"_"+llm_name+".py"),'w',encoding='utf-8') as f:
            f.write(code)

def main():
    # llm_service = OllamaLLMService() #OllamaLLMService() #VllamaLLMSerivce
    # llm_name = 'llama-3.3' #'llama-4'
    if len(sys.argv) != 2:
        print("Usage: choose llm_name from : llama-3.3, llama-4, gemini")
        print(f"The sys args are {sys.argv[0]}")
    llm_name = sys.argv[1]
    if llm_name not in ['llama-3.3','llama-4','gemini']:
        llm_name = "llama-4"
    if llm_name == "llama-3.3":
        llm_service = OllamaLLMService()
    elif llm_name == "gemini":
        llm_service= VertexLLMService()
    else:
        llm_service = VllamaLLMService()
    print(f"We are using {llm_name} which is the {llm_service}")
    prob_dir = os.path.join(os.getcwd(),'data','oproblems')
    for prob_name in os.listdir(prob_dir):
        try:
            if ".pkl" in os.listdir(os.path.join(prob_dir,prob_name))[0]:
                with open(os.path.join(prob_dir,prob_name,"problem_dict.pkl"), "rb") as f:
                    problem_dict = pickle.load(f)
            problem = Problem(**problem_dict)

            if prob_name+"_"+llm_name+"_response.txt" not in os.listdir(os.path.join(os.getcwd(),'data','baseline_0shot_with_sym')):
                Baseline_LLM_Problem_Answer(problem = problem, 
                                            llm_service = llm_service, 
                                            problem_name = prob_name, 
                                            llm_name=llm_name,
                                            include_sym = True)
            if prob_name+"_"+llm_name+"_response.txt" not in os.listdir(os.path.join(os.getcwd(),'data','baseline_0shot')):
                Baseline_LLM_Problem_Answer(problem = problem, 
                                            llm_service = llm_service, 
                                            problem_name = prob_name, 
                                            llm_name=llm_name)    
        except Exception as E:
            print(f"Ran into error: {E} for {prob_name}")
            
if __name__ == "__main__":
    main()