#!/usr/bin/env python3
import argparse
import os
import json
import torch
import dill
from counting_sampling_v2 import (InterventionPreparation, SamplerFactory, DEVICE, 
                            DTYPE, TGT_SYMBOL, get_random_generator_and_seed)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Arguments for lifting the weighted automaton")
    parser.add_argument("--input_dir", type=str)
    parser.add_argument("--output_dir", type=str)
    
    parser.add_argument("--intervention_count", type=int, default=100)
    parser.add_argument("--validation_num_occurrences", type=int, default=100)
    
    parser.add_argument("--intervention_type", type=str, default="arc",
                        help="arc, symbol, state, or vanilla")
    parser.add_argument("--target_state", type=int, default=None)
    parser.add_argument("--target_transition", type=int, default=None)
    parser.add_argument("--target_symbol", type=int, default=None)

    parser.add_argument("--intervention_seed", type=int, default=None,
                        help="Seed for intervention process")
    parser.add_argument("--seed", type=int, default=None,
                        help="General seed (for backwards compatibility)")
    parser.add_argument("--at_least_once_semiring", action="store_true", default=False)
    return parser.parse_args()


def main(args=None):
    if args is None:
        args = parse_args()
    print(args)
    
    # Load the automaton from the previous step
    automaton_path = os.path.join(args.input_dir, "weighted_automaton.pkl")
    if not os.path.exists(automaton_path):
        # could be
        automaton_path = os.path.join(args.input_dir, "automaton.pkl")
    with open(automaton_path, "rb") as f:
        automaton = dill.load(f)

    # Load sample args
    sample_args_path = os.path.join(args.input_dir, "sample_args.json")
    with open(sample_args_path, "r") as f:
        sample_args = json.load(f)

    seed_info = {}
    seed_info["seed"] = args.seed

    # Create the intervention preparation object
    prep = InterventionPreparation(
        automaton,
        args.intervention_count,
        seed=args.seed,
        at_least_once_semiring=args.at_least_once_semiring
    )
    
    intervention_type = args.intervention_type

    # Prepare the automaton based on intervention type
    if intervention_type == "symbol":
        lifted_automaton = prep.prepare_for_symbol(args.target_symbol)
        target_description = "symbol" if args.target_symbol is None else f"symbol_{args.target_symbol}"
    elif intervention_type == "arc":
        lifted_automaton = prep.prepare_for_arc(args.target_transition)
        target_description = "arc" if args.target_transition is None else f"arc_{args.target_transition}"
    elif intervention_type == "state":
        lifted_automaton = prep.prepare_for_state(args.target_state)
        target_description = "state" if args.target_state is None else f"state_{args.target_state}"
    else:  # vanilla
        lifted_automaton = prep.prepare_vanilla()
        target_description = "vanilla"

    # Store the intervention type and targets
    intervention_info = {
        "intervention_type": intervention_type,
        "target_state": args.target_state,
        "target_transition": args.target_transition,
        "target_symbol": args.target_symbol,
        "target_description": target_description
    }
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Save the lifted automaton
    lifted_automaton_path = os.path.join(args.output_dir, "lifted_automaton.pkl")
    with open(lifted_automaton_path, "wb") as f:
        dill.dump(lifted_automaton, f)
    
    # Save intervention info
    intervention_info_path = os.path.join(args.output_dir, "intervention_info.json")
    with open(intervention_info_path, "w") as f:
        json.dump(intervention_info, f)
    
    # Save updated seed information
    seed_info_path = os.path.join(args.output_dir, "seed_info.json")
    with open(seed_info_path, "w") as f:
        json.dump(seed_info, f)
    
    # Save updated sample args
    sample_args_path = os.path.join(args.output_dir, "sample_args.json")
    with open(sample_args_path, "w") as f:
        json.dump(sample_args, f)
    
    print(f"Lifted weighted automaton prepared and saved to {args.output_dir}")


if __name__ == "__main__":
    main()