from transformers import AutoTokenizer
import matplotlib.pyplot as plt
import numpy as np
from datasets import load_from_disk

# 데이터셋과 토크나이저 로드
dataset_dict = load_from_disk("data/llama3.2-1b-deita-dpomix/adpa_dataset_0epoch")
train_ds = dataset_dict["train"]
tokenizer = AutoTokenizer.from_pretrained("./model/dpo_teacher")

# 어떤 샘플을 볼지 고르기
sample_index = 0
row = train_ds[sample_index]
margin_logps = row["rejected_margin_logp_every"]

# 디코딩해서 vocab ID → 단어/토큰 문자열로 변환
decoded_tokens = []
logp_values = []

for token_step in margin_logps:
    indices = token_step["indices"]
    values = token_step["values"]
    
    # 디코딩 (각 인덱스를 하나씩 decode, skip 잘 안 되는 token은 '[UNK]' 처리)
    for idx, val in zip(indices, values):
        try:
            token_str = tokenizer.decode([idx]).strip()
        except Exception:
            token_str = "[UNK]"
        decoded_tokens.append(token_str)
        logp_values.append(val)

# 디코딩 결과 몇 개 출력
for tok, val in zip(decoded_tokens[:20], logp_values[:20]):
    print(f"{tok}: {val:.3f}")

# 시각화
plt.figure(figsize=(12, 5))
plt.hist(logp_values, bins=100)
plt.title("Distribution of log(prob_dpoteacher / prob_refteacher)")
plt.xlabel("log prob margin")
plt.ylabel("Count")
plt.grid(True)
plt.savefig('test_vis.png')
plt.show()