#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
QLoRA fine-tuning of a causal LLM for issue summarization on CSV data.

- Loads a base model in 4-bit NF4, adds LoRA adapters (rank=32, alpha=32, dropout=0.05).
- Trains on (input_col -> target_col) pairs with prompt+target formatting.
- Masks loss on the prompt; optimizes only over target tokens.
- Paper-aligned defaults: 2 epochs, lr=1e-4, per-device batch size=4, grad_accum=8,
  AdamW, gradient checkpointing, bfloat16 compute, deterministic eval.
- Reports ROUGE-1/2/L and BERTScore (F1) on the validation split.

Example:
  python lora_finetune.py \
    --csv train_data.csv \
    --input-col case_text \
    --target-col issue_summary \
    --model meta-llama/Llama-3.1-8B-Instruct \
    --output-dir ./qlora_ckpt

Requirements:
  pip install "transformers>=4.41" peft bitsandbytes "datasets>=2.19" pandas scikit-learn \
              evaluate "bert-score>=0.3.13"
"""

import argparse
import os
import random
import json
from dataclasses import dataclass
from typing import Dict, List, Optional, Union

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

import torch
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
    set_seed,
    BitsAndBytesConfig,
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
)

# Optional metrics
import evaluate

# ---------------------------
# Args
# ---------------------------

def parse_args():
    ap = argparse.ArgumentParser()
    ap.add_argument("--csv", required=True, help="Input CSV with training pairs.")
    ap.add_argument("--input-col", default="case_text", help="Source text column name.")
    ap.add_argument("--target-col", default="issue_summary", help="Target summary column name.")
    ap.add_argument("--model", default="meta-llama/Llama-3.1-8B-Instruct",
                    help="HF model id or local path of the base model.")
    ap.add_argument("--output-dir", default="./qlora_ckpt", help="Where to save adapter and tokenizer.")
    # Paper defaults
    ap.add_argument("--epochs", type=int, default=2, help="Training epochs (paper uses 2).")
    ap.add_argument("--lr", type=float, default=1e-4, help="Learning rate (paper uses 1e-4).")
    ap.add_argument("--per-device-train-batch-size", type=int, default=4)
    ap.add_argument("--per-device-eval-batch-size", type=int, default=4)
    ap.add_argument("--gradient-accumulation-steps", type=int, default=8)
    ap.add_argument("--warmup-ratio", type=float, default=0.03)
    ap.add_argument("--lora-r", type=int, default=32)
    ap.add_argument("--lora-alpha", type=int, default=32)
    ap.add_argument("--lora-dropout", type=float, default=0.05)
    ap.add_argument("--max-source-tokens", type=int, default=512)
    ap.add_argument("--max-target-tokens", type=int, default=96)
    ap.add_argument("--eval-steps", type=int, default=500)
    ap.add_argument("--save-steps", type=int, default=0, help="0 -> do not save checkpoints during training.")
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--bf16", action="store_true", default=True, help="Use bfloat16 compute if CUDA available.")
    ap.add_argument("--logging-steps", type=int, default=50)
    ap.add_argument("--gradient-checkpointing", action="store_true", default=True)
    ap.add_argument("--eval-limit", type=int, default=2000, help="Max val samples for faster eval (0=all).")
    ap.add_argument("--report-to", default="none", help='e.g., "wandb", "tensorboard", or "none".')
    return ap.parse_args()

# ---------------------------
# Prompt formatting
# ---------------------------

INSTR_HEADER = (
    "You are a support assistant for enterprise server products.\n"
    "Summarize the technical issue in one sentence.\n"
    "Rules:\n"
    "1) Include product/model if available.\n"
    "2) Specify affected component and symptom.\n"
    "3) Retain exact error messages, codes, or event IDs.\n"
    "4) Exclude resolutions, troubleshooting, customer identifiers, names, emails, phone numbers, addresses, dates, signatures, and case IDs.\n"
)

PROMPT_TEMPLATE = (
    "<|start_of_prompt|>\n"
    f"{INSTR_HEADER}\n\n"
    "Case Description:\n{source}\n\n"
    "Summary:"
)

def build_prompt(source_text: str) -> str:
    return PROMPT_TEMPLATE.format(source=source_text.strip())

# ---------------------------
# Loss masking collator
# ---------------------------

@dataclass
class CausalLMCollator:
    tokenizer: AutoTokenizer
    mlm: bool = False  # for Trainer compatibility
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        input_ids = [torch.tensor(ex["input_ids"], dtype=torch.long) for ex in batch]
        labels = [torch.tensor(ex["labels"], dtype=torch.long) for ex in batch]
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(
            labels, batch_first=True, padding_value=-100
        )
        attention_mask = (input_ids != self.tokenizer.pad_token_id).long()
        return {"input_ids": input_ids, "labels": labels, "attention_mask": attention_mask}

# ---------------------------
# Data prep
# ---------------------------

def prepare_datasets(df: pd.DataFrame, input_col: str, target_col: str,
                     tokenizer: AutoTokenizer, max_src: int, max_tgt: int) -> Dataset:
    def _encode(row):
        src = str(row[input_col]) if pd.notna(row[input_col]) else ""
        tgt = str(row[target_col]) if pd.notna(row[target_col]) else ""
        prompt = build_prompt(src)

        # Tokenize prompt and target separately
        prompt_ids = tokenizer(prompt, add_special_tokens=False, truncation=True, max_length=max_src)["input_ids"]
        target_ids = tokenizer(" " + tgt.strip(), add_special_tokens=False, truncation=True, max_length=max_tgt)["input_ids"]

        # Concatenate
        input_ids = prompt_ids + target_ids
        # Labels: ignore loss on prompt; learn on target
        labels = [-100] * len(prompt_ids) + target_ids

        return {"input_ids": input_ids, "labels": labels}

    ds = Dataset.from_pandas(df[[input_col, target_col]].reset_index(drop=True))
    ds = ds.map(_encode, remove_columns=ds.column_names, desc="Tokenizing")
    return ds

# ---------------------------
# Evaluation helpers
# ---------------------------

def generate_predictions(trainer: Trainer, ds: Dataset, tokenizer: AutoTokenizer,
                         max_new_tokens: int = 96) -> List[str]:
    preds = []
    model = trainer.model
    model.eval()
    device = next(model.parameters()).device
    with torch.no_grad():
        for ex in ds:
            inp = torch.tensor(ex["input_ids"], dtype=torch.long, device=device).unsqueeze(0)
            attn = torch.tensor(ex["attention_mask"], dtype=torch.long, device=device).unsqueeze(0) \
                   if "attention_mask" in ex else None
            out = model.generate(
                input_ids=inp,
                attention_mask=attn,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                temperature=0.0,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )
            gen_ids = out[0][inp.shape[1]:]
            text = tokenizer.decode(gen_ids, skip_special_tokens=True)
            preds.append(text.strip())
    return preds

def compute_text_metrics(preds: List[str], refs: List[str]) -> Dict[str, Union[float, str]]:
    rouge = evaluate.load("rouge")
    rouge_scores = rouge.compute(predictions=preds, references=refs, use_aggregator=True)
    metrics = {
        "rouge1": rouge_scores.get("rouge1", 0.0),
        "rouge2": rouge_scores.get("rouge2", 0.0),
        "rougeL": rouge_scores.get("rougeL", 0.0),
    }
    try:
        bertscore = evaluate.load("bertscore")
        bs = bertscore.compute(predictions=preds, references=refs, lang="en")
        metrics["bertscore_f1_mean"] = float(np.mean(bs["f1"]))
    except Exception:
        metrics["bertscore_f1_mean"] = float("nan")
        metrics["bertscore_note"] = "Install with: pip install evaluate bert-score"
    return metrics

# ---------------------------
# Main
# ---------------------------

def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    # Determinism
    set_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Load CSV
    df = pd.read_csv(args.csv)
    for col in (args.input_col, args.target_col):
        if col not in df.columns:
            raise ValueError(f"Column '{col}' not found in {args.csv}")

    # Train/val split with safe stratify fallback
    try:
        stratifier = df[args.target_col] if args.target_col in df.columns else None
        train_df, val_df = train_test_split(
            df, test_size=0.1, random_state=args.seed, stratify=stratifier
        )
    except ValueError:
        train_df, val_df = train_test_split(df, test_size=0.1, random_state=args.seed)

    # Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # 4-bit NF4 quantization config (QLoRA)
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16 if args.bf16 and torch.cuda.is_available() else torch.float16,
        bnb_4bit_use_double_quant=True,
    )

    # Load base model in 4-bit and prepare for k-bit training
    model = AutoModelForCausalLM.from_pretrained(
        args.model,
        quantization_config=bnb_config,
        device_map="auto",
    )

    if args.gradient_checkpointing:
        model.gradient_checkpointing_enable()
        model.config.use_cache = False  # required when using gradient checkpointing

    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing)

    # LoRA config
    lora_config = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"],  # common for LLaMA
        task_type="CAUSAL_LM",
        bias="none",
    )
    model = get_peft_model(model, lora_config)

    # Datasets
    train_ds = prepare_datasets(
        train_df, args.input_col, args.target_col, tokenizer,
        args.max_source_tokens, args.max_target_tokens
    )
    val_ds_full = prepare_datasets(
        val_df, args.input_col, args.target_col, tokenizer,
        args.max_source_tokens, args.max_target_tokens
    )

    data_collator = CausalLMCollator(tokenizer=tokenizer)

    # Training args
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.per_device_train_batch_size,
        per_device_eval_batch_size=args.per_device_eval_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        learning_rate=args.lr,
        warmup_ratio=args.warmup_ratio,
        weight_decay=0.0,
        logging_steps=args.logging_steps,
        evaluation_strategy="steps",
        eval_steps=args.eval_steps,
        save_strategy="no" if args.save_steps == 0 else "steps",
        save_steps=args.save_steps,
        bf16=args.bf16 and torch.cuda.is_available(),
        gradient_checkpointing=args.gradient_checkpointing,
        dataloader_num_workers=2,
        report_to=None if args.report_to.lower() == "none" else args.report_to,
        ddp_find_unused_parameters=False,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=val_ds_full,
        data_collator=data_collator,
        tokenizer=tokenizer,
    )

    # Train adapters
    trainer.train()

    # Save PEFT adapters and tokenizer (this is the artifact reviewers will load)
    model.save_pretrained(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)

    # ---------------------------
    # Evaluation (deterministic)
    # ---------------------------
    if args.eval_limit and args.eval_limit > 0 and len(val_ds_full) > args.eval_limit:
        idx = list(range(min(args.eval_limit, len(val_ds_full))))
        val_ds = val_ds_full.select(idx)
        base_val_df = val_df.iloc[idx].reset_index(drop=True)
    else:
        val_ds = val_ds_full
        base_val_df = val_df.reset_index(drop=True)

    preds = generate_predictions(trainer, val_ds, tokenizer, max_new_tokens=args.max_target_tokens)
    refs = [str(x) if pd.notna(x) else "" for x in base_val_df[args.target_col].tolist()]
    metrics = compute_text_metrics(preds, refs)

    print("Validation metrics:", json.dumps(metrics, indent=2))
    with open(os.path.join(args.output_dir, "val_metrics.json"), "w") as f:
        json.dump(metrics, f, indent=2)

    # Save predictions for transparency
    out_df = pd.DataFrame({
        "reference": refs,
        "prediction": preds
    })
    out_df.to_csv(os.path.join(args.output_dir, "val_predictions.csv"), index=False)
    print(f"Saved LoRA adapters and eval artifacts to: {args.output_dir}")

if __name__ == "__main__":
    main()
