import sys
import torch
torch.set_grad_enabled(False)
from text_conditioning import *
import tqdm.auto as tqdm

sys.setrecursionlimit(10000)

def load_sampler(model_name_or_path):
    if model_name_or_path == 'pt-q8o1':
        tcs_base = TextConditionedSampler("Qwen/Qwen3-8B-Base")
        tcs_expert = TextConditionedSampler("allenai/OLMo-2-0425-1B-Instruct")
        tcs_antiexpert = TextConditionedSampler("allenai/OLMo-2-0425-1B")
        return BytewiseProxyTuningFactory(
            BytewiseQAFactory(tcs_base),
            BytewiseInstructFactory(tcs_expert),
            BytewiseQAFactory(tcs_antiexpert),
            alpha=1,
        )
    
    if model_name_or_path == 'pt-q8o1a':
        tcs_base = TextConditionedSampler("Qwen/Qwen3-8B-Base")
        tcs_expert = TextConditionedSampler("allenai/OLMo-2-0425-1B-Instruct")
        tcs_antiexpert = TextConditionedSampler("allenai/OLMo-2-0425-1B")
        return BytewiseProxyTuningFactory(
            BytewisePromptTemplateFactory(tcs_base, None, "\nAnswer: "),
            BytewiseInstructFactory(tcs_expert),
            BytewisePromptTemplateFactory(tcs_antiexpert, None, "\nAnswer: "),
            alpha=1,
        )

    if model_name_or_path == 'pt-l8o1':
        tcs_base = TextConditionedSampler("Qwen/Qwen3-8B-Base")
        tcs_expert = TextConditionedSampler("allenai/OLMo-2-0425-1B-Instruct")
        tcs_antiexpert = TextConditionedSampler("allenai/OLMo-2-0425-1B")
        return BytewiseProxyTuningFactory(
            BytewiseQAFactory(tcs_base),
            BytewiseInstructFactory(tcs_expert),
            BytewiseQAFactory(tcs_antiexpert),
            alpha=1,
        )
        
    if model_name_or_path == 'pt-l8o1a':
        tcs_base = TextConditionedSampler("Qwen/Qwen3-8B-Base")
        tcs_expert = TextConditionedSampler("allenai/OLMo-2-0425-1B-Instruct")
        tcs_antiexpert = TextConditionedSampler("allenai/OLMo-2-0425-1B")
        return BytewiseProxyTuningFactory(
            BytewisePromptTemplateFactory(tcs_base, None, "\nAnswer: "),
            BytewiseInstructFactory(tcs_expert),
            BytewisePromptTemplateFactory(tcs_antiexpert, None, "\nAnswer: "),
            alpha=1,
        )
        
    if model_name_or_path.endswith('Instruct'):
        return BytewiseInstructFactory(TextConditionedSampler(model_name_or_path))
        
    models = model_name_or_path.split(",")
    tcss = [TextConditionedSampler(model) for model in models]
    if len(tcss) > 1:
        return EnsembleBytewiseSamplerFactory(tcss, mode="product")
    else:
        return tcss[0]
    

def batched_generate(
    prompts,
    bsf,
    do_sample=False,
    max_new_bytes=50,
    batch_size=32,
    **kwargs
):
    results = []
    for batch in it.batched(tqdm.tqdm(prompts), batch_size):
        results.extend(generate_batched(bsf, batch, do_sample=do_sample, max_new_bytes=max_new_bytes, display=False, **kwargs))
    return results