# bbh_pipeline.py
import os
import json
from tqdm import tqdm
import argparse
from dotenv import load_dotenv
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig
from gpt_generator import generate_synthetic_samples
from perplexity_filter import filter_by_perplexity
from get_model_outputs import generate_model_outputs

BBH_TASKS = [
    "disambiguation_qa", "logical_deduction_five_objects", "logical_deduction_seven_objects",
    "causal_judgement", "date_understanding", "formal_fallacies", "geometric_shapes",
    "hyperbaton", "movie_recommendation", "navigational_instructions", "penguins_in_a_table",
    "reasoning_about_colored_objects", "ruin_names", "salient_translation_error_detection",
    "sports_understanding", "temporal_sequences", "tracking_shuffled_objects_five_objects",
    "tracking_shuffled_objects_seven_objects", "web_of_lies", "word_sorting", "object_counting",
    "multistep_arithmetic_two", "boolean_expressions", "dyck_languages", "word_in_context",
    "conjunction_fallacy", "navigate"
]

def run_task_pipeline(task_name, args, openai_key):
    print(f"\n=== Running Task: {task_name} ===")
    task_dir = os.path.join(args.output_dir, task_name)
    os.makedirs(task_dir, exist_ok=True)

    seed_file = os.path.join(args.bbh_seed_dir, f"{task_name}.txt")
    # target_data = task_data  # passed from main()

    target_size = args.samples_per_task

    distillation_set = []
    iteration = 0

    bnb_config = BitsAndBytesConfig(load_in_4bit=True)

    base_tokenizer = AutoTokenizer.from_pretrained(args.base_model_path)
    base_model = AutoModelForCausalLM.from_pretrained(args.base_model_path, quantization_config=bnb_config, device_map="auto")
    base_model.eval()

    lora_tokenizer = AutoTokenizer.from_pretrained(args.lora_model_path)
    lora_model = AutoModelForCausalLM.from_pretrained(args.lora_model_path, quantization_config=bnb_config, device_map="auto")
    lora_model.eval()

    hf_fallback_tokenizer = AutoTokenizer.from_pretrained(args.hf_fallback_model)
    hf_fallback_model = AutoModelForCausalLM.from_pretrained(args.hf_fallback_model, quantization_config=bnb_config, device_map="auto")
    hf_fallback_model.eval()

    while len(distillation_set) < target_size:
        iteration += 1
        print(f"--- Iteration {iteration} ---")

        gen_file = os.path.join(task_dir, f"generated_{iteration}.json")
        filtered_file = os.path.join(task_dir, f"filtered_{iteration}.json")
        base_out = os.path.join(task_dir, f"base_outputs_{iteration}.json")
        lora_out = os.path.join(task_dir, f"lora_outputs_{iteration}.json")

        generate_synthetic_samples(
            seed_txt=seed_file,
            output_path=gen_file,
            n_samples=args.batch_size,
            model=args.openai_model,
            openai_key=openai_key,
            hf_fallback=hf_fallback_model,
            hf_fallback_tokenizer=hf_fallback_tokenizer
        )

        # Step 2: Get model outputs
        generate_model_outputs(base_model, base_tokenizer, gen_file, base_out)
        generate_model_outputs(lora_model, lora_tokenizer, gen_file, lora_out)

        # Step 3: Filter by perplexity using those outputs
        filtered = filter_by_perplexity(
            input_file=gen_file,
            output_file=filtered_file,
            base_output_file=base_out,
            lora_output_file=lora_out,
            lora_model=lora_model,
            lora_tokenizer=lora_tokenizer,
            threshold=args.threshold
        )

        distillation_set.extend(filtered)
        distillation_set = distillation_set[:target_size]

        with open(os.path.join(task_dir, "distillation_set.json"), "w") as f:
            json.dump(distillation_set, f, indent=2)

    print(f"✓ Completed distillation for task: {task_name}")

def main():
    load_dotenv()
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, default="bbh_distilled")
    parser.add_argument("--bbh_reference_file", type=str, required=True)
    parser.add_argument("--bbh_seed_dir", type=str, required=True)
    parser.add_argument("--openai_model", type=str, default="gpt-4o")
    parser.add_argument("--openai_key", type=str, default=None)
    parser.add_argument("--hf_fallback_model", type=str, default="tiiuae/falcon-7b-instruct")
    parser.add_argument("--base_model_path", type=str, required=True)
    parser.add_argument("--lora_model_path", type=str, required=True)
    parser.add_argument("--threshold", type=float, default=1.5)
    parser.add_argument("--batch_size", type=int, default=20)
    parser.add_argument("--tasks", nargs="*", default=BBH_TASKS)
    parser.add_argument("--samples_per_task", type=int, default=100, help="How many samples to generate per BBH task")
    args = parser.parse_args()

    openai_key = args.openai_key or os.getenv("OPENAI_API_KEY")

    for task in args.tasks:
        run_task_pipeline(task, args, openai_key)

if __name__ == "__main__":
    main()
