import gc
import time
import torch
import pickle
import argparse
import numpy as np
import random
from collections import defaultdict
import datasets
import asyncio

from tokenizer_conversion.benchmarking.utils.bechmarking_utils import get_config_for_benchmark, load_transducer, set_seed, clear_transducedlm_function_caches
from tokenizer_conversion.benchmarking.utils.data_utils import load_wikitext_paragraphs_bytes
from tokenizer_conversion.machines.utils.properties import calculate_num_states_arcs
from tokenizer_conversion.machines.utils.config import _LMBackend
from tokenizer_conversion.machines.genlm_realpha import GenLMRealpha

from genlm.backend import load_model_by_name



datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True

async def run_eval(
    model_name,
    split,
    steps,
    fraction_us,
    prune_threshold,
    candidate_threshold,
    prune_threshold_alpha,
    max_prune_mass,
    max_candidates,
    output_file,
    max_context_len,
    transducer_name,
    paragraphs=4,
    seed=80808,
    text_length=None,
    use_vllm=False

):
    set_seed(seed)

    metadata = {
        'model_name': model_name,
        'split': split,
        'max_context_len' : max_context_len,
        'paragraphs': paragraphs,
        'prune_threshold': prune_threshold,
        'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
        'transducer_name': transducer_name,
        'fraction_us': fraction_us,
        'steps': steps,
        'text_length': text_length,
    }
    # HF
    if use_vllm:
        llm = load_model_by_name(model_name, llm_opts={"engine_opts": {"dtype": "float16"}})
    else:
        llm = load_model_by_name(model_name, backend="hf", llm_opts={"hf_opts": {"torch_dtype": torch.float16}})

    T = load_transducer(transducer_name, llm, model_name)
    print("Stats: ", calculate_num_states_arcs(T.fst))
    paragraphs, original, total_len = load_wikitext_paragraphs_bytes(T, split, n=paragraphs, transducer_name=transducer_name, text_length=text_length)
    
    print("TLength: ", total_len)
    metadata['text_length'] = total_len

    try:
        with open(output_file, 'rb') as f:
            saved_data = pickle.load(f)
            metadata = saved_data.get('metadata', metadata)
            stats = saved_data.get('stats', {})
            p_nexts = defaultdict(list, saved_data.get('p_nexts', {}))
            completed_configs = set(p_nexts.keys())
    except FileNotFoundError:
        stats = {}
        p_nexts = defaultdict(list)
        completed_configs = set()
    

    universal_states = set([s for s in T.universal_states if T.universal_states[s]])
    print(f"Number of universal states: {len(universal_states)}")

    max_K = int(len(universal_states)/fraction_us)
    drop_counts = np.unique(np.linspace(0, max_K, steps, dtype=int)).tolist()

    metadata['org_num_universal_states'] = len(universal_states)
    metadata['drop_counts'] = drop_counts

    print(f"Drop counts: {drop_counts}")


    for K in drop_counts:
        if K in completed_configs:
            print(f"Skipping K={K}, already processed")
            continue
        # average
        stats[K] = {
            'times': [],
            'backtrack_calls': [],
            'backtrack_times': []
        }
        all_times = []
        # Average over multiple runs
        for i in range(5):
            print(f"Processing K={K}... i={i}")
            
            drop_univ = random.sample(list(universal_states), K)
            print(f"Drop {K} universal states: {drop_univ}")
            new_universals = universal_states - set(drop_univ)
            print(f"New universal states: {len(new_universals)}")
            
            for i, para in enumerate(paragraphs):
                
                cfg = get_config_for_benchmark(transducer_name, prune_threshold, candidate_threshold, prune_threshold_alpha, max_prune_mass, max_candidates)
                llm.clear_cache()
                if not use_vllm:
                    llm.clear_kv_cache()

                # New initialization
                T = load_transducer(transducer_name, llm, model_name)
                T._universal_set_cache = {}

                # Compute closure
                for s in T.fst.states():
                    T._state_closure_output_syms[s] = T.input_epsilon_closure_track_trg(s)
                print(f"Config: {cfg}")

                # Set the new universal states
                T.universal_states = {s: (s in new_universals) for s in T.universal_states}

                try:
                    if cfg.backend is _LMBackend.GENLM_BYTES:
                        T.genlm_realpha = await GenLMRealpha.create(model_name, llm=llm, K=8, prune_threshold=0.001)
                    print(f"Running T.sequence_logp_next for: {para}")
                    
                    res = await T.sequence_logp_next(cfg, para)
                    print("Backtracking stats: ", T.backtracking_stats)
                    backtrack_calls = T.backtracking_stats.get("calls", 0.0)
                    backtrack_times = T.backtracking_stats.get("time", 0.0)
                    all_times.extend(res["times"])
                    
                finally:
                    if cfg.backend is _LMBackend.GENLM_BYTES:
                        T.genlm_realpha.empty_cache()
                        await T.genlm_realpha.root_beam.cleanup()
                        del T.genlm_realpha

                clear_transducedlm_function_caches(T)
                T.empty_cache()               
                del T
                gc.collect()
                torch.cuda.empty_cache()

                # Intermediate logging
                stats[K]['times'].extend(res["times"])
                stats[K]['backtrack_calls'].append(backtrack_calls)
                stats[K]['backtrack_times'].append(backtrack_times)
                p_nexts[K].append([list(byte_dist.values()) for byte_dist in res["byte_level_log_distribution"]])
                del res["byte_level_log_distribution"]

        with open(output_file, 'wb') as f:
            pickle.dump({
                'metadata': metadata,
                'stats': stats,
                'p_nexts': dict(p_nexts),
                'drop_counts': drop_counts,
                'num_universal_states': len(universal_states),
            }, f)
        

    return metadata, stats, p_nexts


async def main():
    """ Example usage:
    python scoring.py --model gpt2-large --split test --length 100000 --beam-sizes 1 2 4 8 16
    """
    parser = argparse.ArgumentParser(
        description='Evaluate character-level probability distributions'
    )
    parser.add_argument('--model', default='gpt2-large',
                      help='Model name to use')
    parser.add_argument('--split', default='test',
                      help='Dataset split to use')
    parser.add_argument('--paragraphs', type=int, default=1,
                      help='Length of text to process')
    
    # PRUNING PARAMETERS
    parser.add_argument('--prune-threshold', type=float,
                      default=0.0001,
                      help='Prune threshold')
    parser.add_argument('--candidate_threshold', type=int, default=100,
                      help='Number of candidates where we start pruning')
    parser.add_argument('--prune_threshold_alpha', type=float, default=0.7,
                      help='Steepness of threshold increase')
    parser.add_argument('--max_prune_mass', type=float, default=0.4,
                      help='Maximum pruning mass')
    parser.add_argument('--max_candidates', type=int, default=None,
                      help='Maximum number of candidates in each iteration')
    
    # DROP UNIVERSALS PARAMETERS
    parser.add_argument('--steps', type=int, default=11,
                      help='Number of steps removing universal states')
    parser.add_argument('--length', type=int, default=256,
                      help='Length of text to process')
    parser.add_argument('--fraction_us', type=int, default=2,
                      help='Lowest fraction of universal states to keep')
    
    parser.add_argument('--output', default='results/scoring_results.pkl',
                      help='Output file')
    parser.add_argument('--max-context-len', type=int, help='Maximum context length')
    parser.add_argument('--text_length', type=int, help='Maximum text length', default=256)

    parser.add_argument('--transducer', type=str, help='Name of Transducer', default="hf_realpha")
    parser.add_argument('--transducer-path', type=str, help='Load Transducer from dir', default=None)
    parser.add_argument('--use_vllm', type=bool, default=False)

    args = parser.parse_args()

    _, _, _ = await run_eval(
        model_name=args.model,
        split=args.split,
       
        # Pruning
        prune_threshold=args.prune_threshold,
        candidate_threshold=args.candidate_threshold,
        prune_threshold_alpha=args.prune_threshold_alpha,
        max_prune_mass=args.max_prune_mass,
        max_candidates=args.max_candidates,

        # Drop universals
        steps=args.steps,
        fraction_us=args.fraction_us,

        output_file=args.output,
        max_context_len=args.max_context_len,
        paragraphs=args.paragraphs,
        transducer_name=args.transducer,
        text_length=args.text_length,
        use_vllm=args.use_vllm
    )

    print(f"Results saved to: {args.output}")


if __name__ == '__main__':
    asyncio.run(main())