# -*- coding: utf-8 -*-

# ============================
# Force single-GPU training to avoid DataParallel scatter
# (do this BEFORE importing torch/transformers)
# ============================
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Train on one GPU -> no scatter across busy GPUs
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")  # optional stability


# Finetune Llama 3.1 Instruct on Excel (case_text -> issue_text), save, reload, and evaluate.
# -----------------------------------------------------------
# Prereqs (run once in your env):
# pip install "transformers>=4.41" "datasets>=2.19" "accelerate>=0.30" peft bitsandbytes evaluate pandas scikit-learn

import random
import re
import torch
import pandas as pd
from dataclasses import dataclass
from typing import Dict, List, Union, Optional

from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict
import evaluate

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding,
)

from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel

# -----------------------------
# CONFIG
# -----------------------------
TRAIN_INPUT_PATH = "train_llm.csv"  # The Excel file with columns matching TEXT_COL_INPUT/TEXT_COL_TARGET
TEST_INPUT_PATH = "test_llm.csv"  # The Excel file with columns matching TEXT_COL_INPUT/TEXT_COL_TARGET
TEXT_COL_INPUT = "text"
label_col_name = "issue"
TEXT_COL_TARGET = "llama-3-3-70b-instruct"

# Llama 3.1 Instruct (adjust variant if needed)
#BASE_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
BASE_MODEL = "meta-llama/Llama-3.1-8B"

OUTPUT_DIR = "Llama-3.1-8B_lora"

# Tokenizer & training limits
MAX_LENGTH = 2048
GEN_MAX_NEW_TOKENS = 25
SEED = 42

# LoRA / QLoRA
USE_4BIT = True
LORA_R = 32
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]

# Train hyperparams (good starting points; tune to your setup)
NUM_EPOCHS = 2            # 23 passes are enough; more = risk of memorization
LR = 1e-4                 # Slightly lower LR for stability on small dataset
PER_DEVICE_TRAIN_BATCH = 4
PER_DEVICE_EVAL_BATCH = 4
GRAD_ACCUM_STEPS = 8      # Effective batch = 32 sequences per update (4 × 8)
LOGGING_STEPS = 10        # More frequent logs so you see progress
EVAL_STEPS = 50           # Evaluate ~60 times per epoch for quick feedback
SAVE_STEPS = 100          # Save every ~1/2 epoch
WARMUP_RATIO = 0.1        # Slightly higher warmup so learning starts gently

# -----------------------------
# REPRO
# -----------------------------
random.seed(SEED)
torch.manual_seed(SEED)

# -----------------------------
# LOAD DATA (no cleaning requested)
# -----------------------------
df = pd.read_csv(TRAIN_INPUT_PATH)
test_df = pd.read_csv(TEST_INPUT_PATH)
df[TEXT_COL_INPUT] = df[TEXT_COL_INPUT].fillna("").astype(str)
test_df[TEXT_COL_INPUT] = test_df[TEXT_COL_INPUT].fillna("").astype(str)
df[TEXT_COL_TARGET] = df[TEXT_COL_TARGET].fillna("").astype(str)
test_df[TEXT_COL_TARGET] = test_df[TEXT_COL_TARGET].fillna("").astype(str)


# Train/val/test split
train_df, val_df = train_test_split(
    df,
    test_size=0.1,
    random_state=42,
    stratify=df[label_col_name]
)

# Build HF datasets
ds = DatasetDict(
    train=Dataset.from_pandas(train_df, preserve_index=False),
    validation=Dataset.from_pandas(val_df, preserve_index=False),
    test=Dataset.from_pandas(test_df, preserve_index=False),
)

# -----------------------------
# PROMPT TEMPLATE
# -----------------------------
INSTRUCTION = (
    "You are a support assistant for enterprise server products.\n"
    "Summarize the technical issue in ONE sentence.\n"
    "Rules:\n"
    "1) Include product/model if available.\n"
    "2) Specify affected component and symptom.\n"
    "3) Retain exact error messages, codes, or event IDs verbatim.\n"
    "4) EXCLUDE resolutions, troubleshooting, customer identifiers, names, emails, phone numbers, "
    "addresses, dates, signatures, and case IDs.\n"
    "Only output the single-sentence summary."
)

def build_prompt(case_text: str) -> str:
    return (
        "Instruction:\n" + INSTRUCTION +
        "\nCase:\n" + case_text.strip() + "\n\nAnswer:\n"
    )

# -----------------------------
# TOKENIZER / MODEL (QLoRA) -- TRAIN
# -----------------------------
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

bnb_config = None
if USE_4BIT:
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
    )

# IMPORTANT: device_map=None during training (no sharding; avoids DataParallel scatter)
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.bfloat16 if not USE_4BIT else None,
    quantization_config=bnb_config if USE_4BIT else None,
    low_cpu_mem_usage=True,
)
# Avoid gradient checkpointing conflict & cache overhead
base_model.config.use_cache = False

if USE_4BIT:
    base_model = prepare_model_for_kbit_training(base_model)

lora_cfg = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=TARGET_MODULES,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(base_model, lora_cfg)
model.print_trainable_parameters()

# -----------------------------
# TOKENIZATION WITH LOSS MASKING
# (loss only on the answer tokens)
# -----------------------------
def tokenize_with_labels(example):
    case_text = example[TEXT_COL_INPUT]
    target_text = example[TEXT_COL_TARGET]

    prompt = build_prompt(case_text)
    full_text = prompt + target_text

    # Tokenize full
    full = tokenizer(
        full_text,
        truncation=True,
        max_length=MAX_LENGTH,
        padding="max_length",
        return_tensors=None,
    )
    # Tokenize prompt only (to know where the answer starts)
    prompt_tok = tokenizer(prompt, truncation=True, max_length=MAX_LENGTH, return_tensors=None)

    labels = full["input_ids"][:]
    # Mask out the prompt part
    prompt_len = len(prompt_tok["input_ids"])
    labels = [-100]*prompt_len + labels[prompt_len:]

    # If we padded, also mask any padding positions to -100
    for i in range(len(labels)):
        if full["attention_mask"][i] == 0:
            labels[i] = -100

    full["labels"] = labels
    return full

tokenized = ds.map(tokenize_with_labels, batched=False, remove_columns=ds["train"].column_names)

# -----------------------------
# DATA COLLATOR (pad already done; still need a collator)
# -----------------------------
@dataclass
class DataCollatorForCausalLM:
    tokenizer: AutoTokenizer
    def __call__(self, features):
        batch = {}
        for key in features[0].keys():
            batch[key] = [f[key] for f in features]

        batch = {
            "input_ids": torch.tensor(batch["input_ids"], dtype=torch.long),
            "attention_mask": torch.tensor(batch["attention_mask"], dtype=torch.long),
            "labels": torch.tensor(batch["labels"], dtype=torch.long),
        }
        return batch

data_collator = DataCollatorForCausalLM(tokenizer=tokenizer)

# -----------------------------
# TRAIN
# -----------------------------
os.makedirs(OUTPUT_DIR, exist_ok=True)

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH,
    per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS,
    learning_rate=LR,
    warmup_ratio=WARMUP_RATIO,
    logging_steps=LOGGING_STEPS,
    eval_steps=EVAL_STEPS,
    save_steps=SAVE_STEPS,
    save_total_limit=2,
    lr_scheduler_type="cosine",
    bf16=not USE_4BIT,                 # if 4bit, bnb compute dtype is set above
    gradient_checkpointing=True,
    optim="paged_adamw_8bit" if USE_4BIT else "adamw_torch",
    report_to="none",
    remove_unused_columns=False,       # safer with custom features
    ddp_find_unused_parameters=False,  # helpful if you later switch to DDP + PEFT
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

# Optional quick device sanity check:
# dl = trainer.get_train_dataloader()
# batch = next(iter(dl))
# for k, v in batch.items():
#     print(k, getattr(v, "device", "cpu"))

trainer.train()

# Save LoRA adapter + tokenizer
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

# -----------------------------
# RELOAD FOR INFERENCE
# -----------------------------
# For generation, sharding is fine if you need it.
gen_bnb_config = bnb_config  # reuse
gen_base = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.bfloat16 if not USE_4BIT else None,
    quantization_config=gen_bnb_config if USE_4BIT else None,
    device_map="auto",                 # OK to shard for inference
    low_cpu_mem_usage=True,
)
gen_model = PeftModel.from_pretrained(gen_base, OUTPUT_DIR)
gen_model.eval()

gen_tokenizer = AutoTokenizer.from_pretrained(OUTPUT_DIR, use_fast=True)
if gen_tokenizer.pad_token is None:
    gen_tokenizer.pad_token = gen_tokenizer.eos_token

# -----------------------------
# EVALUATE ON TEST (ROUGE-L + few samples)
# -----------------------------
rouge = evaluate.load("rouge")  # computes ROUGE-1/2/L

def generate_issue(case_text: str, max_new_tokens: int = GEN_MAX_NEW_TOKENS) -> str:
    prompt = build_prompt(case_text)
    inputs = gen_tokenizer(prompt, return_tensors="pt").to(gen_model.device)
    with torch.no_grad():
        out = gen_model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,         # deterministic for eval
            temperature=0.0,
            top_p=1.0,
            repetition_penalty=1.1,
            pad_token_id=gen_tokenizer.eos_token_id,
            eos_token_id=gen_tokenizer.eos_token_id,
        )
    text = gen_tokenizer.decode(out[0], skip_special_tokens=True)
    # Extract only the part after "Answer:\n" if present
    if "Answer:" in text:
        text = text.split("Answer:", 1)[-1].strip()
    return text.strip()

references = []
predictions = []

for row in ds["test"]:
    case_text = row[TEXT_COL_INPUT]
    target = row[TEXT_COL_TARGET]
    pred = generate_issue(case_text)
    references.append(target)
    predictions.append(pred)

rouge_scores = rouge.compute(predictions=predictions, references=references)
print("ROUGE scores on test:", rouge_scores)

# Show a few examples
for i in range(min(5, len(ds["test"]))):
    print("\n--- Example", i+1, "---")
    print("CASE:", ds["test"][i][TEXT_COL_INPUT][:600].replace("\n"," "))
    print("TARGET:", references[i])
    print("PRED  :", predictions[i])

# Optional: save predictions
pd.DataFrame({
    "case_text": [r[TEXT_COL_INPUT] for r in ds["test"]],
    "target_issue_text": references,
    "pred_issue_text": predictions
}).to_csv(os.path.join(OUTPUT_DIR, "test_predictions.csv"), index=False)

print("\nDone. Adapter + tokenizer saved in:", OUTPUT_DIR)
