import argparse
import sys
import os
import time

base = "/research/2025_mip/"
sys.path.append(base)
sys.path.append(os.path.join(base, 'forge'))
from forge.forge import Forge
from forge.utils import *
try:
    from gurobi_onboarder import init_gurobi
    gurobi_venv, GUROBI_FOUND = init_gurobi.initialize_gurobi()
except:
    gurobi_venv = gp.Env(empty=True)

device = torch.device('cpu')

parser = argparse.ArgumentParser()
parser.add_argument('--prob_type', type=str, default="SC")
parser.add_argument('--inst_idx', type=int)
parser.add_argument('--difficulty', type=str)
parser.add_argument('--threads', type=int)
parser.add_argument('--timelimit', type=int)
parser.add_argument('--debug', type=int, default=0)

args = parser.parse_args()

run_time_ws = []
run_time_og = []

node_count_ws = []
node_count_og = []

pg_ws = []
pg_og = []

base_data = os.path.join(base, 'data/test/')
base_data += args.prob_type + '-' + args.difficulty + '/'
instances = os.listdir(base_data)

print_string = args.prob_type + " - " + args.difficulty + " - " + str(args.inst_idx) + '\n'
inst = instances[args.inst_idx]

# Default Gurobi

gurobi_venv.setParam("TimeLimit", args.timelimit)
gurobi_venv.setParam("OutputFlag", args.debug)
gurobi_venv.setParam("Threads", args.threads)
callback_og = Callback()
m = gp.read(os.path.join(base_data, inst), env=gurobi_venv)
s = time.time()
m.optimize(callback_og)
obj = m.ObjVal

run_time_og.append(np.round(time.time() - s, 3))
node_count_og.append(m.NodeCount)
pg_og.append(callback_og.primal_gap)

print_string += "Gurobi Time : " + str(run_time_og[-1]) + "s\nObjective : " + str(m.ObjVal) + "\nNode Count : " + str(
    m.NodeCount) + '\n'

########################################################
np.random.randint(300)
s_ = time.time()

model = Forge(prob_head=True, cut_head=True)
model.load_model(os.path.join(base, 'models/lp_gap_model_pq.pkl'), model_type='lp_gap')
cut_val, cut_ratio, lp_obj = model.mip_to_lp_cut(mip_instance_path=base_data + inst,
                                                 prob_type=args.prob_type,
                                                 return_metadata=True,
                                                 threads = args.threads)

preprocess_time = time.time() - s_

# Gurobi with Cut 
gurobi_venv.setParam("MIPFocus", 0)
gurobi_venv.setParam("TimeLimit", args.timelimit)
gurobi_venv.setParam("OutputFlag", args.debug)
gurobi_venv.setParam("Threads", args.threads)
m = gp.read(os.path.join(base_data, inst), env=gurobi_venv)
callback_perc = Callback()
s = time.time()

if args.prob_type in ['SC', 'GISP', 'MVC']:
    m.addConstr(quicksum([x.Obj * x for x in m.getVars()]) >= cut_val, name='lazy')  # SC / GISP

elif args.prob_type in ['CA']:
    m.addConstr(quicksum([x.Obj * x for x in m.getVars()]) <= cut_val, name='lazy')  # CA

m.update()
if args.debug:
    print("Starting Final Optimization\n")
m.optimize(callback_perc)
run_time_ws.append(preprocess_time + (np.round(time.time() - s, 3)))
node_count_ws.append(m.NodeCount)
pg_ws.append(callback_perc.primal_gap)

print_string += "Warm Start Time : " + str(run_time_ws[-1]) + "s\nWarm Start Objective : " + str(
    m.ObjVal) + "\nWarm Start Node Count : " + str(m.NodeCount) + '\nPreprocess Time : ' + str(preprocess_time) + '\n'

print_string += "Cut Ratio : " + str(cut_ratio) + "\n"
print_string += "Cut Value : " + str(cut_val) + "\n"
print_string += "True Ratio : " + str(lp_obj / m.ObjVal) + "\n"
print_string += " ----\n\n"

with open(os.path.join(base, 'data/temp/cutres_' + args.prob_type + '-' + args.difficulty + '-' + str(
        args.inst_idx) + '.pkl'), 'wb') as file:
    pkl.dump([run_time_ws, run_time_og, node_count_ws, node_count_og, pg_ws, pg_og], file)

with open(os.path.join(base, 'data/log/cutlog_' + args.prob_type + '-' + args.difficulty + '.txt'), 'a') as file:
    file.write(print_string)
