# mmlu_pipeline.py
import os
import json
from tqdm import tqdm
import argparse
from dotenv import load_dotenv
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig
from gpt_generator import generate_synthetic_samples
from perplexity_filter import filter_by_perplexity
from get_model_outputs import generate_model_outputs

MMLU_TASKS = [
    'abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge',
    'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics',
    'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics',
    'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic',
    'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science',
    'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics',
    'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics',
    'high_school_physics', 'high_school_psychology', 'high_school_statistics',
    'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality',
    'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management',
    'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios',
    'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law',
    'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies',
    'sociology', 'us_foreign_policy', 'virology', 'world_religions'
]

def run_task_pipeline(task_name, args, openai_key):
    print(f"\n=== Running Task: {task_name} ===")
    task_dir = os.path.join(args.output_dir, task_name)
    os.makedirs(task_dir, exist_ok=True)

    seed_file = os.path.join(args.mmlu_seed_dir, f"{task_name}.txt")

    target_size = args.samples_per_task

    distillation_set = []
    iteration = 0

    bnb_config = BitsAndBytesConfig(load_in_4bit=True)

    base_tokenizer = AutoTokenizer.from_pretrained(args.base_model_path)
    base_model = AutoModelForCausalLM.from_pretrained(args.base_model_path, quantization_config=bnb_config, device_map="auto")
    base_model.eval()

    lora_tokenizer = AutoTokenizer.from_pretrained(args.lora_model_path)
    lora_model = AutoModelForCausalLM.from_pretrained(args.lora_model_path, quantization_config=bnb_config, device_map="auto")
    lora_model.eval()

    hf_fallback_tokenizer = AutoTokenizer.from_pretrained(args.hf_fallback_model)
    hf_fallback_model = AutoModelForCausalLM.from_pretrained(args.hf_fallback_model, quantization_config=bnb_config, device_map="auto")
    hf_fallback_model.eval()

    while len(distillation_set) < target_size:
        iteration += 1
        print(f"--- Iteration {iteration} ---")

        gen_file = os.path.join(task_dir, f"generated_{iteration}.json")
        filtered_file = os.path.join(task_dir, f"filtered_{iteration}.json")
        base_out = os.path.join(task_dir, f"base_outputs_{iteration}.json")
        lora_out = os.path.join(task_dir, f"lora_outputs_{iteration}.json")

        generate_synthetic_samples(
            seed_txt=seed_file,
            output_path=gen_file,
            n_samples=args.batch_size,
            model=args.openai_model,
            openai_key=openai_key,
            hf_fallback=hf_fallback_model,
            hf_fallback_tokenizer=hf_fallback_tokenizer
        )

        # Step 2: Get model outputs
        generate_model_outputs(base_model, base_tokenizer, gen_file, base_out)
        generate_model_outputs(lora_model, lora_tokenizer, gen_file, lora_out)

        # Step 3: Filter by perplexity using those outputs
        filtered = filter_by_perplexity(
            input_file=gen_file,
            output_file=filtered_file,
            base_output_file=base_out,
            lora_output_file=lora_out,
            lora_model=lora_model,
            lora_tokenizer=lora_tokenizer,
            threshold=args.threshold
        )

        distillation_set.extend(filtered)
        distillation_set = distillation_set[:target_size]

        with open(os.path.join(task_dir, "distillation_set.json"), "w") as f:
            json.dump(distillation_set, f, indent=2)

    print(f"✓ Completed distillation for task: {task_name}")

def main():
    load_dotenv()
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, default="mmlu_distilled")
    parser.add_argument("--mmlu_reference_file", type=str, required=True)
    parser.add_argument("--mmlu_seed_dir", type=str, required=True)
    parser.add_argument("--openai_model", type=str, default="gpt-4o")
    parser.add_argument("--openai_key", type=str, default=None)
    parser.add_argument("--hf_fallback_model", type=str, default="tiiuae/falcon-7b-instruct")
    parser.add_argument("--base_model_path", type=str, required=True)
    parser.add_argument("--lora_model_path", type=str, required=True)
    parser.add_argument("--threshold", type=float, default=1.5)
    parser.add_argument("--batch_size", type=int, default=20)
    parser.add_argument("--tasks", nargs="*", default=MMLU_TASKS)
    parser.add_argument("--samples_per_task", type=int, default=100, help="How many samples to generate per MMLU task")
    args = parser.parse_args()

    openai_key = args.openai_key or os.getenv("OPENAI_API_KEY")

    for task in args.tasks:
        run_task_pipeline(task, args, openai_key)

if __name__ == "__main__":
    main()
