import gc
import time
import torch
import argparse
import numpy as np
import datasets
import asyncio

from tokenizer_conversion.benchmarking.utils.bechmarking_utils import (
    get_config_for_benchmark, 
    load_transducer, 
    set_seed, 
    clear_transducedlm_function_caches, 
    safe_load_pickle,
    file_lock, 
    atomic_pickle_dump
)
from tokenizer_conversion.benchmarking.utils.data_utils import load_wikitext_paragraphs_bytes, load_fasta
from tokenizer_conversion.machines.utils.properties import calculate_num_states_arcs
from tokenizer_conversion.machines.utils.config import _LMBackend
from tokenizer_conversion.machines.genlm_realpha import GenLMRealpha

from genlm.backend import load_model_by_name

datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True


async def run_eval(
    model_name,
    split,
    prune_threshold,
    candidate_threshold,
    prune_threshold_alpha,
    max_prune_mass,
    max_candidates,
    output_file,
    max_context_len,
    transducer_name,
    paragraphs=4,
    seed=80808,
    max_len=None,
    use_vllm=False
):
    set_seed(seed)

    metadata = {
        'model_name': model_name,
        'split': split,
        'max_context_len' : max_context_len,
        'paragraphs': paragraphs,
        'prune_threshold': prune_threshold,
        'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
        'transducer_name': transducer_name,
    }
    # HF
    if use_vllm:
        llm = load_model_by_name(model_name, llm_opts={"engine_opts": {"dtype": "float16"}})
    else:
        llm = load_model_by_name(model_name, backend="hf", llm_opts={"hf_opts": {"torch_dtype": torch.float16}})
    
    T = load_transducer(transducer_name, llm, model_name)
    print("Stats: ", calculate_num_states_arcs(T.fst))

    if transducer_name == "hf_dna2aa":
        fasta_file = "src/tokenizer_conversion/benchmarking/dna_data/uniprotkb_accession_A0A0A0MT78_OR_access_2025_08_20.fasta"
        paragraphs, total_len = load_fasta(T, fasta_file)

    else:
        paragraphs, original, total_len = load_wikitext_paragraphs_bytes(T, split, n=paragraphs, transducer_name=transducer_name)
        
        if max_len is not None:
            seen = 0
            new_pgs = []
            for pg in paragraphs:
                if seen > max_len:
                    break
                missing = max_len - seen
                cut_pg = pg[:missing]
                new_pgs.append(cut_pg)
                seen += len(cut_pg)

            total_len = seen
            paragraphs = new_pgs
            breakpoint()


    print(f"Total Length: {total_len}")
    metadata['text_length'] = total_len

    default_payload = {'metadata': metadata, 'stats': {}, 'p_nexts': {}}
    saved = safe_load_pickle(output_file, default_payload)

    # preserve existing metadata if present
    metadata = saved.get('metadata', metadata)
    stats_disk = saved.get('stats', {})
    pnext_disk = saved.get('p_nexts', {})

    for ths in prune_threshold:
        already = len(pnext_disk.get(ths, []))
        if already >= len(paragraphs):
            print(f"Skipping ths={ths}, already processed ({already} paragraphs).")
            continue

        print(f"Processing ths={ths} starting at paragraph index {already} …")
        
        all_times = []
        for i in range(already, len(paragraphs)):
            para = paragraphs[i]

            cfg = get_config_for_benchmark(transducer_name, ths, candidate_threshold, prune_threshold_alpha, max_prune_mass, max_candidates)
            
            llm.clear_cache()
            if not use_vllm:
                llm.clear_kv_cache()

            # New initialization
            T = load_transducer(transducer_name, llm, model_name)

            # Compute closure
            for s in T.fst.states():
                T._state_closure_output_syms[s] = T.input_epsilon_closure_track_trg(s)
            print(f"Config: {cfg}")
            try:
                if cfg.backend is _LMBackend.GENLM_BYTES:
                    T.genlm_realpha = await GenLMRealpha.create(model_name, llm=llm, K=8, prune_threshold=0.001)
                print(f"Running T.sequence_logp_next for: {para}")
                
                res = await T.sequence_logp_next(cfg, para)
                print("Backtracking stats: ", T.backtracking_stats)
                backtrack_calls = T.backtracking_stats.get("calls", 0.0)
                backtrack_times = T.backtracking_stats.get("time", 0.0)
                all_times.extend(res["times"])
                  
            finally:
                if cfg.backend is _LMBackend.GENLM_BYTES:
                    T.genlm_realpha.empty_cache()
                    T.genlm_realpha.llm.clear_cache()
                    if not use_vllm:
                        T.genlm_realpha.llm.clear_kv_cache()
                    await T.genlm_realpha.root_beam.cleanup()
                    del T.genlm_realpha

            clear_transducedlm_function_caches(T)
            T.empty_cache()               
            del T
            gc.collect()
            torch.cuda.empty_cache()

            # Intermediate logging
            print(f"Mean time {np.mean(res['times'])}")
            print(f"Throughput {len(para) / np.sum(res['times'])} bytes/sec")
            p_item = [list(byte_dist.values()) for byte_dist in res["byte_level_log_distribution"]]
            del res["byte_level_log_distribution"]

            # per-paragraph incremental merge+save (same top-level dict structure)
            lock_path = output_file + ".lock"
            with file_lock(lock_path):
                prev = safe_load_pickle(output_file, default_payload)
                meta_out  = prev.get('metadata', metadata)
                stats_out = prev.get('stats', {})
                pnext_out = prev.get('p_nexts', {})

                # ensure per-threshold containers exist
                if ths not in pnext_out:
                    pnext_out[ths] = []
                if ths not in stats_out:
                    stats_out[ths] = {'times': [], 'backtrack_calls': [], 'backtrack_times': []}

                # append this paragraph
                pnext_out[ths].append(p_item) 
                stats_out[ths]['times'].extend(res['times'])
                stats_out[ths]['backtrack_calls'].append(backtrack_calls)
                stats_out[ths]['backtrack_times'].append(backtrack_times)

                atomic_pickle_dump({'metadata': meta_out, 'stats': stats_out, 'p_nexts': pnext_out},
                                output_file)

            # free big temporaries ASAP
            del p_item
            del prev
            gc.collect()
            torch.cuda.empty_cache()
            del res

        gc.collect()
        

async def main():
    
    parser = argparse.ArgumentParser(
        description='Evaluate character-level probability distributions'
    )
    parser.add_argument('--model', default='gpt2-large',
                      help='Model name to use')
    parser.add_argument('--split', default='test',
                      help='Dataset split to use')
    parser.add_argument('--paragraphs', type=int, default=3,
                      help='Length of text to process')
    # PRUNING PARAMETERS
    parser.add_argument('--prune-threshold', type=float, nargs='+',
                      default=[0.1, 0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001, 0.00003, 0.00001, 0.000003, 0.000001],
                      help='Prune threshold')
    parser.add_argument('--candidate_threshold', type=int, default=100,
                      help='Number of candidates where we start pruning')
    parser.add_argument('--prune_threshold_alpha', type=float, default=0.7,
                      help='Steepness of threshold increase')
    parser.add_argument('--max_prune_mass', type=float, default=0.4,
                      help='Maximum pruning mass')
    parser.add_argument('--max_candidates', type=int, default=None,
                      help='Maximum number of candidates in each iteration')
    
    parser.add_argument('--output', default='results/scoring_results.pkl',
                      help='Output file')
    parser.add_argument('--max-context-len', type=int, help='Maximum context length')
    parser.add_argument('--transducer', type=str, help='Name of Transducer', default="hf_realpha")
    parser.add_argument('--max-len', type=int, default=None)
    parser.add_argument('--use_vllm', type=bool, default=False)

    args = parser.parse_args()

    await run_eval(
        model_name=args.model,
        split=args.split,

        # Pruning
        prune_threshold=args.prune_threshold,
        candidate_threshold=args.candidate_threshold,
        prune_threshold_alpha=args.prune_threshold_alpha,
        max_prune_mass=args.max_prune_mass,
        max_candidates=args.max_candidates,

        output_file=args.output,
        max_context_len=args.max_context_len,
        paragraphs=args.paragraphs,
        transducer_name=args.transducer,
        max_len=args.max_len,
        use_vllm=args.use_vllm
    )

    print(f"Results saved to: {args.output}")


if __name__ == '__main__':
    asyncio.run(main())