#!/usr/bin/env python3
import argparse
import os
import json
import torch
import dill
from counting_sampling_v2 import SamplerFactory, DEVICE, DTYPE, TGT_SYMBOL


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Arguments for lifting the weighted automaton")
    parser.add_argument("--input_dir", type=str, default="data")
    parser.add_argument("--output_dir", type=str, default="data")
    parser.add_argument("--max_occ_count", type=int, default=2000)
    parser.add_argument("--seed", type=int, default=None,
                        help="General seed (for backwards compatibility)")
    parser.add_argument("--at_least_once_semiring", action="store_true", default=False)
    return parser.parse_args()


def main(args=None):
    if args is None:
        args = parse_args()
    print(args)

    # Load sample args
    sample_args_path = os.path.join(args.input_dir, "sample_args.json")
    with open(sample_args_path, "r") as f:
        sample_args = json.load(f)
    
    # Load intervention info
    intervention_info_path = os.path.join(args.input_dir, "intervention_info.json")
    with open(intervention_info_path, "r") as f:
        intervention_info = json.load(f)
    
    # Load seed information from previous step
    seed_info_path = os.path.join(args.input_dir, "seed_info.json")
    with open(seed_info_path, "r") as f:
        seed_info = json.load(f)
    
    intervention_seed = args.seed
    
    # Update seed information with intervention seed
    seed_info["intervention_seed"] = intervention_seed
    
    # Get the intervention type and targets
    intervention_type = intervention_info["intervention_type"]
    target_state = intervention_info["target_state"]
    target_transition = intervention_info["target_transition"]
    target_symbol = intervention_info["target_symbol"]
    
    # Load the lifted automaton from the previous step
    lifted_automaton_path = os.path.join(args.input_dir, "lifted_automaton.pkl")
    with open(lifted_automaton_path, "rb") as f:
        lifted_automaton = dill.load(f)

    # Create the sampler from the lifted automaton
    sampler = SamplerFactory.create_sampler(
        lifted_automaton=lifted_automaton,
        tgt_symbol=target_symbol if intervention_type == "symbol" else TGT_SYMBOL,
        dtype=DTYPE,
        device=DEVICE,
        seed=intervention_seed,
        num_states=sample_args.get("num_states", 20),
        alphabet_size=sample_args.get("alphabet_size", 10),
        max_occ_count=args.max_occ_count,
        accept_prob=sample_args.get("accept_prob", 0.1),
        automaton_name=sample_args.get("automaton_name"),
        tgt_arc=(intervention_type == "arc"),
        tgt_state=(intervention_type == "state"),
        tgt_arc_idx=target_transition,
        tgt_state_idx=target_state,
        at_least_once_semiring=args.at_least_once_semiring
    )
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Save the sampler for the next step
    sampler_path = os.path.join(args.output_dir, "sampler.pkl")
    with open(sampler_path, "wb") as f:
        dill.dump(sampler, f)
    
    # Save sampling parameters
    sampling_params = {
        "intervention_type": intervention_type,
        "target_state": target_state,
        "target_transition": target_transition,
        "target_symbol": target_symbol
    }
    
    sampling_params_path = os.path.join(args.output_dir, "sampling_params.json")
    with open(sampling_params_path, "w") as f:
        json.dump(sampling_params, f)
    
    # Save updated seed information
    seed_info_path = os.path.join(args.output_dir, "seed_info.json")
    with open(seed_info_path, "w") as f:
        json.dump(seed_info, f)
    
    print(f"Pushed weighted automaton prepared and saved to {args.output_dir}")


if __name__ == "__main__":
    main()