import argparse
from pathlib import Path
import traceback
from types import SimpleNamespace

from llm_utils.textgen_api.textgen_api import TextGenApi
from construct_action_models import main as construct_action_models_main
from correct_action_models import main as correct_action_models_main


def main(args):
    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    textgen_api = TextGenApi.default(args.llm)
    for i in range(1, args.n_samples + 1):
        textgen_api.usage.reset()
        sample_dir = out_dir / f"sample-{i}"
        details_dir = sample_dir / "details"
        if details_dir.is_dir() and len(list(details_dir.iterdir())) > 0:
            print(f"Already processed {sample_dir}")
        else:
            sample_dir.mkdir(parents=True, exist_ok=True)
            print(f"Processing {sample_dir}")
            n_args = SimpleNamespace(engine=args.llm, idx=i, domain=args.domain)
            try:
                # Construct action models
                construct_action_models_main(n_args, textgen_api=textgen_api, domain=args.domain, out_dir=details_dir)
                # Correct action models
                correct_action_models_main(n_args, textgen_api=textgen_api, domain=args.domain, out_dir=details_dir)
                # Test planning
                # test_planning_main(args, domain=args.domain, out_dir=details_dir)
            except Exception as e:
                print(f"Error processing {sample_dir}: {e}")
                error = str(e)
                error += traceback.format_exc()
                (sample_dir / "error.txt").write_text(error)
                continue
            finally:
                (sample_dir / "textgen-api-usage.json").write_text(textgen_api.usage.to_dumps())


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Construct action models using LLMs.")
    parser.add_argument("--llm", type=str, required=True)
    parser.add_argument("--out-dir", type=str, required=True)
    parser.add_argument("--domain", type=str, required=True)
    parser.add_argument("--n-samples", type=int, required=True)
    main(parser.parse_args())
