import argparse
import sys
import os

base = "research/2025_mip/"
sys.path.append(base)
sys.path.append(os.path.join(base, 'forge'))

import time
from forge.forge import Forge
from forge.utils import *

try:
    from gurobi_onboarder import init_gurobi
    gurobi_venv, GUROBI_FOUND = init_gurobi.initialize_gurobi()
except:
    gurobi_venv = gp.Env(empty=True)

device = torch.device('cpu')

parser = argparse.ArgumentParser()
parser.add_argument('--prob_type', type=str, default="SC")
parser.add_argument('--inst_idx', type=int)
parser.add_argument('--difficulty', type=str)
parser.add_argument('--threads', type=int)
parser.add_argument('--timelimit', type=int)
parser.add_argument('--debug', type=int, default=0)

args = parser.parse_args()

run_time_ws = []
run_time_og = []

node_count_ws = []
node_count_og = []

pg_ws = []
pg_og = []

base_data = os.path.join(base, 'data/test/')
base_data += args.prob_type + '-' + args.difficulty + '/'
instances = os.listdir(base_data)

print_string = args.prob_type + " - " + args.difficulty + " - " + str(args.inst_idx) + '\n'

inst = instances[args.inst_idx]

# Default Gurobi Solution
gurobi_venv.setParam("TimeLimit", args.timelimit)
gurobi_venv.setParam("OutputFlag", args.debug)
gurobi_venv.setParam("Threads", args.threads)
callback_og = Callback()
m = gp.read(os.path.join(base_data, inst), env=gurobi_venv)
s = time.time()
m.optimize(callback_og)

run_time_og.append(np.round(time.time() - s, 3))
node_count_og.append(m.NodeCount)
pg_og.append(callback_og.primal_gap)

print_string += "Gurobi Time : " + str(run_time_og[-1]) + "s\nObjective : " + str(m.ObjVal) + "\nNode Count : " + str(
    m.NodeCount) + '\n'

######################################################
# Start Warm Start Section

s_ = time.time()
m = gp.read(os.path.join(base_data, inst), env=gurobi_venv)
variables = m.getVars()

model = Forge(prob_head=True, cut_head=False)
model.load_model(os.path.join(base, 'models/warm_start_model_pq.pkl'), model_type='warm_start')
hint_ones, hint_zeros, hint_pri_ones, hint_pri_zeros = model.mip_to_hint(
    mip_instance_path=os.path.join(base_data, inst),
    prob_type=args.prob_type)

hint_count = 0
for i, j in list(zip(hint_ones, hint_pri_ones)):
    m.getVarByName(variables[i].varName).VarHintVal = 1
    m.getVarByName(variables[i].varName).VarHintPri = j
    hint_count += 1

hint_count_zero = 0
for i, j in list(zip(hint_zeros, hint_pri_zeros)):
    m.getVarByName(variables[i].varName).VarHintVal = 0
    m.getVarByName(variables[i].varName).VarHintPri = j

    # GISP/CA seem to benefit by also setting a start value 
    if args.prob_type in ['GISP', 'CA']:
        m.getVarByName(variables[i].varName).Start = 0

    hint_count_zero += 1

preprocess_time = time.time() - s_

callback_perc = Callback()
s = time.time()
m.update()
if args.debug:
    print("Starting Final Optimization\n")
m.optimize(callback_perc)
run_time_ws.append(preprocess_time + (np.round(time.time() - s, 3)))
node_count_ws.append(m.NodeCount)
pg_ws.append(callback_perc.primal_gap)

print_string += "Warm Start Time : " + str(run_time_ws[-1]) + "s\nWarm Start Objective : " + str(
    m.ObjVal) + "\nWarm Start Node Count : " + str(m.NodeCount) + '\nPreprocess Time : ' + str(preprocess_time) + '\n'

with open(os.path.join(base,
                       'data/temp/res_' + args.prob_type + '-' + args.difficulty + '-' + str(args.inst_idx) + '.pkl'),
          'wb') as file:
    pkl.dump([run_time_ws, run_time_og, node_count_ws, node_count_og, pg_ws, pg_og], file)

print_string += ' ---- ' + str(len(hint_ones)) + ' | ' + str(len(hint_zeros)) + " ----\n\n"

with open(os.path.join(base, 'data/log/log_' + args.prob_type + '-' + args.difficulty + '.txt'), 'a') as file:
    file.write(print_string)
