
import os, re, random, pickle
os.environ["CUDA_VISIBLE_DEVICES"] = "3"


from dataclasses import dataclass
from typing import List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import networkx as nx

from transformers import AutoModelForCausalLM, PreTrainedTokenizerFast
from peft import LoraConfig, get_peft_model, TaskType

# -------------------------
# Config
# -------------------------
@dataclass
class PPOHyper:
    lr: float = 2e-5
    vf_lr: float = 1e-4
    batch_size: int = 16
    mini_batch: int = 4
    ppo_epochs: int = 2
    clip_eps: float = 0.2
    kl_coef: float = 0.05
    gamma: float = 1.0          # episodic end reward
    lam: float = 0.95
    max_prompt_len: int = 128
    max_new_tokens: int = 32
    temperature: float = 1.0
    top_p: float = 0.95
    steps: int = 20000
    log_every: int = 20

H = PPOHyper()


device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_float32_matmul_precision("high")

seed = 42
random.seed(seed)
torch.manual_seed(seed)

# -------------------------
# Paths (fill yours)
# -------------------------

n_layer = 6
base_model_path = f''

n = 1000
k_ratio = 0.1
p_in = 0.1
p_out = 0.01
method = 'path'
hidden_size = 512
block_size = 128
max_length = 64
train_num_ratio = 0.5
overlap_method = 'full'  # 'none' or 'full' or 'partial'
model_selected = 4
new_train_ratio = 0.2 

backbone_model = 'llama'  # 'llama' or 'mixtral' or 'qwen'  :contentReference[oaicite:2]{index=2}

graph_pkl = os.path.join(base_model_path,f"{n}_{k_ratio}_{p_in}_{p_out}.pkl")
base_model_path = os.path.join(base_model_path, f"{n}_{k_ratio}_{p_in}_{p_out}")
tokenizer_path = os.path.join(base_model_path, f"baby_tokenizer.json")

output_dir = f'{method}_{n}_{k_ratio}_{p_in}_{p_out}/{backbone_model}_{n_layer}_{hidden_size}_{train_num_ratio}/'





backbone_model = "llama"  # llama/qwen/mixtral/...

def pick_checkpoint(which=1):
    checkpoints = []
    epoch_steps = []
    for f in os.listdir(output_dir):
        if f.endswith('.pkl'):
            continue
        if not f.startswith('checkpoint-'):
            continue
        if 'lora_finetune' in f:
            continue
        if 'ppo' in f:
            continue
        # epoch_steps.append(int(f.split('-')[-1]))
        full = os.path.join(output_dir, f)
        if os.path.isdir(full):
            checkpoints.append(full)
    checkpoints = sorted(checkpoints, key=lambda x: int(x.split('-')[-1]))
    return checkpoints[which]

base_or_ckpt_dir = output_dir 
ckpt_path = pick_checkpoint(model_selected)

# -------------------------
# Load tokenizer
# -------------------------
tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_path)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# -------------------------
# Load corpus + graph
# -------------------------
# with open(train_pkl, "rb") as f:
#     corpus: List[str] = pickle.load(f)


with open(os.path.join(f'{base_model_path}',f'{method}_soft_train.pkl'),'rb') as f:
    new_training_corpus = pickle.load(f)

train_num = len(new_training_corpus)
valid_num = int(train_num*0.1)
train_num = int(len(new_training_corpus)*train_num_ratio)
new_training_corpus = new_training_corpus[valid_num:]

ori_train_num = max(1, int(len(new_training_corpus) * train_num_ratio))
new_train_num = max(1, int(len(new_training_corpus) * new_train_ratio))

if overlap_method == 'full':
    new_corpus = new_training_corpus[ori_train_num - new_train_num:ori_train_num]
elif overlap_method == 'partial':
    new_overlap_train_num = max(1, int(len(new_training_corpus) * (new_train_ratio/2)))
    new_corpus = new_training_corpus[ori_train_num - new_overlap_train_num : ori_train_num - new_overlap_train_num + new_train_num]
elif overlap_method == 'none':
    new_corpus = new_training_corpus[ori_train_num:ori_train_num + new_train_num]

out_dir = os.path.join(output_dir, f"ppo_lora_{model_selected}_{overlap_method}_{new_train_ratio}")
os.makedirs(out_dir, exist_ok=True)
print("new_corpus size =", len(new_corpus))

# 划分 train/val
valid_num = max(1, int(len(new_corpus) * 0.1))
valid_corpus = new_corpus[:valid_num]
training_corpus = new_corpus[valid_num:]
corpus = training_corpus  # use training portion for PPO

with open(graph_pkl, "rb") as f:
    G = pickle.load(f)
assert hasattr(G, "has_edge")
print("graph loaded:", type(G))

# -------------------------
# Build prompts
# -------------------------
ENDP = "END_P"
PATH = "PATH"
def build_prompt(text: str) -> str:
    i = text.find(PATH)
    if i != -1:
        return text[: i + len(PATH)].strip()
    ids = tokenizer(text, truncation=True, max_length=H.max_prompt_len)["input_ids"]
    half = max(4, len(ids)//2)
    return tokenizer.decode(ids[:half], skip_special_tokens=False)

def sample_prompts(bs: int) -> List[str]:
    batch = random.sample(corpus, bs)
    return [build_prompt(x) for x in batch]

# -------------------------
# Reward: PATH legality
# -------------------------
PATH_RE = re.compile(r"\bPATH\b(.*?)\bEND_P\b", flags=re.DOTALL)

def parse_path_nodes(text: str) -> Optional[List[int]]:
    m = PATH_RE.search(text)
    if not m:
        return None
    seg = m.group(1).strip()
    if not seg:
        return None
    toks = seg.split()
    if len(toks) < 2:
        return None
    if not all(t.lstrip("-").isdigit() for t in toks):
        return None
    return [int(t) for t in toks]

def path_reward(full_text: str) -> float:
    nodes = parse_path_nodes(full_text)
    if nodes is None:
        return -1.0
    for i in range(len(nodes)-1):
        # dist = nx.shortest_path_length(G, source=int(nodes[i]), target=int(nodes[i+1]))
        # print('dist', dist)
        if not G.has_edge(nodes[i], nodes[i+1]):
            return -1.0
    return 1.0

# -------------------------
# Models: policy + ref (frozen)
# -------------------------
policy = AutoModelForCausalLM.from_pretrained(
    ckpt_path,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None,
).to(device)
policy.config.use_cache = False

ref = AutoModelForCausalLM.from_pretrained(
    ckpt_path,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None,
).to(device)
ref.config.use_cache = False
ref.eval()
for p in ref.parameters():
    p.requires_grad = False

# LoRA on policy
if backbone_model in ["llama", "qwen", "mixtral"]:
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
else:
    target_modules = ["c_attn", "c_proj"]

lora_cfg = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules=target_modules,
)

try:
    policy.enable_input_require_grads()
except Exception:
    pass
policy = get_peft_model(policy, lora_cfg)
policy.print_trainable_parameters()

# Value function: simplest = scalar head on last hidden state of policy (no extra model)
# (If you prefer a separate value_model, you can swap this out.)
v_head = nn.Linear(policy.base_model.model.config.hidden_size, 1, bias=False).to(device)

# Optimizers
opt = torch.optim.AdamW([p for p in policy.parameters() if p.requires_grad], lr=H.lr)
vf_opt = torch.optim.AdamW(v_head.parameters(), lr=H.vf_lr)

# -------------------------
# Helpers: logprobs, values
# -------------------------
def shift_logits_and_labels(logits, labels):
    # logits: [B,T,V], labels: [B,T]
    # return token-level logp for labels at each position
    logp = F.log_softmax(logits, dim=-1)
    tok_logp = torch.gather(logp, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
    return tok_logp

@torch.no_grad()
def compute_logprobs(model, input_ids, attn_mask):
    out = model(input_ids=input_ids, attention_mask=attn_mask, use_cache=False)
    # token logp for next token prediction; align with labels = input_ids[:,1:]
    tok_logp = shift_logits_and_labels(out.logits[:, :-1, :], input_ids[:, 1:])
    # mask aligns to labels positions
    mask = attn_mask[:, 1:].float()
    return tok_logp, mask, out.hidden_states if getattr(out, "hidden_states", None) is not None else None

def get_last_hidden(policy_out, attn_mask):
    # policy_out.hidden_states[-1]: [B,T,H]
    hs = policy_out.hidden_states[-1]
    # last nonpad index
    lengths = attn_mask.sum(dim=1).long() - 1
    last = hs[torch.arange(hs.size(0), device=hs.device), lengths]  # [B,H]
    return last

# -------------------------
# Rollout: generate, compute rewards, store old logp/value
# -------------------------
@torch.no_grad()
def rollout():
    prompts = sample_prompts(H.batch_size)
    # print(prompts[0])  # print one prompt for debug
    enc = tokenizer(
        prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=H.max_prompt_len,
    ).to(device)
    enc.pop("token_type_ids", None)

    gen = policy.generate(
        **enc,
        max_new_tokens=H.max_new_tokens,
        do_sample=True,
        temperature=H.temperature,
        top_p=H.top_p,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        return_dict_in_generate=False,
    )

    # build attention mask for generated
    attn = (gen != tokenizer.pad_token_id).long()

    # compute old logprobs under current policy (treat as old snapshot for this batch)
    pol_out = policy(input_ids=gen, attention_mask=attn, output_hidden_states=True, use_cache=False)
    old_tok_logp = shift_logits_and_labels(pol_out.logits[:, :-1, :], gen[:, 1:])
    old_mask = attn[:, 1:].float()

    # ref logp for KL
    ref_out = ref(input_ids=gen, attention_mask=attn, use_cache=False)
    ref_tok_logp = shift_logits_and_labels(ref_out.logits[:, :-1, :], gen[:, 1:])

    # value from last hidden
    last_h = get_last_hidden(pol_out, attn)
    old_values = v_head(last_h.float()).squeeze(-1)  # [B]

    # rewards from decoded full string
    decoded = tokenizer.batch_decode(gen, skip_special_tokens=False)
    rewards = torch.tensor([path_reward(t) for t in decoded], device=device, dtype=torch.float32)  # [B]

    return gen, attn, old_tok_logp.detach(), ref_tok_logp.detach(), old_mask.detach(), old_values.detach(), rewards.detach()

# -------------------------
# PPO update
# -------------------------
def ppo_update(gen, attn, old_logp, ref_logp, mask, old_values, rewards):
    B = gen.size(0)
    idx = torch.randperm(B, device=device)

    for _ in range(H.ppo_epochs):
        for start in range(0, B, H.mini_batch):
            mb = idx[start:start+H.mini_batch]

            mb_ids = gen[mb]
            mb_attn = attn[mb]
            mb_old_logp = old_logp[mb]
            mb_ref_logp = ref_logp[mb]
            mb_mask = mask[mb]
            mb_old_v = old_values[mb]
            mb_r = rewards[mb]  # [b]

            # forward current policy
            out = policy(input_ids=mb_ids, attention_mask=mb_attn, output_hidden_states=True, use_cache=False)
            cur_logp = shift_logits_and_labels(out.logits[:, :-1, :], mb_ids[:, 1:])  # [b,T-1]
            # KL token-wise
            kl = (cur_logp - mb_ref_logp) * mb_mask
            kl_seq = kl.sum(dim=1) / (mb_mask.sum(dim=1) + 1e-8)  # [b]
            # final reward with KL penalty
            r_hat = mb_r - H.kl_coef * kl_seq  # [b]

            # value
            last_h = get_last_hidden(out, mb_attn)
            
            v = v_head(last_h.float()).squeeze(-1)  # [b]

            # advantage (simple: A = r_hat - v_old) — episodic end reward
            adv = (r_hat - mb_old_v).detach()  # [b]

            # policy loss: use sequence-level logprob sum
            cur_seq_logp = (cur_logp * mb_mask).sum(dim=1)  # [b]
            old_seq_logp = (mb_old_logp * mb_mask).sum(dim=1)  # [b]
            ratio = torch.exp(cur_seq_logp - old_seq_logp)  # [b]
            clipped = torch.clamp(ratio, 1 - H.clip_eps, 1 + H.clip_eps)
            pg_loss = -torch.mean(torch.minimum(ratio * adv, clipped * adv))

            # value loss
            vf_loss = F.mse_loss(v, r_hat)

            opt.zero_grad(set_to_none=True)
            vf_opt.zero_grad(set_to_none=True)

            opt.zero_grad(set_to_none=True)
            vf_opt.zero_grad(set_to_none=True)

            loss = pg_loss + vf_loss
            loss.backward()

            torch.nn.utils.clip_grad_norm_([p for p in policy.parameters() if p.requires_grad], 1.0)
            torch.nn.utils.clip_grad_norm_(v_head.parameters(), 1.0)

            opt.step()
            vf_opt.step()
    return

# -------------------------
# Train loop
# -------------------------
from tqdm import tqdm
for step in tqdm(range(1, H.steps + 1), total=H.steps):
    gen, attn, old_logp, ref_logp, mask, old_v, r = rollout()
    ppo_update(gen, attn, old_logp, ref_logp, mask, old_v, r)

    if step % H.log_every == 0:
        print(f"[{step}/{H.steps}] reward_mean={r.mean().item():.4f} reward_pos={(r>0).float().mean().item():.3f}")
    if (step-1)%1000 == 0:
        # Save LoRA + v_head
        policy.save_pretrained(os.path.join(out_dir, f'epoch-{step}'))
        torch.save(v_head.state_dict(), os.path.join(out_dir, f"{step}_value_head.pt"))
        tokenizer.save_pretrained(out_dir)
        print("Saved to:", out_dir)
