#!/usr/bin/env python3
import argparse
import os
import json
import torch
import dill
from counting_sampling_v2 import DEVICE, DTYPE, get_random_generator_and_seed, initialize_weighted_automaton


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Arguments for weighting the automaton")
    parser.add_argument("--input_dir", type=str, default="data")
    parser.add_argument("--output_dir", type=str, default="data")
    parser.add_argument("--accept_prob", type=float)

    parser.add_argument("--seed", type=int, default=None,
                        help="Seed for weighting the automaton")
  
    return parser.parse_args()


def main(args=None):
    if args is None:
        args = parse_args()
    print(args)

    # Load sampling parameters from previous step
    sample_args_path = os.path.join(args.input_dir, "sample_args.json")
    with open(sample_args_path, "r") as f:
        sample_args = json.load(f)
    
    # Load seed information from previous step
    seed_info_path = os.path.join(args.input_dir, "seed_info.json")
    with open(seed_info_path, "r") as f:
        seed_info = json.load(f)
    
    # Determine the seed to use for weighting
    seed = None
    if args.seed is not None:
        seed = args.seed
    
    # Update seed information with weight seed
    seed_info["seed"] = seed
    
    # Load the automaton from the previous step
    automaton_path = os.path.join(args.input_dir, "automaton.pkl")
    with open(automaton_path, "rb") as f:
        automaton = dill.load(f)
    
    generator, _ = get_random_generator_and_seed(seed)
    weighted_automaton = initialize_weighted_automaton(
        automaton,
        DTYPE,
        DEVICE,
        accept_prob=args.accept_prob,
        generator=generator,
    )
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Save the lifted automaton
    weighted_automaton_path = os.path.join(args.output_dir, "weighted_automaton.pkl")
    with open(weighted_automaton_path, "wb") as f:
        dill.dump(weighted_automaton, f)
        
    # Save updated seed information
    seed_info_path = os.path.join(args.output_dir, "seed_info.json")
    with open(seed_info_path, "w") as f:
        json.dump(seed_info, f)
    
    # Save updated sample args
    sample_args_path = os.path.join(args.output_dir, "sample_args.json")
    with open(sample_args_path, "w") as f:
        json.dump(sample_args, f)
    
    print(f"Weighted automaton prepared and saved to {args.output_dir}")
    

if __name__ == "__main__":
    main()