import torch
import numpy as np
import math
import random
from lbrayuela.string_sampling.finite_automaton_weight_pushing import (
    push_finite_automaton_weights
)
import tqdm
from collections import defaultdict
from utils import *
from vanilla_sampler import vanilla_sampler

# Keep all the original constants
SEED = 0
NUMBER_OF_STATES = 100
MAX_OCC_COUNT = 1000
ALPHABET_SIZE = 10

if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device("cpu")
    
DTYPE = torch.float32
TGT_SYMBOL = 0


def get_random_generator_and_seed(random_seed):
    random_seed = get_random_seed(random_seed)
    return random.Random(random_seed), random_seed


def get_random_seed(random_seed):
    return random.getrandbits(32) if random_seed is None else random_seed


class AutomatonBuilder:
    """Class responsible for building the basic automaton topology"""
    
    def __init__(self, num_states=NUMBER_OF_STATES, alphabet_size=ALPHABET_SIZE, 
                 seed=SEED, accept_prob=0.1, automaton_name=None):
        """
        Initialize the builder with parameters for automaton generation
        
        Args:
            num_states: Number of states in the automaton
            alphabet_size: Size of the alphabet
            seed: Random seed for reproducibility
            accept_prob: Probability of a state being an accept state
            automaton_name: Name of a predefined automaton from the register
        """
        self.num_states = num_states
        self.alphabet_size = alphabet_size
        self.seed = seed
        self.accept_prob = accept_prob
        self.automaton_name = automaton_name
        self.generator, _ = get_random_generator_and_seed(seed)
        
    def build(self):
        """Generate the basic automaton topology"""
        if self.automaton_name is not None:
            from automata_register import AUTOMATA_REGISTER  
            automaton = AUTOMATA_REGISTER.get(self.automaton_name, None)
            if automaton is None:
                raise ValueError(f"Automaton '{self.automaton_name}' not in registry.")
        else:
            automaton = sample_finite_automaton(
                self.num_states, self.alphabet_size, self.generator, self.accept_prob)
        return automaton


class InterventionPreparation:
    """Class responsible for preparing a weighted automaton for a specific intervention"""
    
    def __init__(self, weighted_automaton, max_count, seed=None, at_least_once_semiring=False):
        """
        Initialize with a pre-weighted automaton
        
        Args:
            weighted_automaton: A weighted automaton that already has weights assigned
            seed: Random seed for generating random targets
        """
        # Validate that we received a weighted automaton, not just a finite automaton
        if not hasattr(weighted_automaton, 'transition_weights') or not weighted_automaton.transition_weights():
            raise ValueError("The provided automaton must be a WeightedFiniteAutomaton with weights already assigned")
            
        self.weighted_automaton = weighted_automaton
        self.max_count = max_count
        self.seed = seed if seed is not None else 0
        self.generator, _ = get_random_generator_and_seed(self.seed)
        self.at_least_once_semiring = at_least_once_semiring
    
    def prepare_for_symbol(self, tgt_symbol=TGT_SYMBOL):
        """Configure automaton for symbol intervention"""
        # Create a copy of the weighted automaton to avoid modifying the original
        #weighted_automaton_copy = self.weighted_automaton.copy()
        
        # Apply symbol targeting
        return modify_weights_for_target(
            self.weighted_automaton,
            self.max_count,
            target_symbol=tgt_symbol,
            generator=self.generator,
            at_least_once_semiring=self.at_least_once_semiring
        )
    
    def prepare_for_arc(self, tgt_arc_idx=None):
        """Configure automaton for arc intervention"""
        # Create a copy of the weighted automaton to avoid modifying the original
        #weighted_automaton_copy = self.weighted_automaton.copy()

        # Apply arc targeting
        return modify_weights_for_target(
            self.weighted_automaton,
            self.max_count,
            target_arc=True,
            tgt_arc_idx=tgt_arc_idx,
            generator=self.generator,
            at_least_once_semiring=self.at_least_once_semiring
        )
    
    def prepare_for_state(self, tgt_state_idx=None):
        """Configure automaton for state intervention"""
        # Create a copy of the weighted automaton to avoid modifying the original
        #weighted_automaton_copy = self.weighted_automaton.copy()
        
        # Apply state targeting
        return modify_weights_for_target(
            self.weighted_automaton,
            self.max_count,
            target_state=True,
            tgt_state_idx=tgt_state_idx,
            generator=self.generator,
            at_least_once_semiring=self.at_least_once_semiring
        )
        
    
    def prepare_vanilla(self):
        """
        Return a copy of the automaton without specific intervention target
        This just returns a copy of the original weighted automaton
        """
        # Apply targeting, doesnt matter for vanilla sampling anyways
        return modify_weights_for_target(
            self.weighted_automaton,
            self.max_count,
            target_symbol=0,
            generator=self.generator
        )


class SamplerFactory:
    """Factory class to create samplers from lifted automata"""
    
    @staticmethod
    def create_sampler(lifted_automaton, tgt_symbol=None, dtype=DTYPE, 
                       device=DEVICE, seed=SEED, num_states=NUMBER_OF_STATES, 
                       alphabet_size=ALPHABET_SIZE, max_occ_count=MAX_OCC_COUNT,
                       accept_prob=0.1, automaton_name=None, tgt_arc=False, 
                       tgt_state=False, tgt_arc_idx=None, tgt_state_idx=None, at_least_once_semiring=False):
        """
        Create a ready-to-use sampler from a lifted automaton
        
        Args:
            lifted_automaton: Automaton prepared for the specific intervention
            tgt_symbol: Target symbol for intervention (if applicable)
            dtype: Data type for weights
            device: Device to use for tensor operations
            seed: Random seed
            num_states: Number of states in the automaton
            alphabet_size: Size of the alphabet
            max_occ_count: Maximum occurrence count
            accept_prob: Probability of a state being an accept state
            automaton_name: Name of the automaton from the register
            tgt_arc: Whether arc intervention is targeted
            tgt_state: Whether state intervention is targeted
            tgt_arc_idx: Index of target arc (if applicable)
            tgt_state_idx: Index of target state (if applicable)
        """
        # Prepare the pushed automaton for sampling
        pushed_automaton = push_finite_automaton_weights(lifted_automaton, dtype, device)
        
        # Copy target information from lifted to pushed automaton
        pushed_automaton.target_transition = lifted_automaton.target_transition
        pushed_automaton.target_state = lifted_automaton.target_state
        
        # Override the sampling function
        if at_least_once_semiring:
            pushed_automaton.counting_sample = counting_sample_alo
        else:    
            pushed_automaton.counting_sample = counting_sample

        # Create and return the sampler
        return InterventionSampling(
            num_states=num_states,
            alphabet_size=alphabet_size,
            max_occ_count=max_occ_count,
            seed=seed,
            device=device,
            dtype=dtype,
            tgt_symbol=tgt_symbol,
            tgt_arc=tgt_arc,
            tgt_state=tgt_state,
            accept_prob=accept_prob,
            automaton_name=automaton_name,
            tgt_arc_idx=tgt_arc_idx,
            tgt_state_idx=tgt_state_idx,
            lifted_automaton=lifted_automaton,
            pushed_automaton=pushed_automaton,
            at_least_once_semiring=at_least_once_semiring
        )


class InterventionSampling:
    """Main class for intervention sampling, refactored to support modular approach"""
    
    def __init__(self, num_states=NUMBER_OF_STATES, alphabet_size=ALPHABET_SIZE, 
                 max_occ_count=MAX_OCC_COUNT, seed=SEED, device=DEVICE, dtype=DTYPE,
                 tgt_symbol=None, tgt_arc=False, tgt_state=False, accept_prob=0.1,
                 automaton_name=None, tgt_arc_idx=None, tgt_state_idx=None,
                 lifted_automaton=None, pushed_automaton=None, at_least_once_semiring=False):
        """
        Initialize the sampler with parameters and optionally pre-configured automata
        
        This constructor supports both the original approach (building automata internally)
        and the new modular approach (receiving pre-configured automata)
        """
        self.num_states = num_states
        self.alphabet_size = alphabet_size
        # Ensure max_occ_count is at least 2 to avoid utils.py:sample_weights error
        self.max_occ_count = max(2, max_occ_count)
        self.seed = seed
        self.device = device
        self.dtype = dtype
        self.tgt_symbol = tgt_symbol
        self.tgt_arc = tgt_arc
        self.tgt_state = tgt_state
        self.accept_prob = accept_prob
        self.tgt_arc_idx = tgt_arc_idx
        self.tgt_state_idx = tgt_state_idx
        self.automaton_name = automaton_name

        self.generator, _ = get_random_generator_and_seed(seed)
        self.torch_generator = torch.Generator()
        self.torch_generator.manual_seed(seed)

        self.at_least_once_semiring = at_least_once_semiring
        
        # Support both approaches: pre-configured automata or build internally
        if lifted_automaton is not None and pushed_automaton is not None:
            self.A_l = lifted_automaton
            self.A_p = pushed_automaton
        else:
            self._setup_machine()
        
    def set_seed(self, seed):
        self.seed = seed
        self.generator, _ = get_random_generator_and_seed(seed) 
        self.torch_generator = torch.Generator()
        self.torch_generator.manual_seed(seed)

    def _get_raw_machine(self):
        """Get the raw automaton (either from registry or generated)"""
        if self.automaton_name is not None:
            from automata_register import AUTOMATA_REGISTER  
            base_automaton = AUTOMATA_REGISTER.get(self.automaton_name, None)
            if base_automaton is None:
                raise ValueError(f"Automaton '{self.automaton_name}' not in registry.")
        else:
            base_automaton = sample_finite_automaton(self.num_states, self.alphabet_size, 
                                                   self.generator, self.accept_prob)
        return base_automaton

    def _setup_machine(self):
        """Set up the automaton internally (original approach)"""
        self.A = self._get_raw_machine()

        # Lifts based on the target type, and samples the intervention sites!
        # This also adds properties to the machine, such as the target state or arc
        # A_l.target_transition and A_l.target_state
        self.A_l = lift_finite_automaton_random_weights(self.A, self.max_occ_count,
            self.dtype, self.device, self.tgt_symbol, self.tgt_arc,
            self.tgt_state, tgt_arc_idx=self.tgt_arc_idx, tgt_state_idx=self.tgt_state_idx,
            accept_prob=self.accept_prob, generator=self.generator)
        
        # does the backwards calculations! i.e. the pathsums 
        self.A_p = push_finite_automaton_weights(self.A_l, self.dtype, self.device)

        # move the target_transition and target_state to the push machine
        # very ugly, but does the job
        self.A_p.target_transition = self.A_l.target_transition
        self.A_p.target_state = self.A_l.target_state

        # override sampling function -- ugly but does the job
        self.A_p.counting_sample = counting_sample

    def resize(self, num_count):
        """Resize the automaton to handle a different occurrence count"""
        # store for new machine
        target_state = self.A_l.target_state
        target_transition = self.A_l.target_transition

        self.max_occ_count = num_count
        w_size = num_count + 1
        self.A_l._semiring.size = num_count + 1

        # we loop over weights in A_p and resize them
        for state in self.A_l._accept_weights:
            w = self.A_l._accept_weights[state]
            n_w = torch.full((w_size,), -math.inf, dtype=self.dtype, device=self.device)
            n_w[0] = w[0]
            n_w[1] = w[1]
            self.A_l._accept_weights[state] = n_w
        for transition in self.A_l._transitions:
            w_size = num_count + 1
            w = self.A_l._transitions[transition]
            n_w = torch.full((w_size,), -math.inf, dtype=self.dtype, device=self.device)
            n_w[0] = w[0]
            n_w[1] = w[1]
            self.A_l._transitions[transition] = n_w

        self.A_p = push_finite_automaton_weights(self.A_l, self.dtype, self.device)
        # reset these properties
        self.A_p.target_transition = target_transition
        self.A_p.target_state = target_state
        # override sampling function -- ugly but does the job
        self.A_p.counting_sample = counting_sample 

    def sample(self, tgt_count):
        """Sample a string with exactly tgt_count occurrences of the target"""
        sampled_string = self.A_p.counting_sample(self.A_p, tgt_count, self.generator, 
                                                include_log_probability=True, 
                                                include_next_symbols=False, 
                                                tgt_symbol=self.tgt_symbol)
      
        if not self.tgt_arc and not self.tgt_state:
            # Symbol intervention test
            num_occurrences = 0
            for symbol in sampled_string["sampled_string"]:
                if symbol == self.tgt_symbol:
                    num_occurrences += 1

            # Only test this if we are sampling symbols!!
            #assert num_occurrences == tgt_count, f"Expected {tgt_count} occurrences, got {num_occurrences}"
        
        return sampled_string

    def sample_interventions(self, K, N, pbar=False):
        """Sample K strings with a total of N interventions"""
        Z = self.A_p.total_length_weights
        # count_targets is the list of occurrences
        get_exact_counts_max = 10
        tries = 0
        
        torchseed = self.generator.getrandbits(64)
        torch_gen = torch.Generator()
        torch_gen.manual_seed(torchseed)
        while tries < get_exact_counts_max:
            count_targets = get_sample_lengths(Z, K, N, torch_gen, at_least_once_semiring=self.at_least_once_semiring)
            if sum(count_targets) == N:
                break
            tries += 1

        used_counts = []
        samples = []
        seen_occs = 0

        if pbar:
            count_targets = tqdm.tqdm(count_targets)

        for tgt_count in count_targets:
            # we try to sample tgt_count occurrences in a single string
            #try:
            sampled_string = self.sample(tgt_count)
            #if tgt_count:
            #    breakpoint()
            #except:
            #    #print(f"Could not sample for tgt_count {tgt_count}")
            #    continue
            samples.append(sampled_string)
            used_counts.append(tgt_count)
            seen_occs += tgt_count

        if 0 < seen_occs < N:
            missing_samples = []
            tries_left = 10
            while missing_samples == [] and tries_left > 0:
                tries_left -= 1
                try:
                    # the recursion can fail for a few reasons, but this is a bit better
                    missing_samples, new_counts = self.sample_interventions(K - len(samples), N - seen_occs)
                except:
                    missing_samples = []
                    new_counts = []

                samples = samples + missing_samples
                used_counts = used_counts + new_counts

        print(f" | Sample targets K={K}, N={N}. Saw K={len(samples)}, N={seen_occs}, sum count targets: {sum(count_targets)}")
        return samples, used_counts

    def sample_original(self, num_samples=1, tgt_arc=None, tgt_state=None, pbar=False):
        """Sample strings without interventions"""
        samples = []
        samp_rng = range(num_samples)
        if pbar:
            samp_rng = tqdm.tqdm(samp_rng)
        for _ in samp_rng:
            sample = vanilla_sampler(self.A_l, self.torch_generator, 
                                    include_log_probability=True, 
                                    include_next_symbols=False)
            samples.append(sample)
        return samples


# Keep all the original sampling functions with the same API
def sample_symbol(K, N, num_states=NUMBER_OF_STATES, seed=SEED, accept_prob=0.1, 
                  name=None, tgt_symbol=TGT_SYMBOL, alphabet_size=ALPHABET_SIZE):
    """Sample strings with targeted symbol occurrences - original API"""
    if tgt_symbol is None:
        tgt_symbol = TGT_SYMBOL
    
    # APPROACH 1: Original monolithic approach
    # Ensure max_occ_count is at least 2
    max_occ_count = max(2, N+1)
    sampler = InterventionSampling(num_states, alphabet_size, max_occ_count, automaton_name=name, 
                                  seed=seed, device=DEVICE, dtype=DTYPE, 
                                  tgt_symbol=tgt_symbol, accept_prob=accept_prob)
    
    # For modular approach (commented out):
    # builder = AutomatonBuilder(num_states, alphabet_size, seed, accept_prob, name)
    # automaton = builder.build()
    # prep = InterventionPreparation(automaton, N+1, DTYPE, DEVICE, seed, accept_prob)
    # lifted_automaton = prep.prepare_for_symbol(tgt_symbol)
    # sampler = SamplerFactory.create_sampler(
    #     lifted_automaton, tgt_symbol, DTYPE, DEVICE, seed,
    #     num_states, alphabet_size, N+1, accept_prob, name)
    
    samples, counts = sampler.sample_interventions(K, N, pbar=True)

    occs = []
    lengths = []
    for sample in samples:
        string = sample["sampled_string"]
        occs.append(string.count(tgt_symbol))
        lengths.append(len(string))
    avg_occ = sum(occs) / len(occs) if occs else 0
    stdv_occ = np.std(occs)
    avg_length = sum(lengths) / len(lengths) if lengths else 0
    stdv_length = np.std(lengths)

    print(f"Total symbols: {sum(lengths)}")
    print(f"Total samples: {len(samples)}")
    print(f"Total occurrences: {sum(occs)}")
    print(f"Average occurrences: {avg_occ}, stdv: {stdv_occ}")
    print(f"Average length: {avg_length}, stdv: {stdv_length}")

    stats = {
        "total_symbols": sum(lengths)
    }

    return samples, sampler, counts, stats


def sample_vanilla(K, N, num_states=NUMBER_OF_STATES, seed=SEED, accept_prob=0.1, 
                  name=None, alphabet_size=ALPHABET_SIZE):
    """Sample strings without interventions - original API"""
    # APPROACH 1: Original monolithic approach
    # Ensure max_occ_count is at least 2
    max_occ_count = max(2, N+1)
    sampler = InterventionSampling(num_states, alphabet_size, max_occ_count, automaton_name=name, 
                                  seed=seed, device=DEVICE, dtype=DTYPE, 
                                  tgt_symbol=TGT_SYMBOL, accept_prob=accept_prob)
    
    # For modular approach (commented out):
    # builder = AutomatonBuilder(num_states, alphabet_size, seed, accept_prob, name)
    # automaton = builder.build()
    # prep = InterventionPreparation(automaton, N+1, DTYPE, DEVICE, seed, accept_prob)
    # lifted_automaton = prep.prepare_vanilla()
    # sampler = SamplerFactory.create_sampler(
    #     lifted_automaton, None, DTYPE, DEVICE, seed,
    #     num_states, alphabet_size, N+1, accept_prob, name)
    
    original_samples = sampler.sample_original(num_samples=K)
    sample_strings, sample_weights = [], []
    states, transitions = [], []
    for sample in original_samples:
        sample_strings.append(sample["sampled_string"])
        sample_weights.append(sample["log_probs"])
        states.append(sample["states"])
        transitions.append(sample["transitions"])

    print("----")
    print("Total original samples: ", len(original_samples))
    orig_lengths = []

    symbol_hist = defaultdict(int)
    transition_hist = defaultdict(int)

    for idx, sample in enumerate(sample_strings):
        for symbol in sample:
            symbol_hist[symbol] += 1
        for trans in transitions[idx]:
            transition_hist[trans] += 1
        orig_lengths.append(len(sample))

    occ_counts = list(symbol_hist.values())
    mean_occ = sum(occ_counts) / len(occ_counts) / len(original_samples) if occ_counts and original_samples else 0
    stdv_occ = np.std(occ_counts) / len(original_samples) if occ_counts and original_samples else 0

    trans_counts = list(transition_hist.values())
    mean_trans = sum(trans_counts) / len(trans_counts) / len(original_samples) if trans_counts and original_samples else 0
    stdv_trans = np.std(trans_counts) / len(original_samples) if trans_counts and original_samples else 0

    orig_avg_length = sum(orig_lengths) / len(orig_lengths) if orig_lengths else 0
    orig_stdv_length = np.std(orig_lengths)

    mean_symbol_occ = sum(occ_counts) / len(occ_counts) if occ_counts else 0
    mean_trans_occ = sum(trans_counts) / len(trans_counts) if trans_counts else 0

    print(f"Total symbols: {sum(occ_counts)}")
    print(f"Average symbol occurrences: {mean_symbol_occ}")
    print(f"Average symbol occurrences per string: {mean_occ}, stdv: {stdv_occ}")
    print(f"Average transitions: {mean_trans_occ}")
    print(f"Average transitions occurrences: {mean_trans}, stdv: {stdv_trans}")
    print(f"Average length: {orig_avg_length}, stdv: {orig_stdv_length}")

    stats = {
        "total_symbols": sum(occ_counts) if occ_counts else 0
    }

    return original_samples, sampler, stats


def sample_arc(K, N, num_states=NUMBER_OF_STATES, seed=SEED, accept_prob=0.1, 
              name=None, tgt_transition=None, alphabet_size=ALPHABET_SIZE):
    """Sample strings with targeted arc occurrences - original API"""
    # APPROACH 1: Original monolithic approach
    sampler = InterventionSampling(num_states, alphabet_size, MAX_OCC_COUNT, 
                                  automaton_name=name, seed=seed,
                                  device=DEVICE, dtype=DTYPE, tgt_arc=True,
                                  accept_prob=accept_prob, tgt_arc_idx=tgt_transition)
    
    # For modular approach (commented out):
    # builder = AutomatonBuilder(num_states, alphabet_size, seed, accept_prob, name)
    # automaton = builder.build()
    # prep = InterventionPreparation(automaton, MAX_OCC_COUNT, DTYPE, DEVICE, seed, accept_prob)
    # lifted_automaton = prep.prepare_for_arc(tgt_transition)
    # sampler = SamplerFactory.create_sampler(
    #     lifted_automaton, None, DTYPE, DEVICE, seed,
    #     num_states, alphabet_size, MAX_OCC_COUNT, accept_prob, name,
    #     tgt_arc=True, tgt_arc_idx=tgt_transition)
    
    samples, counts = sampler.sample_interventions(K, N, pbar=True) 
    
    lengths = []
    for sample in samples:
        lengths.append(len(sample["sampled_string"]))

    avg_length = sum(lengths) / len(lengths) if lengths else 0
    stdv_length = np.std(lengths)
    print(f"Total symbols: {sum(lengths)}")
    print(f"Total samples: {len(samples)}")
    print(f"Average length: {avg_length}, stdv: {stdv_length}")
    print(f"Total occurrences: {sum(counts)}")
    stdv_occ = np.std(counts)
    print(f"Average occurrences: {sum(counts) / len(counts) if counts else 0}, stdv: {stdv_occ}")

    stats = {
        "total_symbols": sum(lengths)
    }

    return samples, sampler, counts, stats


def sample_state(K, N, num_states=NUMBER_OF_STATES, seed=SEED, accept_prob=0.1, 
                name=None, tgt_state=None, alphabet_size=ALPHABET_SIZE):
    """Sample strings with targeted state occurrences - original API"""
    # APPROACH 1: Original monolithic approach
    sampler = InterventionSampling(num_states, alphabet_size, MAX_OCC_COUNT, 
                                  automaton_name=name, seed=seed, device=DEVICE, 
                                  dtype=DTYPE, tgt_symbol=TGT_SYMBOL, tgt_state=True, 
                                  accept_prob=accept_prob, tgt_state_idx=tgt_state)
    
    # For modular approach (commented out):
    # builder = AutomatonBuilder(num_states, alphabet_size, seed, accept_prob, name)
    # automaton = builder.build()
    # prep = InterventionPreparation(automaton, MAX_OCC_COUNT, DTYPE, DEVICE, seed, accept_prob)
    # lifted_automaton = prep.prepare_for_state(tgt_state)
    # sampler = SamplerFactory.create_sampler(
    #     lifted_automaton, TGT_SYMBOL, DTYPE, DEVICE, seed,
    #     num_states, alphabet_size, MAX_OCC_COUNT, accept_prob, name,
    #     tgt_state=True, tgt_state_idx=tgt_state)
    
    samples, counts = sampler.sample_interventions(K, N, pbar=True)

    lengths = []
    for sample in samples:
        lengths.append(len(sample["sampled_string"]))
    
    if len(lengths):
        avg_length = sum(lengths) / len(lengths)
    else:
        avg_length = 0
    stdv_length = np.std(lengths)

    print(f"Total symbols: {sum(lengths)}")
    print(f"Total samples: {len(samples)}")
    print(f"Average length: {avg_length}, stdv: {stdv_length}")
    print(f"Total occurrences: {sum(counts)}")
    stdv_occ = np.std(counts)
    print(f"Average occurrences: {sum(counts) / len(counts) if counts else 0}, stdv: {stdv_occ}")

    stats = {
        "total_symbols": sum(lengths)
    }

    return samples, sampler, counts, stats


def to_rayuela(sampler):
    """Convert sampler to rayuela format - original API"""
    machine = sampler.A_l 
    fsa = to_rayuela_fsa(machine)
    return fsa


# Function for modular sampling approach
def modular_sample(intervention_type, K, N, num_states=NUMBER_OF_STATES, seed=SEED, 
                  accept_prob=0.1, automaton_name=None, target_value=None, 
                  alphabet_size=ALPHABET_SIZE):
    """
    Sample strings using the modular approach
    
    Args:
        intervention_type: "symbol", "arc", "state", or "vanilla"
        K: Number of strings to sample
        N: Total number of interventions
        num_states: Number of states in the automaton
        seed: Random seed
        accept_prob: Probability of a state being an accept state
        automaton_name: Name of predefined automaton
        target_value: Value to target (symbol, arc index, or state index)
        alphabet_size: Size of the alphabet
    
    Returns:
        samples, sampler, counts (or stats for vanilla), stats
    """
    # Step 1: Build the automaton
    builder = AutomatonBuilder(num_states, alphabet_size, seed, accept_prob, automaton_name)
    automaton = builder.build()
    
    # Step 2: Prepare for the specific intervention
    # Ensure max_count is at least 2
    max_count = max(2, N+1) if intervention_type == "symbol" else MAX_OCC_COUNT
    
    generator, _ = get_random_generator_and_seed(seed)
    weighted_automaton = initialize_weighted_automaton(
        automaton,
        DTYPE,
        DEVICE,
        accept_prob=accept_prob,
        generator=generator,
    )
 
    prep = InterventionPreparation(weighted_automaton, max_count, seed)
    
    if intervention_type == "symbol":
        lifted_automaton = prep.prepare_for_symbol(target_value or TGT_SYMBOL)
        tgt_symbol = target_value or TGT_SYMBOL
        tgt_arc = False
        tgt_state = False
        tgt_arc_idx = None
        tgt_state_idx = None
    elif intervention_type == "arc":
        lifted_automaton = prep.prepare_for_arc(target_value)
        tgt_symbol = None
        tgt_arc = True
        tgt_state = False
        tgt_arc_idx = target_value
        tgt_state_idx = None
    elif intervention_type == "state":
        lifted_automaton = prep.prepare_for_state(target_value)
        tgt_symbol = TGT_SYMBOL
        tgt_arc = False
        tgt_state = True
        tgt_arc_idx = None
        tgt_state_idx = target_value
    else:  # vanilla
        lifted_automaton = prep.prepare_vanilla()
        tgt_symbol = TGT_SYMBOL
        tgt_arc = False
        tgt_state = False
        tgt_arc_idx = None
        tgt_state_idx = None
    
    # Step 3: Create the sampler
    sampler = SamplerFactory.create_sampler(
        lifted_automaton, tgt_symbol, DTYPE, DEVICE, seed,
        num_states, alphabet_size, max_count, accept_prob, automaton_name,
        tgt_arc=tgt_arc, tgt_state=tgt_state,
        tgt_arc_idx=tgt_arc_idx, tgt_state_idx=tgt_state_idx)
    
    # Step 4: Generate samples
    if intervention_type == "vanilla":
        samples = sampler.sample_original(num_samples=K)
        # Process vanilla samples as in sample_vanilla()
        # ...
        stats = {"total_symbols": 0}  # Calculate this
        return samples, sampler, stats
    else:
        samples, counts = sampler.sample_interventions(K, N, pbar=True)
        # Process samples and calculate stats
        # ...
        stats = {"total_symbols": 0}  # Calculate this
        return samples, sampler, counts, stats


def main():
    """Test function to demonstrate both approaches"""
    K = 50
    N = 10
    
    for autom in ["canonical_parity"]:
        print("NOW TARGETING AUTOMATON: ", autom)

        # Test the original approach
        print("=== TESTING ORIGINAL APPROACH ===")
        
        print("Sampling vanilla")
        vanilla_samples, vanilla_sampler, stats = sample_vanilla(K, N, name=autom)
        rayuela_fsa = to_rayuela(vanilla_sampler)
        assert len(vanilla_samples) == K
        
        print("Sampling arcs")
        arc_samples, arc_sampler, counts, stats = sample_arc(K, N, name=autom, tgt_transition=4)
        rayuela_fsa = to_rayuela(arc_sampler)
        target_arc = arc_sampler.A_l.target_transition
        arc_pdfa = to_rayuela(arc_sampler)
        assert len(arc_samples) == K
        assert abs(sum(counts)-N) < 10
        
        print("Sampling states")
        state_samples, state_sampler, counts, stats = sample_state(K, N, name=autom, tgt_state=4)
        rayuela_fsa = to_rayuela(state_sampler)
        target_state = state_sampler.A_l.target_state
        state_pdfa = to_rayuela(state_sampler)
        assert len(state_samples) == K
        assert abs(sum(counts)-N) < 10
        
        print("Sampling symbols")
        symb_samples, symbol_sampler, counts, symbol_stats = sample_symbol(K, N, name=autom, tgt_symbol=1)
        tgt_symbol = symbol_sampler.tgt_symbol
        assert len(symb_samples) == K
        assert abs(sum(counts)-N) < 10
        
        # Test the modular approach
        print("\n=== TESTING MODULAR APPROACH ===")
        
        print("Sampling arcs (modular)")
        arc_samples_new, arc_sampler_new, counts_new, stats_new = modular_sample("arc", K, N, automaton_name=autom, target_value=4)
        assert len(arc_samples_new) == K
        assert abs(sum(counts_new)-N) < 10
        
        print("Sampling states (modular)")
        state_samples, state_sampler, counts, stats = modular_sample("state", K, N, automaton_name=autom, target_value=4)
        assert len(state_samples) == K
        assert abs(sum(counts)-N) < 10
        
        print("Sampling symbols (modular)")
        symb_samples_new, symbol_sampler_new, counts_new, stats = modular_sample("symbol", K, N, automaton_name=autom, target_value=1)
        assert len(symb_samples_new) == K
        assert abs(sum(counts_new)-N) < 10

        print("Sampling vanilla (modular)")
        vanilla_samples, vanilla_sampler, stats = modular_sample("vanilla", K, N, automaton_name=autom)
        assert len(vanilla_samples) == K
        breakpoint()

if __name__ == '__main__':
    main()