# protocol_utils.py
""" Methods and structures for implementing dataset generation. Allows for train/test/validation splits, new seeds, and customized parameterization over defaults
This presents a sufficiently more customizable and uniform approach over the baseline generation, which can be achieved using:

```
import os
from symb.Problem import generate_problems_object
generate_problems_object(n= NUMBER OF PROBLEMS, path= PATH_TO_SAVE_GENERATED_PROBLEMS)
```
"""

import os, json, hashlib, time, random
import numpy as np
from dataclasses import dataclass, asdict
from typing import Dict, Any, Tuple, List, Optional
from symb.Problem import Problem, get_with_lazy_default
from experiment_analysis import op_method
from experiment_analysis import nlp4lp_method


# generate_problems_object(n=1000, path = os.path.join(os.getcwd(),'data','oproblems'))

# Configs & Seeding 
@dataclass 
class TrancheConfig:
    # High-level controls
    name: str                      
    problem_kwargs: Dict[str, Any] 
    # To DO: future customization for ranges and mixers for sampling strategies should go here go here

def _seed_everything(seed: int) -> None:
    """Deterministic seeding for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)

def _hash_config(cfg: Dict[str, Any]) -> str:
    """Stable short hash for config dictionaries."""
    blob = json.dumps(cfg, sort_keys=True, default=str).encode("utf-8")
    return hashlib.md5(blob).hexdigest()[:10]

def _now_utc_compact() -> str:
    return time.strftime("%Y%m%dT%H%M%SZ", time.gmtime())

# Problem sampling

def sample_problem(cfg: TrancheConfig, rng_seed: Optional[int] = None) -> Dict[str, Any]:
    """
    Create ONE Problem instance under a fixed RNG seed and return its dict using Problem.__init__.
    If Problem lazily calls a zero-arg _gen_resources_dict but the underlying implementation
    requires a positional 'resources_split_parameter', we temporarily patch the class method
    to supply a sensible default.
    """
    if rng_seed is not None:
        _seed_everything(rng_seed)

    kwargs = dict(cfg.problem_kwargs)
    # Prefer caller's value; otherwise a symmetric split
    resources_param = kwargs.setdefault("resources_split_parameter", 0.5)

    # Try clean instantiation first
    try:
        p = Problem(**kwargs)
        return p.to_dict()
    except TypeError as e:
        msg = str(e)
        needs_patch = "Problem._gen_resources_dict() missing 1 required positional argument" in msg

        if not needs_patch:
            # Propagate unrelated constructor errors
            raise

        # --- Patch: install a wrapper that takes (self) and forwards a param ---
        # Make the param also visible on the class so the wrapper can read it without a closure.
        # (Instances that set .resources_split_parameter will take precedence.)
        orig = Problem._gen_resources_dict
        setattr(Problem, "_proto_resources_param", resources_param)

        def _patched(self):
            param = getattr(self, "resources_split_parameter",
                            getattr(Problem, "_proto_resources_param", 0.5))
            return orig(self, param)

        try:
            # Assign the plain function; Python will bind it as a method automatically.
            Problem._gen_resources_dict = _patched  # type: ignore[assignment]
            p = Problem(**kwargs)
            return p.to_dict()
        finally:
            # Always restore
            Problem._gen_resources_dict = orig  # type: ignore[assignment]
            # Optional: clean up the class attr
            if hasattr(Problem, "_proto_resources_param"):
                delattr(Problem, "_proto_resources_param")


# def sample_problem(cfg: TrancheConfig, rng_seed: Optional[int] = None) -> Dict[str, Any]:
#     """
#     Create a Problem instance under a fixed RNG seed and return its dict using Problem.__init__ random pathways.
#     Control is only given over RNG state. 
#     If specific problems want to be tested, directly structure the problem per the Problem parameters guidance.
#     """
#     if rng_seed is not None:
#         _seed_everything(rng_seed)
#     kwargs = dict(cfg.problem_kwargs)
#     kwargs.setdefault("resources_split_parameter", 0.5) 
#     # Since Problem uses get_with_lazy_default, let's make sure we're handling this in the try-except blocks
#     try:

#         p = Problem(**kwargs)
#         return p.to_dict()
#     except TypeError as e:
#         # Attempt to handle exact failure mode seen in traceback
#         if "Problem._gen_resources_dict() missing 1 required positional argument" not in str(e):
#             raise

#         # Proposed monkey-patch -- wrap the original to supply the missing param when called with zero args ---
#         import types
#         orig = Problem._gen_resources_dict

#         def _patched(self):
#             # per pseudo-code diagram, let's fall back to attribute if the class later sets it. use the kwarg default otherwise
#             param = getattr(self, "resources_split_parameter", resources_param)
#             return orig(self, param)

#         try:
#             Problem._gen_resources_dict = types.FunctionType(_patched.__code__, globals(), "_patched")
#             # Now let's try it as a function attribute, since python will interpret it as a descriptor -> bound method
#             p = Problem(**kwargs)
#             return p.to_dict()
#         finally:
#             # Always restore the original
#             Problem._gen_resources_dict = orig

def make_tranche(
    cfg: TrancheConfig,
    seed: int,
    n: int,
    out_dir: str,
    as_json: bool = True,
    tranche_tag: Optional[str] = None,
) -> Dict[str, Any]:
    """
    Make a tranche of size n deterministically from a seed.
    Writes problems + tranche_card.json for auditability.
    Returns the tranche card payload.
    """
    os.makedirs(out_dir, exist_ok=True)
    _seed_everything(seed)

    tag = tranche_tag or f"{cfg.name}-{_now_utc_compact()}-seed{seed}"
    config_hash = _hash_config(asdict(cfg))

    problems_dir = os.path.join(out_dir, "problems")
    os.makedirs(problems_dir, exist_ok=True)

    ids = []
    for i in range(n):
        # Derive a per-problem seed to guarantee order-invariant reproducibility
        per_seed = (seed + i * 100003) & 0x7FFFFFFF
        problem_dict = sample_problem(cfg, rng_seed=per_seed)
        pid = f"p{i:05d}"
        ids.append(pid)
        if as_json:
            with open(os.path.join(problems_dir, f"{pid}.json"), "w") as f:
                json.dump(problem_dict, f, indent=2, default=str)
        else:
            # Optional: write pickles or split artifacts if desired
            with open(os.path.join(problems_dir, f"{pid}.json"), "w") as f:
                json.dump(problem_dict, f, indent=2, default=str)

    tranche_card = {
        "tranche_id": tag,
        "config_name": cfg.name,
        "config_hash": config_hash,
        "seed": seed,
        "size": n,
        "problems_relpath": "problems/",
        "problem_format": "json",
        "problem_kwargs": cfg.problem_kwargs,   # echo for transparency
        "created_at": _now_utc_compact(),
        "generator_version": "Problem.py@"  # optionally include git SHA/version
    }
    with open(os.path.join(out_dir, "tranche_card.json"), "w") as f:
        json.dump(tranche_card, f, indent=2, default=str)
    return tranche_card

# Protocol helpers (train/val/test & fresh tests)

@dataclass
class SplitSeeds:
    train: int
    val: int
    test: int

def make_train_val_test_splits(
    cfg: TrancheConfig,
    seeds: SplitSeeds,
    sizes: Tuple[int, int, int],
    base_dir: str,
    as_json: bool = True
) -> Dict[str, Dict[str, Any]]:
    """
    Make train/val/test tranches (by seed). Returns a mapping {split: tranche_card}.
    """
    n_train, n_val, n_test = sizes
    os.makedirs(base_dir, exist_ok=True)

    cards = {}
    cards["train"] = make_tranche(cfg, seeds.train, n_train, os.path.join(base_dir, "train"), as_json)
    cards["val"]   = make_tranche(cfg, seeds.val, n_val, os.path.join(base_dir, "val"), as_json)
    cards["test"]  = make_tranche(cfg, seeds.test, n_test, os.path.join(base_dir, "test"), as_json)
    return cards

def make_fresh_test_tranches(
    cfg: TrancheConfig,
    num_tranches: int,
    start_seed: int,
    size: int,
    base_dir: str,
    as_json: bool = True
) -> List[Dict[str, Any]]:
    """
    After training starts, make K fresh test tranches with new, held-out seeds.
    Seeds are {start_seed + j} to keep things simple and auditable.
    """
    os.makedirs(base_dir, exist_ok=True)
    cards = []
    for n in range(num_tranches):
        seed_n = start_seed + n
        out_n = os.path.join(base_dir, f"test_tranche_{n:02d}")
        card = make_tranche(cfg, seed_n, size, out_n, as_json)
        cards.append(card)
    return cards

# Auditing: attach model outputs
# The following are meant for testing, and benchmarking for later
# ---------------------------------------------------------
# Output layout conventions (must match experiment_analysis)
# ---------------------------------------------------------
# experiment_analysis.py expects, for OP problems:
#   - with_sym=False  -> PATH/{problem_id}_{llm_name}.py
#   - with_sym=True   -> PATH_WITHSYM/{problem_id}_{llm_name}.py
# default PATH is set as data/baseline_0shot and default PATH_WITHSYM is data/baseline_0shot_with_sym
#
# See experiment_analysis.py: op_method/nlp4lp_method & op_compare_problem/nlp4lp_compare_problem
# which scan those folders and filenames. 
#
# baseline_llm_problem_answer.py writes:
#   response.txt, .py, and (for sym) a "solution.json" alongside the .py.  
#
# This module preserves the same layout and filenames so analysis can just ingest them.

def _resolve_dataset_dirs(dataset_tag: str, 
                        with_sym: bool, 
                        output_save_path = os.path.join(os.getcwd(),'data'), 
                        base_leaf = "baseline_0shot", 
                        sym_leaf = "baseline_0shot_with_sym",
                        altbase_leaf="baseline_0shot_nlp4lp",
                        altsym_leaf="baseline_0shot_with_sym_nlp4lp",
                        original_tag="oproblems",
                        alt_tag="nlp4opt"
                        ) -> Tuple[str, str]:
    """
    Map (dataset_tag, with_sym) -> output directory name under ./data/dataset_tag in the default {"oproblems","nlp4opt"}.
    """
    base = output_save_path
    if dataset_tag == original_tag:
        leaf = sym_leaf if with_sym else base_leaf
    elif dataset_tag == alt_tag:
        leaf = altsym_leaf if with_sym else altbase_leaf
    else:
        raise ValueError(f"Unknown dataset_tag: {dataset_tag}")
    out_dir = os.path.join(base, leaf)
    os.makedirs(out_dir, exist_ok=True)
    return base, out_dir

def _extract_from_raw_response(raw: str) -> Tuple[Optional[str], Optional[str]]:
    """
    Lightweight extractor compatible with baseline_llm_problem_answer.py’s output format:
    - JSON sym block inside ```json ... ```
    - Code inside ```python ... ```
    If you already parsed these upstream, just pass them directly and skip raw.
    """
    sym = None
    code = None
    if raw:
        # Extract sym JSON
        if "```json" in raw:
            try:
                sym_blob = raw.split("```json", 1)[1].split("```", 1)[0]
                sym = sym_blob.strip()
            except Exception:
                sym = None
        # Extract python code
        if "```python" in raw:
            try:
                code_blob = raw.split("```python", 1)[1].split("```", 1)[0]
                code = code_blob.strip()
            except Exception:
                code = None
    return sym, code

def save_model_outputs(
    *,
    dataset_tag: str,  # defaults are the oproblems and nlp4opt tags
    with_sym: bool,         
    model_name: str, # e.g., "llama-3.3" | "llama-4" | "gemini" | "gptoss"
    outputs: Dict[str, Dict[str, Any]],
    # outputs[problem_id] may contain any of:
    #   "code": str                    # preferred (already extracted)
    #   "sym": (dict|str)              # optional, will be json-dumped if provided
    #   "raw_response": str            # optional raw text; we will parse if code/sym missing
    #   "response_text": str           # optional; if present we’ll save *_response.txt (for parity)
    #
    # Optional knobs:
    write_response_txt: bool = True,   # keep parity with baseline writer
    ensure_dirs: bool = True,
    output_save_path = os.path.join(os.getcwd(),'data'), 
    base_leaf = "baseline_0shot", 
    sym_leaf = "baseline_0shot_with_sym",
    altbase_leaf="baseline_0shot_nlp4lp",
    altsym_leaf="baseline_0shot_with_sym_nlp4lp",
    original_tag="oproblems",
    alt_tag="nlp4opt"
) -> str:
    """
    Write per-problem artifacts in the exact layout that experiment_analysis.py scans,
    without re-implementing its logic. Returns the output directory used.
    """
    _, out_dir = _resolve_dataset_dirs(dataset_tag, with_sym,output_save_path,base_leaf,sym_leaf,altbase_leaf,altsym_leaf,original_tag,alt_tag)
    if ensure_dirs:
        os.makedirs(out_dir, exist_ok=True)

    for pid, payload in outputs.items():
        # Prefer already-parsed fields; otherwise mine from raw response
        code: Optional[str] = payload.get("code")
        sym_obj = payload.get("sym", None)
        raw: Optional[str] = payload.get("raw_response")
        resp_txt: Optional[str] = payload.get("response_text", raw)

        if (code is None) or (with_sym and sym_obj is None):
            parsed_sym, parsed_code = _extract_from_raw_response(raw or "")
            code = code or parsed_code
            if sym_obj is None and parsed_sym is not None:
                # try to load JSON; if not valid, store as raw string
                try:
                    import json as _json
                    sym_obj = _json.loads(parsed_sym)
                except Exception:
                    sym_obj = parsed_sym  # keep as string

        # ---- Write artifacts matching baseline/analysis expectations ----
        # 1) Python code file used by experiment_analysis.{op_compare_problem,nlp4lp_compare_problem}
        #    Filename MUST be {problem_id}_{llm_name}.py  :contentReference[oaicite:4]{index=4}
        if code:
            py_path = os.path.join(out_dir, f"{pid}_{model_name}.py")
            with open(py_path, "w", encoding="utf-8") as f:
                f.write(code)

        # 2) Optional: raw response text (for parity with baseline writer)  :contentReference[oaicite:5]{index=5}
        if write_response_txt and resp_txt:
            txt_path = os.path.join(out_dir, f"{pid}_{model_name}_response.txt")
            with open(txt_path, "w", encoding="utf-8") as f:
                f.write(resp_txt)

        # 3) Optional: sym JSON (only for with_sym variants; not required by analysis)
        if with_sym and sym_obj is not None:
            # mirror baseline writer but fix JSON dumping (store structured JSON when possible)
            sym_path = os.path.join(out_dir, f"{pid}_{model_name}_solution.json")
            try:
                import json as _json
                if isinstance(sym_obj, str):
                    # try to reparse strings that look like JSON
                    try:
                        sym_obj = _json.loads(sym_obj)
                    except Exception:
                        pass
                with open(sym_path, "w", encoding="utf-8") as f:
                    _json.dump(sym_obj, f, indent=2, ensure_ascii=False)
            except Exception:
                # last resort: write as plain text
                with open(sym_path, "w", encoding="utf-8") as f:
                    f.write(str(sym_obj))

    return out_dir

def run_op_analysis(with_sym: bool) -> dict:
    """
    Thin wrapper that imports experiment_analysis.op_method (no duplication).
    Returns the nested results dict used elsewhere in your pipeline.  :contentReference[oaicite:6]{index=6}
    """
    return op_method(with_sym)  # scans the same folders we just wrote

def run_nlp4opt_analysis(with_sym: bool) -> dict:
    """
    Thin wrapper for NL4OPT analysis path (if used).  :contentReference[oaicite:7]{index=7}
    """
    return nlp4lp_method(with_sym)