import os
import argparse
import json
from dotenv import load_dotenv
from gpt_generator import generate_synthetic_samples
from perplexity_filter import filter_by_perplexity
from get_model_outputs import generate_model_outputs
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig

def main():
    load_dotenv()
    openai_key = os.getenv("OPENAI_API_KEY")

    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_name", type=str, required=True)
    parser.add_argument("--seed_path", type=str, required=True)
    parser.add_argument("--reference_file", type=str, required=True)
    parser.add_argument("--output_dir", type=str, default="outputs/")
    parser.add_argument("--openai_model", type=str, default="gpt-4o")
    parser.add_argument("--openai_key", type=str, default=None, help="Optional OpenAI API key for generation")
    parser.add_argument("--hf_fallback_model", type=str, default="tiiuae/falcon-7b-instruct")
    parser.add_argument("--base_model_path", type=str, required=True)
    parser.add_argument("--lora_model_path", type=str, required=True)
    parser.add_argument("--threshold", type=float, default=1.5)
    parser.add_argument("--batch_size", type=int, default=20)


    args = parser.parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    # Get target size
    with open(args.reference_file, "r") as f:
        target_data = json.load(f)
    target_size = len(target_data)

    print(f"Target distillation set size: {target_size}")
    distillation_set = []
    iteration = 0
    openai_key = args.openai_key or os.getenv("OPENAI_API_KEY")

    bnb_config = BitsAndBytesConfig(load_in_4bit=True)

    base_tokenizer = AutoTokenizer.from_pretrained(args.base_model_path)
    base_model = AutoModelForCausalLM.from_pretrained(args.base_model_path, quantization_config=bnb_config, device_map="auto")
    base_model.eval()

    lora_tokenizer = AutoTokenizer.from_pretrained(args.lora_model_path)
    lora_model = AutoModelForCausalLM.from_pretrained(args.lora_model_path, quantization_config=bnb_config, device_map="auto")
    lora_model.eval()

    hf_fallback_tokenizer = AutoTokenizer.from_pretrained(args.hf_fallback_model)
    hf_fallback_model = AutoModelForCausalLM.from_pretrained(args.hf_fallback_model, quantization_config=bnb_config, device_map="auto")
    hf_fallback_model.eval()

    while len(distillation_set) < target_size:
        iteration += 1
        print(f"\n--- Iteration {iteration} ---")

        gen_file = os.path.join(args.output_dir, f"generated_batch_{iteration}.json")
        filtered_file = os.path.join(args.output_dir, f"filtered_batch_{iteration}.json")
        base_out = os.path.join(args.output_dir, f"base_outputs_{iteration}.json")
        lora_out = os.path.join(args.output_dir, f"lora_outputs_{iteration}.json")

        # Step 1: Generate samples
        generate_synthetic_samples(
            seed_txt=args.seed_path,
            output_path=gen_file,
            n_samples=args.batch_size,
            model=args.openai_model,
            openai_key=openai_key,
            hf_fallback=hf_fallback_model,
            hf_fallback_tokenizer=hf_fallback_tokenizer
        )

        # Step 2: Get model outputs
        generate_model_outputs(base_model, base_tokenizer, gen_file, base_out)
        generate_model_outputs(lora_model, lora_tokenizer, gen_file, lora_out)

        # Step 3: Filter by perplexity using those outputs
        filtered = filter_by_perplexity(
            input_file=gen_file,
            output_file=filtered_file,
            base_output_file=base_out,
            lora_output_file=lora_out,
            lora_model=lora_model,
            lora_tokenizer=lora_tokenizer,
            threshold=args.threshold
        )


        distillation_set.extend(filtered)
        print(f"→ Accumulated distillation samples: {len(distillation_set)}")

        with open(os.path.join(args.output_dir, "distillation_set.json"), "w") as f:
            json.dump(distillation_set, f, indent=2)

    print("\n✓ Completed distillation dataset generation.")

if __name__ == "__main__":
    main()
