import argparse
import os
import yaml
from dotenv import load_dotenv
from pydantic import BaseModel
from util import CONFIGS_DIR, DATA_DIR

# OpenAI SDK v1
import openai
load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")
if not openai.api_key:
    raise RuntimeError("OPENAI_API_KEY is not set. Please export it before running this script.")
client = openai.Client()


# ---- structured output ----
class Parameter(BaseModel):
    name: str
    value: int
    explanation: str


class Configuration(BaseModel):
    parameters: list[Parameter]


# ---------------------------
# Normalization & validation helpers
# ---------------------------

def _norm_key(s: str) -> str:
    s = (s or "").strip().lower()
    return "".join(ch for ch in s if ch.isalnum())


# ====== Gurobi canonical & aliases ======
_GRB_CANON = {
    "BQPCuts", "CliqueCuts", "CoverCuts", "FlowCoverCuts", "FlowPathCuts", "GUBCoverCuts",
    "ImpliedCuts", "InfProofCuts", "LiftProjectCuts", "MIRCuts", "MixingCuts", "ModKCuts",
    "NetworkCuts", "ProjImpliedCuts", "PSDCuts", "RelaxLiftCuts", "StrongCGCuts",
    "SubMIPCuts", "ZeroHalfCuts",
}
# Only primalhint for Gurobi (no root_only)
_GRB_GLOBAL = {"primalhint"}

_GRB_ALIASES = {
    **{_norm_key(n): n for n in _GRB_CANON},
    _norm_key("BQP"): "BQPCuts",
    _norm_key("Clique"): "CliqueCuts",
    _norm_key("Clique Cuts"): "CliqueCuts",
    _norm_key("Cover"): "CoverCuts",
    _norm_key("Cover Cuts"): "CoverCuts",
    _norm_key("FlowCover"): "FlowCoverCuts",
    _norm_key("Flow Cover Cuts"): "FlowCoverCuts",
    _norm_key("FlowPath"): "FlowPathCuts",
    _norm_key("Flow Path Cuts"): "FlowPathCuts",
    _norm_key("GUBCover"): "GUBCoverCuts",
    _norm_key("GUB Cover Cuts"): "GUBCoverCuts",
    _norm_key("Implied"): "ImpliedCuts",
    _norm_key("ImpliedBound"): "ImpliedCuts",
    _norm_key("Implied Bound Cuts"): "ImpliedCuts",
    _norm_key("InfProof"): "InfProofCuts",
    _norm_key("Infeasibility Proof Cuts"): "InfProofCuts",
    _norm_key("LiftProject"): "LiftProjectCuts",
    _norm_key("Lift-and-Project"): "LiftProjectCuts",
    _norm_key("Lift and Project Cuts"): "LiftProjectCuts",
    _norm_key("MIR"): "MIRCuts",
    _norm_key("MixedIntegerRoundingCuts"): "MIRCuts",
    _norm_key("GMI Cuts"): "MIRCuts",
    _norm_key("Mixing"): "MixingCuts",
    _norm_key("Mixing Cuts"): "MixingCuts",
    _norm_key("ModK"): "ModKCuts",
    _norm_key("Modk Cuts"): "ModKCuts",
    _norm_key("Network"): "NetworkCuts",
    _norm_key("Network Cuts"): "NetworkCuts",
    _norm_key("ProjImplied"): "ProjImpliedCuts",
    _norm_key("Projected Implied Bound Cuts"): "ProjImpliedCuts",
    _norm_key("PSD"): "PSDCuts",
    _norm_key("PSD Cuts"): "PSDCuts",
    _norm_key("RelaxLift"): "RelaxLiftCuts",
    _norm_key("Relax and Lift Cuts"): "RelaxLiftCuts",
    _norm_key("StrongCG"): "StrongCGCuts",
    _norm_key("Strong CG Cuts"): "StrongCGCuts",
    _norm_key("SubMIP"): "SubMIPCuts",
    _norm_key("MIPSepCuts"): "SubMIPCuts",
    _norm_key("ZeroHalf"): "ZeroHalfCuts",
    _norm_key("Zero Half Cuts"): "ZeroHalfCuts",
}

_GRB_GLOBAL_ALIASES = {
    _norm_key("primalhint"): "primalhint",
    _norm_key("primal_hint"): "primalhint",
    _norm_key("hint"): "primalhint",
}


def _coerce_grb_value(canon_name: str, v_raw) -> int | None:
    try:
        x = float(v_raw)
    except Exception:
        return None
    if canon_name in _GRB_GLOBAL:
        return 1 if x >= 0.5 else 0
    # separators: {-1, 0, 1, 2}
    if x >= 1.5:
        return 2
    if x >= 0.5:
        return 1
    if x <= -0.5:
        return -1
    return 0


def normalize_and_validate_gurobi_config(cfg: dict) -> dict | None:
    params = cfg.get("parameters", [])
    out = []
    seen = set()
    for p in params:
        raw = str(p.get("name", "")).strip()
        key = _norm_key(raw)
        canon = _GRB_GLOBAL_ALIASES.get(key) or _GRB_ALIASES.get(key)
        if canon is None:
            print(f"[llm] drop unknown parameter (gurobi): {raw}")
            continue
        v = _coerce_grb_value(canon, p.get("value", 0))
        if v is None:
            print(f"[llm] drop invalid value for {canon}: {p.get('value')}")
            continue
        if canon in seen:
            for q in out:
                if q["name"] == canon:
                    q["value"] = int(v)
                    q["explanation"] = p.get("explanation", "")
                    break
        else:
            out.append({"name": canon, "value": int(v), "explanation": p.get("explanation", "")})
            seen.add(canon)
    return {"parameters": out} if out else None


# ====== SCIP canonical & aliases ======
_SCIP_CANON = {
    "closecuts", "disjunctive", "convexproj", "gauge", "impliedbounds", "intobj",
    "gomory", "cgmip", "strongcg", "aggregation", "clique", "zerohalf", "mcf", "eccuts",
    "oddcycle", "flowcover", "cmir", "rapidlearning",
}
_SCIP_GLOBAL = {"primalhint", "root_only"}

_SCIP_ALIASES = {
    **{_norm_key(n): n for n in _SCIP_CANON},
    _norm_key("disjunction"): "disjunctive",
    _norm_key("convexprojection"): "convexproj",
    _norm_key("convex projection"): "convexproj",
    _norm_key("integerobjective"): "intobj",
    _norm_key("integer objective"): "intobj",
    _norm_key("chvatalgomory"): "cgmip",
    _norm_key("cg"): "cgmip",
    _norm_key("zero half"): "zerohalf",
    _norm_key("zero-half"): "zerohalf",
    _norm_key("odd cycle"): "oddcycle",
    _norm_key("mir"): "cmir",
    _norm_key("mir cuts"): "cmir",
    _norm_key("flow path"): "mcf",
    _norm_key("flowpath"): "mcf",
    _norm_key("implied"): "impliedbounds",
}

_SCIP_GLOBAL_ALIASES = {
    _norm_key("primalhint"): "primalhint",
    _norm_key("primal_hint"): "primalhint",
    _norm_key("hint"): "primalhint",
    _norm_key("root_only"): "root_only",
    _norm_key("root only cuts"): "root_only",
    _norm_key("rootonlycuts"): "root_only",
}


def _coerce_scip_value(canon_name: str, v_raw) -> int | None:
    try:
        x = float(v_raw)
    except Exception:
        return None
    # separators & globals: 0/1
    return 1 if x >= 0.5 else 0


def normalize_and_validate_scip_config(cfg: dict) -> dict | None:
    params = cfg.get("parameters", [])
    out = []
    seen = set()
    for p in params:
        raw = str(p.get("name", "")).strip()
        key = _norm_key(raw)
        canon = _SCIP_GLOBAL_ALIASES.get(key) or _SCIP_ALIASES.get(key)
        if canon is None:
            print(f"[llm] drop unknown parameter (scip): {raw}")
            continue
        v = _coerce_scip_value(canon, p.get("value", 0))
        if v is None:
            print(f"[llm] drop invalid value for {canon}: {p.get('value')}")
            continue
        if canon in seen:
            for q in out:
                if q["name"] == canon:
                    q["value"] = int(v)
                    q["explanation"] = p.get("explanation", "")
                    break
        else:
            out.append({"name": canon, "value": int(v), "explanation": p.get("explanation", "")})
            seen.add(canon)
    return {"parameters": out} if out else None


# ---------------------------
# Prompt helpers
# ---------------------------

def extract_cutting_planes(descriptions: dict, solver: str) -> list[str]:
    solver = str.upper(solver)
    return [
        "- " + sep["solvers"][solver] + " : " + sep["description"]
        for sep in descriptions["separators"]
        if solver in sep["solvers"]
    ]


def extract_value_instructions(solver: str) -> str:
    match solver.lower():
        case "gurobi":
            return "Values: 2 = aggressive, 1 = on, 0 = off for separators."
        case "scip":
            return "Values: 1 = on, 0 = off for separators."
        case _:
            return ""


def _as_bool(x):
    if isinstance(x, bool):
        return x
    if isinstance(x, str):
        return x.strip().lower() in {"1", "true", "t", "yes", "y"}
    return bool(x)


def extract_default_instructions(allow_default) -> str:
    allow_default = _as_bool(allow_default)
    if allow_default:
        return (
            "You only need to specify cutting planes you are confident you want to turn on or off; "
            "all other cutting planes will be set to their default setting."
        )
    else:
        return (
            "You only need to specify cutting planes to turn on; every other separator will be turned off."
        )


def extract_controls_instructions(solver: str) -> str:
    if solver.lower() == "scip":
        return (
            "Global controls (binary 0/1): "
            "primalhint enables cross-instance warm starts in re-optimization; "
            "root_only restricts enabled separators to the root node (tree freq=0) instead of all nodes."
        )
    # Gurobi: only primalhint
    return (
        "Global control (binary 0/1): "
        "primalhint enables cross-instance warm starts in re-optimization."
    )


def generate_system_prompt(args):
    with open("config_generation/cutting_plane_descriptions.yaml", "r") as f:
        data = yaml.load(f, Loader=yaml.SafeLoader)

    cutting_planes = extract_cutting_planes(data, args.solver)
    value_instructions = extract_value_instructions(args.solver)
    default_instructions = extract_default_instructions(args.allow_default)
    controls_instructions = extract_controls_instructions(args.solver)

    # --- Header aligned with the paper's Meta Prompt ---
    header = (
        "You are configuring for MILP re-optimization: a sequence of closely related instances derived from the same base "
        "model with small changes to objective function coefficients, variable bounds, constraint right-hand sides, or "
        "coefficients of the constraint matrix. You need to configure the following solver parameters and global control options. "
        "Any configuration key not explicitly included in your final output will be set to 0 (disabled/default). "
    )

    # --- Structured guidance (same format as before) ---
    common = (
        "Below is the list of cutting-plane separators and global controls you may configure: "
        f"{cutting_planes} "
        f"{value_instructions} "
        f"{default_instructions} "
        f"{controls_instructions} "
        "Note: enabling more/stronger separators can tighten the LP relaxation and reduce the B&B tree, "
        "but increases separation overhead; disabling them saves cut time but may enlarge the tree. "
        "Empirically, larger instances (e.g., Number of variables ≥ 10000) often benefit from more cut families and aggressive cuts. "
    )

    # concise, solver-specific tail; no full Primal Hint paragraph added
    if args.solver.lower() == "scip":
        tail = (
            "Use 'root_only' to restrict enabled separators to the root node (tree freq=0); "
            "when disabled, separators can run at all nodes."
        )
    else:
        tail = ""

    return header + common + tail


def generate_user_prompt(args):
    path_to_description = os.path.join(DATA_DIR, args.instance_name, "description.txt")
    with open(path_to_description, "r") as file:
        problem_description = file.read()
    return f"Here is the description of the optimization problem: {problem_description}"


def generate_config(args):
    system_prompt = generate_system_prompt(args)
    user_prompt = generate_user_prompt(args)

    completion = client.beta.chat.completions.parse(
        model="gpt-4o-2024-08-06",
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ],
        response_format=Configuration,
    )
    config = completion.choices[0].message.parsed
    cfg_dict = config.dict()

    # post-validate & normalize
    if args.solver.lower() == "gurobi":
        fixed = normalize_and_validate_gurobi_config(cfg_dict)
    elif args.solver.lower() == "scip":
        fixed = normalize_and_validate_scip_config(cfg_dict)
    else:
        fixed = cfg_dict

    if fixed is None or not fixed.get("parameters"):
        print("[llm] skip saving: config empty after normalization/validation.")
        return None

    return Configuration(**fixed)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--instance_name", type=str, help="name of instance family")
    parser.add_argument("--config_name", type=str, help="name of generated configurations", default="baseline")
    parser.add_argument("--allow_default", type=str, default="False", help="allow LLM to pick default settings")
    parser.add_argument("--solver", type=str, help="which MILP solver to use", default="gurobi")
    parser.add_argument("--num_configs", type=int, default=20, help="number of configurations to generate")

    args = parser.parse_args()
    write_path = os.path.join(CONFIGS_DIR, args.instance_name, args.solver, args.config_name)
    os.makedirs(write_path, exist_ok=True)

    for i in range(args.num_configs):
        print(f"Generating config {i+1}/{args.num_configs} for {args.instance_name}")
        config = generate_config(args)
        if config is None:
            continue
        with open(os.path.join(write_path, f"{args.config_name}-{i}.yaml"), "w") as f:
            yaml.safe_dump(config.dict(), f, default_flow_style=False, sort_keys=False)
