import os
import torch
import pandas as pd
from transformers import AutoTokenizer

# --- CONFIG ---
CSV_PATH = CSV_PATH_chat_base
TOKENIZER_PATH = "//models/meta-llama/Llama-2-7b-hf"
OUT_CSV = CSV_PATH + ".tokdiff.csv"


import os, json
import torch
import pandas as pd
from transformers import AutoTokenizer


# --- HELPERS ---
def add_pad_token(tok):
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
        tok.pad_token_id = tok.eos_token_id
    tok.padding_side = "left"
    return tok

def decode_piece(tokenizer, tid: int) -> str:
    return tokenizer.decode([tid], skip_special_tokens=False).replace("▁", " ")

def process_pt(pt_path: str, tokenizer) -> list[dict]:
    """Return ordered list of {'token': str, 'p_diff': float} across all steps/results in the .pt."""
    results_list = torch.load(pt_path, map_location="cpu")
    if not isinstance(results_list, list) or len(results_list) == 0:
        return []

    out = []
    for res in results_list:
        token_ids  = res["token_ids"]
        p_dexperts = res["p_dexperts"]
        p_base     = res["p_base"]

        if torch.is_tensor(token_ids):  token_ids  = token_ids.tolist()
        if torch.is_tensor(p_dexperts): p_dexperts = p_dexperts.to(torch.float32).tolist()
        if torch.is_tensor(p_base):     p_base     = p_base.to(torch.float32).tolist()

        if not (len(token_ids) == len(p_dexperts) == len(p_base)):
            # skip malformed entries but keep going
            continue

        for tid, pdexp, pbas in zip(token_ids, p_dexperts, p_base):
            tok = decode_piece(tokenizer, int(tid))
            if not tok.strip():
                continue
            out.append({"token": tok, "p_diff": float(pdexp - pbas)})

    return out

def ensure_csv_path(path: str) -> str:
    if path.endswith(".csv"):
        return path
    # try "<path>.csv" first, then treat as a directory containing "logits_analysis" or one CSV
    if os.path.isfile(path + ".csv"):
        return path + ".csv"
    if os.path.isdir(path):
        # prefer a file literally named logits_analysis or logits_analysis.csv
        for cand in ("logits_analysis", "logits_analysis.csv"):
            p = os.path.join(path, cand)
            if os.path.isfile(p):
                return p
        # last resort: first .csv in the folder
        for name in os.listdir(path):
            if name.lower().endswith(".csv"):
                return os.path.join(path, name)
        raise FileNotFoundError(f"No CSV found in directory: {path}")
    if os.path.isfile(path):
        return path
    raise FileNotFoundError(f"CSV not found: {path} (tried with/without .csv)")

def main():
    csv_path = ensure_csv_path(CSV_PATH)
    df = pd.read_csv(csv_path)
    if "logits_path" not in df.columns:
        raise ValueError("CSV must contain a 'logits_path' column.")

    tok = AutoTokenizer.from_pretrained(TOKENIZER_PATH, use_fast=True)
    tok = add_pad_token(tok)

    tokdiff_json_col = []
    for i, row in df.iterrows():
        pt_path = row["logits_path"]
        try:
            items = process_pt(pt_path, tok)  # list of {"token": str, "p_diff": float}
            # store as compact JSON (ordered list preserves order + duplicates)
            tokdiff_json_col.append(json.dumps(items, ensure_ascii=False, separators=(",", ":")))
        except Exception as e:
            print(f"Error reading {pt_path}: {e}")
            tokdiff_json_col.append("[]")

    df["tokdiff_json"] = tokdiff_json_col

    out_csv = OUT_CSV or (csv_path + ".tokdiff.csv")
    df.to_csv(out_csv, index=False)
    print(f"Saved with token diffs JSON to {out_csv}")

if __name__ == "__main__":
     main()

# # Load tokenizer
# tok = AutoTokenizer.from_pretrained(TOKENIZER_PATH)
# if tok.pad_token is None:
#     tok.pad_token = tok.eos_token

# # Load base CSV
# df = pd.read_csv(CSV_PATH)

# def extract_tokdiff_dict(pt_path):
#     data = torch.load(pt_path, map_location="cpu")
#     # Expecting something like: {"tokens": [...], "p_expert": [...], "p_base": [...]}
#     tokens = data["tokens"]              # list[int]
#     p_expert = data["p_expert"]          # tensor or list[float]
#     p_base = data["p_base"]

#     # Convert to python list
#     if torch.is_tensor(p_expert): p_expert = p_expert.tolist()
#     if torch.is_tensor(p_base): p_base = p_base.tolist()

#     # Compute diffs and decode tokens
#     diffs = {}
#     for tid, pe, pb in zip(tokens, p_expert, p_base):
#         tok_str = tok.decode([tid])
#         diffs[tok_str] = pe - pb
#     return diffs

# # Add new column
# new_col = []
# for _, row in df.iterrows():
#     try:
#         tokdict = extract_tokdiff_dict(row["logits_path"])
#         new_col.append(tokdict)
#     except Exception as e:
#         print(f"Error reading {row['logits_path']}: {e}")
#         new_col.append({})
# df["tok_diffs"] = new_col

# # Save
# df.to_csv(OUT_CSV, index=False)
# print(f"Saved with token differences to {OUT_CSV}")
