import numpy as np
import torch

import os
import pickle
from contextlib import contextmanager

from tokenizer_conversion.machines.utils.config import Config, _LMBackend

from tokenizer_conversion.machines.hf_realpha import build_hf_fst_bytes
from tokenizer_conversion.machines.ptb import load_ptb
from tokenizer_conversion.machines.dna2aa import build_dna_to_aa_fst_bytes

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def load_transducer(transducer_name, llm, model_name):
    # Load transducer
    if transducer_name == "hf_realpha":
        return build_hf_fst_bytes(hf_tokenizer=model_name, llm=llm)
    elif transducer_name == "ptb":
        return load_ptb("fsts/ptb_bytes_AUG", model_name=model_name)
    elif transducer_name == "hf_dna2aa":
        return build_dna_to_aa_fst_bytes(llm_name=model_name, llm=llm)
    else:
        raise ValueError(f"Unknown transducer name: {transducer_name}") 

def clear_transducedlm_function_caches(T):
    for name in (
        'apply',
        'apply_deterministic_with_eps',
        'input_epsilon_closure',
        'input_epsilon_closure_track_trg',
        'first_symbol',
        '_vectorized_arcs',
        'get_vectorized',
        '_arcs',
        '_ps2id',
        '_id2ps',
        '_next_by_tok',
        '_next_prefilled',
        '_univ_arr',
        '_ctx_cache',
        '_local_cache',
        '_first',
        '_next',
        '_next_pset_after_with_y_dict'
    ):
        fn = getattr(T, name, None)
        if fn is not None and hasattr(fn, 'cache_clear'):
            try:
                fn.cache_clear()
            except Exception:
                pass

def get_config_for_benchmark(
        transducer_name, 
        ths, 
        candidate_threshold, 
        prune_threshold_alpha, 
        max_prune_mass,
        max_candidates
    ):
    if transducer_name in ["ptb", "lowercase", "bad_ungood"]:
        backend = _LMBackend.GENLM_BYTES
        batched = False
        track_powerstates=True
        use_no_symloop=True
    elif transducer_name == "hf_realpha":
        backend = _LMBackend.GENLM_ASYNC
        batched = True
        track_powerstates=False
        use_no_symloop=True
    elif transducer_name == "hf_dna2aa":
        backend = _LMBackend.GENLM_ASYNC
        batched = True
        track_powerstates=False
        use_no_symloop=True
    else:
        raise ValueError("Invalid transducer name.")

    return Config(
        prune_threshold = ths,
        ignore_remainder = True,
        use_beam_cache = True,
        verbose = True,
        cover_opt=False,
        candidate_threshold = candidate_threshold,
        prune_threshold_alpha = prune_threshold_alpha,
        max_prune_mass = max_prune_mass,   
        expand_threshold=3,
        track_powerstates=track_powerstates,
        backend=backend,
        batched=batched,
        ngram_top_p=1.0, # Disabled
        use_no_symloop=use_no_symloop,
        max_candidates=max_candidates
    )


def safe_load_pickle(path, default):
    try:
        with open(path, 'rb') as f:
            return pickle.load(f)
    except FileNotFoundError:
        return default
    except (pickle.UnpicklingError, EOFError):
        # treat empty/partial file as default
        return default

@contextmanager
def file_lock(lock_path):
    fd = None
    try:
        fd = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_RDWR)
        yield
    finally:
        if fd is not None:
            os.close(fd)
            try: os.unlink(lock_path)
            except FileNotFoundError: pass

def atomic_pickle_dump(obj, path):
    tmp = f"{path}.tmp.{os.getpid()}"
    with open(tmp, 'wb') as f:
        pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)
        f.flush(); os.fsync(f.fileno())
    os.replace(tmp, path)