#!/usr/bin/env python
"""
train_llama3_8b_qlora.py ─ Fine-tune **LLaMA-3 8B-Instruct** on a TITAN RTX (24 GB)
with **4-bit QLoRA + LoRA adapters**, using your JSONL dataset that contains
`"prompt"` and `"code"` fields.

Key points adapted to your hardware 🖥️
────────────────────────────────────────────────────────────────────────────
• Model            : meta-llama/Meta-Llama-3-8B-Instruct  (≈10 GB in 4-bit)
• GPU              : TITAN RTX 24 GB → batch 4 fits comfortably (<14 GB)
• Sequence length  : 512  (your data tops out at ≤500)
• Real batch       : 4, gradient_accum = 4 → effective batch 16
• Validation split : 10 % (configurable)
• Precision        : fp16 (Turing GPUs do **not** have bfloat16 tensor cores)
• Optimizer        : paged_adamw_8bit  (saves CPU RAM)

Usage example
─────────────
```bash
python train_llama3_8b_qlora.py \
  --dataset_path /path/to/dataset.jsonl \
  --output_dir   /home/$USER/llama3_8b_qlora_run1
```
"""
from huggingface_hub import login
login("your token")






import argparse, os, math, json, torch
from datetime import datetime
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,
    TrainingArguments, Trainer, DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training





# ═══════════════════════════════════  CLI  ══════════════════════════════════

def parse_args():
    ap = argparse.ArgumentParser()
    ap.add_argument("--model_id", default="meta-llama/Llama-3.2-1B-Instruct")
    ap.add_argument("--dataset_path", required=True, help="JSONL with 'prompt' & 'code' fields")
    ap.add_argument("--output_dir", required=True)
    ap.add_argument("--epochs", type=int, default=4)
    ap.add_argument("--val_split", type=float, default=0.20)
    ap.add_argument("--max_seq_len", type=int, default=1204)
    ap.add_argument("--batch", type=int, default=4)
    ap.add_argument("--grad_accum", type=int, default=4) #4
    ap.add_argument("--lr", type=float, default=2e-4)
    return ap.parse_args()

# ═══════════════════════════════════  Helpers  ═════════════════════════════

def formatting_func(example):
    """Compose an instruction-response pair for LLaMA-Instruct."""
    return (
        f"### Instruction:\n{example['prompt']}\n\n"
        f"### Response:\n{example['code']}"
    )

# ═══════════════════════════════════  Main  ════════════════════════════════
import matplotlib.pyplot as plt
import json

def plot_data_lengths(tokenized_train_dataset, tokenized_val_dataset, output_dir):
    lengths = [len(x['input_ids']) for x in tokenized_train_dataset]
    lengths += [len(x['input_ids']) for x in tokenized_val_dataset]
    print(f"Total tokenized examples: {len(lengths)}")

    # Save raw lengths
    lengths_path = os.path.join(output_dir, "sequence_lengths.json")
    with open(lengths_path, "w") as f:
        json.dump(lengths, f, indent=2)
    print(f"? Saved lengths to {lengths_path}")

    # Plot and save histogram
    plt.figure(figsize=(10, 6))
    plt.hist(lengths, bins=20, alpha=0.7, color='blue')
    plt.xlabel('Length of input_ids')
    plt.ylabel('Frequency')
    plt.title('Distribution of Lengths of input_ids')
    plt.tight_layout()
    plot_path = os.path.join(output_dir, "sequence_length_distribution.png")
    plt.savefig(plot_path)
    print(f"? Saved histogram to {plot_path}")

def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    # 1️⃣  Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_id, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"

    # 2️⃣  Dataset → train / val
    raw_ds = load_dataset("json", data_files=args.dataset_path, split="train")
   
    def tok_fn(example):
         
         formatted = formatting_func(example)
         
         
         enc = tokenizer(
            formatted,
            
            truncation=True,
            max_length=args.max_seq_len,
            padding="max_length",
         )
         enc["labels"] = enc["input_ids"].copy()
         return enc

    tokenised = raw_ds.map(tok_fn, remove_columns=raw_ds.column_names, num_proc=os.cpu_count())
    sample_idx = 1
    decoded = tokenizer.decode(tokenised[sample_idx]['input_ids'], skip_special_tokens=False)
    print(f"\n?? Decoded input_ids for sample {sample_idx}:\n{decoded}")

    print(f"\n?? Raw input_ids for sample {sample_idx}:\n{tokenised[sample_idx]['input_ids']}")
    split     = tokenised.train_test_split(test_size=args.val_split, seed=42)
    train_ds, eval_ds = split["train"], split["test"]

    # 3️⃣  4-bit quantisation config
    bnb_cfg = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.float16,
    )

    # 4️⃣  Base model
    base_model = AutoModelForCausalLM.from_pretrained(
        args.model_id, quantization_config=bnb_cfg, device_map="auto"
    )
    
    

        
        
    tokenized_train_dataset = train_ds
    tokenized_val_dataset = eval_ds

    plot_data_lengths(tokenized_train_dataset, tokenized_val_dataset, args.output_dir)

    print("🚀 model lives on →", next(base_model.parameters()).device)
    base_model = prepare_model_for_kbit_training(base_model)

    # 5️⃣  LoRA adapters
    lora_cfg = LoraConfig(
        r=8,
        lora_alpha=16,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(base_model, lora_cfg)
    model.print_trainable_parameters()

    # 6️⃣  Training args
    steps_per_epoch = math.ceil(len(train_ds) / (args.batch * args.grad_accum))
    run_name = f"llama3-8b-lora-{datetime.now():%Y%m%d_%H%M}"
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        run_name=run_name,
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.batch,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=args.grad_accum,
        learning_rate=args.lr,
        warmup_steps=int(0.05 * args.epochs * steps_per_epoch),
        logging_steps=50,
        eval_strategy="steps",
        eval_steps=steps_per_epoch,
        save_strategy="steps",
        save_steps=steps_per_epoch,
        load_best_model_at_end=True,
        metric_for_best_model="loss",
        greater_is_better=False,
        fp16=True,
        bf16=False,
        optim="paged_adamw_8bit",
        gradient_checkpointing=True,
        report_to="none",
    )

    # 7️⃣  Data collator
    data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

    # 8️⃣  Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        data_collator=data_collator,
    )

    trainer.train()

    # 9️⃣  Save adapter + tokenizer + metrics
    metrics = trainer.evaluate()
    with open(os.path.join(args.output_dir, "eval_results.json"), "w") as f:
        json.dump(metrics, f, indent=2)

    model.save_pretrained(os.path.join(args.output_dir, "lora-adapter"))
    tokenizer.save_pretrained(args.output_dir)

    print("✅  Training complete. Results written to eval_results.json")

if __name__ == "__main__":
    main()
