import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"

def load_llm():
    names = ["meta-llama/Llama-2-7b-chat-hf", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
    last_e = None
    for n in names:
        try:
            tok = AutoTokenizer.from_pretrained(n, use_fast=True)
            if tok.pad_token is None:
                tok.pad_token = tok.eos_token
            mdl = AutoModelForCausalLM.from_pretrained(
                n,
                dtype=torch.float16 if device=="cuda" else torch.float32,
                device_map="auto"
            )
            mdl.eval()
            return n, tok, mdl
        except Exception as e:
            last_e = e
    raise last_e

model_name, tokenizer, model = load_llm()

def get_texts_from_rwku(max_docs=300):
    try:
        ds = load_dataset("zjunlp/RWKU", split="train")
        col = "text" if "text" in ds.column_names else ds.column_names[0]
        return [x[col] for x in ds.select(range(min(max_docs, len(ds))))]
    except:
        try:
            ds = load_dataset("zjunlp/RWKU", "forget", split="train")
            col = "text" if "text" in ds.column_names else ds.column_names[0]
            return [x[col] for x in ds.select(range(min(max_docs, len(ds))))]
        except:
            ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
            col = "text" if "text" in ds.column_names else ds.column_names[0]
            texts = [t for t in (x[col] for x in ds.select(range(len(ds)))) if isinstance(t, str) and len(t.strip())>0]
            return texts[:max_docs]

def embed_text_mean_hidden(text):
    with torch.no_grad():
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024).to(model.device)
        out = model(**inputs, output_hidden_states=True)
        h = out.hidden_states[-1].float()
        m = h.mean(dim=1)
        v = F.normalize(m, p=2, dim=-1)
        return v.squeeze(0)

def embed_ids_mean_hidden(input_ids):
    with torch.no_grad():
        out = model(input_ids=input_ids, attention_mask=(input_ids!=tokenizer.pad_token_id), output_hidden_states=True)
        h = out.hidden_states[-1].float()
        m = h.mean(dim=1)
        v = F.normalize(m, p=2, dim=-1)
        return v

def cosine(a,b):
    return F.cosine_similarity(a,b,dim=-1)

def build_context(top_texts, max_tokens=512):
    ids = []
    for t in top_texts:
        tids = tokenizer.encode(t, add_special_tokens=False)
        ids.extend(tids)
        if len(ids) >= max_tokens:
            break
    ids = ids[:max_tokens]
    return torch.tensor([ids], device=model.device)

def generate_text_with_embeds(prompt_ids, context_ids, delta=None, max_new_tokens=128):
    emb = model.get_input_embeddings()
    p_emb = emb(prompt_ids)
    c_emb = emb(context_ids)
    if delta is not None:
        c_emb = c_emb + delta
    inputs_embeds = torch.cat([p_emb, c_emb], dim=1)
    attn = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long, device=model.device)
    gen = model.generate(inputs_embeds=inputs_embeds, attention_mask=attn, max_new_tokens=max_new_tokens, do_sample=False, pad_token_id=tokenizer.eos_token_id)
    return tokenizer.decode(gen[0], skip_special_tokens=True)

def forward_logits_hidden(prompt_ids, context_ids, delta=None):
    emb = model.get_input_embeddings()
    p_emb = emb(prompt_ids)
    c_emb = emb(context_ids)
    if delta is not None:
        c_emb = c_emb + delta
    inputs_embeds = torch.cat([p_emb, c_emb], dim=1)
    attn = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long, device=model.device)
    out = model(inputs_embeds=inputs_embeds, attention_mask=attn, output_hidden_states=True)
    logits = out.logits.float()
    hidden = out.hidden_states[-1].float()
    return logits, hidden

def avg_dir(logits):
    z = logits[:, :-1, :]
    z0 = logits[:, 1:, :]
    z = F.normalize(z, p=2, dim=-1)
    z0 = F.normalize(z0, p=2, dim=-1)
    T = min(z.size(1), z0.size(1))
    z = z[:, :T, :]
    z0 = z0[:, :T, :]
    return F.cosine_similarity(z, z0, dim=-1).mean()

def mean_hidden(hidden):
    return F.normalize(hidden.mean(dim=1), p=2, dim=-1)

def pgd_update(prompt_ids, context_ids, base_hidden, steps=3, lr=5e-2, epsilon=1.0):
    emb = model.get_input_embeddings()
    with torch.no_grad():
        c_emb = emb(context_ids).detach()
    delta = torch.zeros_like(c_emb, requires_grad=True)
    for _ in range(steps):
        logits_u, hidden_u = forward_logits_hidden(prompt_ids, context_ids, delta)
        S = cosine(mean_hidden(hidden_u), mean_hidden(base_hidden)).mean()
        N = avg_dir(logits_u)
        L = F.softplus(N - S)
        grads = torch.autograd.grad(L, delta)[0]
        with torch.no_grad():
            delta -= lr * grads
            flat = delta.view(delta.size(0), -1)
            norm = flat.norm(p=2, dim=1, keepdim=True).clamp(min=1e-8)
            factor = (epsilon / norm).clamp(max=1.0)
            delta = (flat * factor).view_as(delta)
            delta.requires_grad_(True)
    with torch.no_grad():
        return delta.detach()

def retrieve_topk(texts, query_vec, k):
    vecs = [embed_text_mean_hidden(t) for t in texts]
    mat = torch.stack(vecs, dim=0)
    sims = cosine(mat, query_vec.unsqueeze(0).expand_as(mat))
    idx = torch.topk(sims, k=min(k, sims.numel())).indices.tolist()
    return [texts[i] for i in idx], mat[idx]

def gate_trigger(y_text_vec, forget_vecs, tau=0.3):
    sims = cosine(forget_vecs, y_text_vec.unsqueeze(0).expand_as(forget_vecs))
    return sims.max().item() >= tau

def run_unre_once(query, forget_texts, K=5, tau=0.3, steps=3, lr=5e-2, epsilon=1.0, max_ctx_tokens=512, max_new_tokens=128):
    with torch.no_grad():
        y0 = model.generate(**tokenizer(query, return_tensors="pt").to(model.device), max_new_tokens=64, do_sample=False, pad_token_id=tokenizer.eos_token_id)
        y_text = tokenizer.decode(y0[0], skip_special_tokens=True)
    y_vec = embed_text_mean_hidden(y_text)
    vecs = [embed_text_mean_hidden(t) for t in forget_texts]
    forget_mat = torch.stack(vecs, dim=0)
    if not gate_trigger(y_vec, forget_mat, tau=tau):
        return None
    query_vec = embed_text_mean_hidden(query)
    top_texts, _ = retrieve_topk(forget_texts, query_vec, K)
    context_ids = build_context(top_texts, max_tokens=max_ctx_tokens)
    prompt_ids = tokenizer.encode(query, return_tensors="pt").to(model.device)
    _, base_hidden = forward_logits_hidden(prompt_ids, context_ids, None)
    delta = pgd_update(prompt_ids, context_ids, base_hidden, steps=steps, lr=lr, epsilon=epsilon)
    out = generate_text_with_embeds(prompt_ids, context_ids, delta, max_new_tokens=max_new_tokens)
    return out

if __name__ == "__main__":
    forget_texts = get_texts_from_rwku(300)
    candidates = get_texts_from_rwku(300)
    result = None
    i = 0
    while result is None and i < len(candidates):
        q = candidates[i]
        result = run_unre_once(q, forget_texts, K=5, tau=0.3, steps=3, lr=5e-2, epsilon=1.0, max_ctx_tokens=256, max_new_tokens=128)
        i += 1
    if result is not None:
        print(result)
    else:
        print("")
