#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os, re, json, argparse, random, wandb
from datetime import datetime

import torch
from datasets import Dataset
from transformers import (
    GPT2Tokenizer, GPT2Config, DataCollatorForLanguageModeling, GPT2LMHeadModel
)
from trl import SFTTrainer, SFTConfig


# ----- 유틸 ---------------------------------------------------------
def set_seed(seed: int):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def load_texts(path: str):
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    if isinstance(data, list):
        if len(data) > 0 and isinstance(data[0], dict) and "text" in data[0]:
            return [str(x["text"]) for x in data]
        return [str(x) for x in data]
    raise ValueError("입력 JSON은 리스트 형태여야 합니다.")

def stem(path: str):
    s = os.path.splitext(os.path.basename(path))[0]
    return re.sub(r"[^A-Za-z0-9._\-]+", "_", s) or "dataset"


# ----- 메인 ---------------------------------------------------------
def main():
    p = argparse.ArgumentParser(description="Simple GPT-2 training script (single JSON).")
    p.add_argument("--data", type=str, required=True, help="학습 데이터(.json): 리스트[str] 또는 리스트[{'text':...}]")
    p.add_argument("--output_root", type=str, default="./gpt2_runs_0816", help="출력 루트")
    p.add_argument("--max_steps", type=int, default=16000)
    p.add_argument("--batch_size", type=int, default=128)
    p.add_argument("--grad_accum", type=int, default=1)
    p.add_argument("--lr", type=float, default=4e-4)
    p.add_argument("--weight_decay", type=float, default=0.1)
    p.add_argument("--logging_steps", type=int, default=10)
    p.add_argument("--save_steps", type=int, default=100)
    p.add_argument("--max_seq_len", type=int, default=512)
    p.add_argument("--seed", type=int, default=42)
    args = p.parse_args()

    wandb.init(project="gpt2-0816", name=f"{args.data}", config=vars(args))

    os.makedirs(args.output_root, exist_ok=True)
    set_seed(args.seed)
  
    texts = load_texts(args.data)
    train_ds = Dataset.from_dict({"text": texts})

    # ----- 토크나이저 / 모델 ----------------------------------------
    tok = GPT2Tokenizer.from_pretrained("gpt2")
    tok.pad_token = tok.eos_token

    cfg = GPT2Config(
        n_embd=512, n_layer=8, n_head=8, n_inner=2048,
        n_positions=args.max_seq_len, n_ctx=args.max_seq_len,
        pad_token_id=tok.pad_token_id, vocab_size=len(tok)
    )
    model = GPT2LMHeadModel(cfg)
    model.resize_token_embeddings(len(tok))

    # ----- 출력 경로 -------------------------------------------------
    stamp = datetime.now().strftime("%y%m%d-%H%M%S")
    out_dir = os.path.join(args.output_root, f"{stem(args.data)}_{stamp}")
    os.makedirs(out_dir, exist_ok=True)

    # ----- SFT 설정 --------------------------------------------------
    sft_args = SFTConfig(
        output_dir=out_dir,
        overwrite_output_dir=True,
        max_steps=args.max_steps,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=args.grad_accum,
        learning_rate=args.lr,
        lr_scheduler_type="cosine",
        weight_decay=args.weight_decay,
        bf16=True,
        logging_steps=args.logging_steps,
        save_steps=args.save_steps,
        eval_strategy="no",          # 평가 제거
        report_to=["wandb"],
        packing=True, 
        max_seq_length=args.max_seq_len,
    )

    trainer = SFTTrainer(
        model=model,
        args=sft_args,
        train_dataset=train_ds,
        tokenizer=tok,
        data_collator=None,
    )

    print("Start training …")
    trainer.train()
    print("Training done.")

    final_dir = os.path.join(out_dir, "_final")
    trainer.save_model(final_dir)
    tok.save_pretrained(final_dir)
    print(f"Saved model to: {final_dir}")


    wandb.finish()
 

if __name__ == "__main__":
    main()
