import os
import sys
import pickle
import json
import pandas as pd
import io
import sys
import subprocess
import concurrent.futures
import importlib.util
import multiprocessing
import traceback
from contextlib import redirect_stdout, redirect_stderr
from tqdm import tqdm
model_names = ["llama-3.3","llama-4","gemini"]


def collect_summary(data)-> dict():
    summary = dict()
    if not isinstance(data['num_vars'],type(None)):
        summary.__setitem__('nvars',data['num_vars'])
    if not isinstance(data['num_resources'],type(None)):
        summary.__setitem__('nres',data['num_resources'])
    if not isinstance(data['problem_constraints'],type(None)):
        summary.__setitem__('ncons',len(data['problem_constraints']))
    if not isinstance(data['problem_type'],type(None)):
        summary.__setitem__('ptype',data['problem_type'].get('degree',0))
        summary.__setitem__('itype',data['problem_type'].get('is_integer',))
    if not isinstance(data['problem_solution'],type(None)):
        summary.__setitem__('solvable', False if data['problem_solution'].get('Value',None)=='Infeasible' else True)
    return summary

def grab_summary(df:pd.DataFrame)->dict():
    summ = dict()
    summ.__setitem__('MedianNumVars',df.nvars.median())
    summ.__setitem__('MaxNumVars',df.nvars.max())
    summ.__setitem__('MedianNumResources',df.nres.median())
    summ.__setitem__('MaxNumResources',df.nres.max())
    summ.__setitem__('MedianNumConstraints',df.ncons.median())
    summ.__setitem__('MaxNumConstraints',df.ncons.max())
    summ.__setitem__('Total Solvable',len(df[df.solvable==True]))
    summ.__setitem__('Total',len(df))
    return summ

def _run_script_worker(script_code, queue):
    stdout_buffer = io.StringIO()
    stderr_buffer = io.StringIO()
    exec_globals = {}
    exec_locals = {}

    with redirect_stdout(stdout_buffer), redirect_stderr(stderr_buffer):
        try:
            exec(script_code, exec_globals, exec_locals)
        except Exception as e:
            print(f"Exception occurred during execution: {e}", file=stderr_buffer)
            traceback.print_exc(file=stderr_buffer)

    queue.put({
        "output": {
            "stdout": stdout_buffer.getvalue(),
            "stderr": stderr_buffer.getvalue()
        }
    })

def _run_corrected_worker(script_code, queue):
    # Patch import if needed
    if "gurobipy" not in script_code:
        code = script_code.replace("import gurobi", "import gurobipy")
    else:
        code = script_code

    stdout_buffer = io.StringIO()
    stderr_buffer = io.StringIO()
    exec_globals = {}
    exec_locals = {}

    with redirect_stdout(stdout_buffer), redirect_stderr(stderr_buffer):
        try:
            exec(code, exec_globals, exec_locals)
        except Exception as e:
            print(f"Exception occurred during execution: {e}", file=stderr_buffer)
            traceback.print_exc(file=stderr_buffer)

    queue.put({
        "output": {
            "stdout": stdout_buffer.getvalue(),
            "stderr": stderr_buffer.getvalue()
        }
    })

def run_script_with_logging(script_path, timeout=10):
    # Read script as string input
    with open(script_path, 'r', encoding='utf-8') as f:
        script_code = f.read()

    queue = multiprocessing.Queue()
    proc = multiprocessing.Process(target=_run_script_worker, args=(script_code, queue))
    proc.start()
    proc.join(timeout)

    if proc.is_alive():
        proc.terminate()
        proc.join()
        print(f"[Timeout] Script '{script_path}' exceeded {timeout} seconds.", file=sys.stderr)
        return {
            "input": script_code,
            "output": {
                "stdout": "",
                "stderr": f"[Timeout] Script exceeded {timeout} seconds.\n"
            }
        }

    result = queue.get() if not queue.empty() else {"output": {"stdout": "", "stderr": "[Unknown failure]"}}
    result["input"] = script_code
    return result
    # # Read script as string input
    # with open(script_path, 'r',encoding = 'utf-8') as f:
    #     script_code = f.read()

    # # Prepare stdout and stderr buffers
    # stdout_buffer = io.StringIO()
    # stderr_buffer = io.StringIO()

    # # Optional: Controlled execution scope
    # exec_globals = {}
    # exec_locals = {}

    # with redirect_stdout(stdout_buffer), redirect_stderr(stderr_buffer):
    #     try:
    #         exec(script_code, exec_globals, exec_locals)
    #     except Exception as e:
    #         print(f"Exception occurred during execution: {e}", file=stderr_buffer)

    # return {
    #     "input": script_code,
    #     "output": {
    #         "stdout": stdout_buffer.getvalue(),
    #         "stderr": stderr_buffer.getvalue()
    #     }
    # }

def run_corrected_script_with_logging(script_code,timeout=10):
    queue = multiprocessing.Queue()
    proc = multiprocessing.Process(target=_run_corrected_worker, args=(script_code, queue))
    proc.start()
    proc.join(timeout)

    if proc.is_alive():
        proc.terminate()
        proc.join()
        print("[Timeout] Script execution exceeded timeout.", file=sys.stderr)
        return {
            "input": script_code,
            "output": {
                "stdout": "",
                "stderr": f"[Timeout] Script execution exceeded {timeout} seconds.\n"
            }
        }

    result = queue.get() if not queue.empty() else {"output": {"stdout": "", "stderr": "[Unknown failure]"}}
    result["input"] = script_code
    return result
    # if "gurobi_py" not in script_code:
    #     code = script_code.replace("import gurobi", "import gurobipy")
    # else: 
    #     code = script_code
    # stdout_buffer = io.StringIO()
    # stderr_buffer = io.StringIO()

    # # Optional: Controlled execution scope
    # exec_globals = {}
    # exec_locals = {}

    # with redirect_stdout(stdout_buffer), redirect_stderr(stderr_buffer):
    #     try:
    #         exec(code, exec_globals, exec_locals)
    #     except Exception as e:
    #         print(f"Exception occurred during execution: {e}", file=stderr_buffer)

    # return {
    #     "input": script_code,
    #     "output": {
    #         "stdout": stdout_buffer.getvalue(),
    #         "stderr": stderr_buffer.getvalue()
    #     }
    # }
    
def run_problem_gurobi_code(gurobi_code):
        # proceed to run the gurobi code
        local_scope = {}
        try:
            exec(gurobi_code, globals())
            m = make_and_optimize_model()
        except Exception as E:
            print(f"{self.problem_statement}")
            print(f"{self.gurobi_code}")
            print(E)
            raise NameError
        solution = dict()
        v_solution = dict()
        if m.status == GRB.OPTIMAL:
            for v in m.getVars():
                try:
                    v_solution.__setitem__(v.VarName, v.X)
                except AttributeError:
                    try:
                        v_solution.__setitem__(v.VarName, v.Xn)
                    except Exception as E:
                        print(f" Error {E} on input {v}")
                        v_solution.__setitem__(v.VarName,0)
            try:
                solution.__setitem__('Value',m.ObjVal)
            except AttributeError:
                try:
                    solution.__setitem__('Value',m.ObjNVal)
                except Exception as E:
                    print(f" Error {E} on input {m}")
        else:
            solution.__setitem__('Value','Infeasible')
        solution.__setitem__("Optimal solution",v_solution)
        return solution

def get_problem_result(problem_pickle):
    with open(problem_pickle,'rb') as fp:
        return pickle.load(fp).get("problem_solution",dict())

def nlp4lp_compare_problem(datapath,problem_id:str,llm_name:str):
    # From problem id and llm_name get the following: ground answer, optimum_match: 1/0, full_matches, error message
    ACC = 0
    ACC_adj = 0
    CE = 0
    RE = 0
    RE_adj = 0
    
    solution_dict = get_problem_result(os.path.join(datapath,f"{problem_id}.pkl"))
    proposal = f"{problem_id}_{llm_name}.py"
    if proposal in os.listdir(datapath):
        proposed_solution_log = run_script_with_logging(os.path.join(datapath,proposal))
        if str(solution_dict.get("Value")) in proposed_solution_log.get("output").get("stdout"):
            ACC = 1
            ACC_adj = 1 
        if proposed_solution_log.get("output").get("stderr") != '':
            if "invalid syntax" in proposed_solution_log.get("output").get("stderr"):
                CE = 1
            else:
                RE = 1
                amended = run_corrected_script_with_logging(proposed_solution_log.get("input"))
                if str(solution_dict.get("Value")) in amended.get("output").get("stdout"):
                    ACC_adj = 1
                if amended.get("output").get("stderr") !='':
                    RE_adj = 1
    else:
        CE = 1        
    return dict({"ACC":ACC,"ACC_adj":ACC_adj,"CE":CE,"RE":RE,"RE_adj":RE_adj})

def op_compare_problem(datapath,problem_id:str,llm_name:str):
    # From problem id and llm_name get the following: ground answer, optimum_match: 1/0, full_matches, error message
    ACC = 0
    ACC_adj = 0
    CE = 0
    RE = 0
    RE_adj = 0
    
    solution_dict = get_problem_result(os.path.join(os.path.join(os.getcwd(),"data","oproblems",problem_id,"problem_dict.pkl")))
    proposal = f"{problem_id}_{llm_name}.py"
    if proposal in os.listdir(datapath):
        proposed_solution_log = run_script_with_logging(os.path.join(datapath,proposal))
        # print(f"The proposed log is {proposed_solution_log}")
        # print(f"The proposed solution output is {proposed_solution_log.get("output")}")
        if str(solution_dict.get("Value",None)) in proposed_solution_log.get("output",dict()).get("stdout",""):
            ACC = 1
            ACC_adj = 1 
        if proposed_solution_log.get("output",dict()).get("stderr") != '':
            if "invalid syntax" in proposed_solution_log.get("output",dict()).get("stderr"):
                CE = 1
            else:
                RE = 1
                amended = run_corrected_script_with_logging(proposed_solution_log.get("input"))
                if str(solution_dict.get("Value")) in amended.get("output",dict()).get("stdout"):
                    ACC_adj = 1
                if amended.get("output").get("stderr") !='':
                    RE_adj = 1
    else:
        CE = 1    
    return dict({"ACC":ACC,"ACC_adj":ACC_adj,"CE":CE,"RE":RE,"RE_adj":RE_adj})

def op_method(with_sym:bool):
    prob_names = [f"prob_{n}" for n in range(1000)]
    result = dict()
    for llm_name in model_names:
        llm_result = dict()
        if with_sym:
            path_name = os.path.join(os.getcwd(),'data','baseline_0shot_with_sym')                
        else:
            path_name = os.path.join(os.path.join(os.getcwd(),'data','baseline_0shot'))
        for problem_id in tqdm(prob_names):
            # try:
            llm_result.__setitem__(problem_id,op_compare_problem(path_name,problem_id,llm_name))
            # except Exception as E:
            #     # print(f"Raised {E} on problem {problem_id}")
            #     llm_result.__setitem__(problem_id,dict({"ACC":0,"ACC":0,"CE":1,"RE":0,"RE_adj":0,"Error":E}))
        result.__setitem__(llm_name,llm_result)
    return result

def nlp4lp_method(with_sym):
    path = os.path.join(os.getcwd(),'dataset','nlp4opt','generation_data')
    bad_dict = dict()
    data = []
    with open(os.path.join(path,'train.jsonl'), 'r', encoding='utf-8' ) as f:
        for line in f:
            data.append(json.loads(line))
    # nlp_dict = dict(data[0]).get(list(data[0].keys())[0])
    prob_names = []
    for item in data:
        if f'{list(item.keys())[0]}.pkl' in os.listdir(os.path.join(os.getcwd(),'data','baseline_0shot_nlp4lp')) and f'{list(item.keys())[0]}.pkl' in os.listdir(os.path.join(os.getcwd(),'data','baseline_0shot_with_sym_nlp4lp')):
            prob_names.append(list(item.keys())[0])
    result = dict()
    
    for llm_name in model_names:
        llm_result = dict()
        if with_sym:
            path_name = os.path.join(os.getcwd(),'data','baseline_0shot_with_sym_nlp4lp')                
        else:
            path_name = os.path.join(os.path.join(os.getcwd(),'data','baseline_0shot_nlp4lp'))
        for problem_id in tqdm(prob_names):
            # try:
            llm_result.__setitem__(problem_id,nlp4lp_compare_problem(path_name,problem_id,llm_name))
            # except Exception as E:
            #     # print(f"Raised {E} on problem {problem_id}")
            #     llm_result.__setitem__(problem_id,dict({"ACC":0,"ACC":0,"CE":1,"RE":0,"RE_adj":0,"Error":E}))
        result.__setitem__(llm_name,llm_result)
    return result

def main():
    results_dict=dict()
    print("\n\nMaking oproblem results...")
    results_dict.__setitem__("op_result", op_method(False))
    print("\n\nMaking oproblem results with sym...")
    results_dict.__setitem__("op_result_wsym", op_method(True))
    print("\n\nMaking nlp4lp results...")
    results_dict.__setitem__("nlp4lp_result", nlp4lp_method(False))
    print("\n\nMaking nlp4lp results with sym...")
    results_dict.__setitem__("nlp4lp_result_wsym", nlp4lp_method(True))
    print("\n\nSaving...")
    with open(os.path.join(os.getcwd(),'analysis','results.pkl'), 'wb') as f:  
        pickle.dump(results_dict, f)
    f.close() 
    print(f"Saved at {os.path.join(os.getcwd(),'results.pkl')}")
    new_data = dict()
    for key, v0 in results_dict.items():
        for model,v1 in v0.items():
            new_data.__setitem__(f"{model}_{key}",pd.DataFrame.from_dict(results_dict[key][model], orient="index").mean())
    with open(os.path.join(os.getcwd(),'analysis','summary_statistics.pkl'), 'wb') as f:
        pickle.dump(new_data,f)
    f.close()
    
    # Oproblem summary
    problem_summaries = dict()
    for pname in os.listdir(os.path.join(os.getcwd(),'data','oproblems')):
        with open(os.path.join(os.getcwd(),'data','oproblems', pname, 'problem_dict.pkl'),'rb') as f:
            data = pickle.load(f)
        problem_summaries.__setitem__(pname, collect_summary(data))
    psdf =  pd.DataFrame.from_dict(problem_summaries, orient = 'index')
    final_summ = dict()
    final_summ.__setitem__('IPLP',grab_summary(psdf.loc[(psdf.itype==1.0) & (psdf.ptype == 1.0) ]))
    final_summ.__setitem__('IPQP',grab_summary(psdf.loc[(psdf.itype==1.0) & (psdf.ptype == 2.0) ]))
    final_summ.__setitem__('MIPLP',grab_summary(psdf.loc[(psdf.itype==0.5) & (psdf.ptype == 1.0) ]))
    final_summ.__setitem__('MIPQP',grab_summary(psdf.loc[(psdf.itype==0.5) & (psdf.ptype == 2.0) ]))
    final_summ.__setitem__('LP',grab_summary(psdf.loc[(psdf.itype==0.0) & (psdf.ptype == 1.0) ]))
    final_summ.__setitem__('QP',grab_summary(psdf.loc[(psdf.itype==0.0) & (psdf.ptype == 2.0) ]))
    with open(os.path.join(os.getcwd(),'analysis',oproblem_summary_statistics.pkl'), 'wb') as f:
        pickle.dump(final_summ,f)
    f.close()
    # nlp4lp problem summary
    nlproblem_summaries = dict()
    for pname in os.listdir(os.path.join(os.getcwd(),'data','baseline_0shot_with_sym_nlp4lp')):
        if ".pkl" in pname:
            with open(os.path.join(os.path.join(os.getcwd(),'data','baseline_0shot_with_sym_nlp4lp'), pname),'rb') as f:
                data = pickle.load(f)
            nlproblem_summaries.__setitem__(pname.replace(".pkl",""),collect_summary(data))
    nlpsdf = pd.DataFrame.from_dict(nlproblem_summaries, orient = "index")
    nlp_final_summ = dict()
    nlp_final_summ.__setitem__('LP', grab_summary(nlpsdf))
    with open(os.path.join(os.getcwd(),'analysis','nlp_summary_statistics.pkl'), 'wb') as f:
        pickle.dump(nlp_final_summ,f)
    f.close()
    


if __name__ == "__main__":
    main()
    # print(op_compare_problem(os.path.join(os.getcwd(),'data','baseline_0shot_with_sym'), 'prob_856','llama-4'))
    # print(op_compare_problem(os.path.join(os.getcwd(),'data','baseline_0shot_with_sym'), 'prob_617','llama-4'))
    # print(op_compare_problem(os.path.join(os.getcwd(),'data','baseline_0shot_with_sym'), 'prob_618','llama-4'))
    # print(op_compare_problem(os.path.join(os.getcwd(),'data','baseline_0shot_with_sym'), 'prob_856','llama-3.3'))
    # print(op_compare_problem(os.path.join(os.getcwd(),'data','baseline_0shot_with_sym'), 'prob_617','llama-3.3'))
    # print(op_compare_problem(os.path.join(os.getcwd(),'data','baseline_0shot_with_sym'), 'prob_618','llama-3.3'))
    # print(op_compare_problem(os.path.join(os.getcwd(),'data','baseline_0shot_with_sym'), 'prob_856','gemini'))
    # print(op_compare_problem(os.path.join(os.getcwd(),'data','baseline_0shot_with_sym'), 'prob_617','gemini'))
    # print(op_compare_problem(os.path.join(os.getcwd(),'data','baseline_0shot_with_sym'), 'prob_618','gemini')) 