from argparse import ArgumentParser

from tqdm import tqdm
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("-i", type=str)
    args = parser.parse_args()
    i = int(args.i)

    model_name = "Qwen/QwQ-32B"
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    dataset = load_dataset("HuggingFaceH4/aime_2024", split="train")
    prompt_template = "{question}\n\nPlease reason step by step, and put your final answer within \\boxed{{}}"
    inputs = []
    answers = []
    for example in dataset:
        messages = [
            {"role": "user", "content": prompt_template.format(question=example["problem"])}
        ]
        inputs.append(tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False))
        answers.append(example["answer"])
    
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        trust_remote_code=True, 
        torch_dtype="bfloat16",
        attn_implementation="flash_attention_2", 
        device_map="auto"
    )
    model.eval()

    instruction_tokens = tokenizer.encode(inputs[i], add_special_tokens=False)
    answer_tokens = tokenizer.encode(answers[i], add_special_tokens=False)
    inputs = tokenizer.encode_plus(inputs[i] + "\n\n**Final Answer**\n\\boxed{" + answers[i], add_special_tokens=False, return_tensors="pt")

    outputs = model(**inputs, labels=inputs["input_ids"])
    logits = outputs.logits[:, :-1, :][0, -len(answer_tokens):, :]
    scores = torch.softmax(logits, dim=-1)
    loss = 0.0
    for indice, idx in enumerate(answer_tokens):
        loss += -torch.log(scores[indice, idx])
    loss /= len(answer_tokens)

    direction1, direction2 = {}, {}
    for n, p in tqdm(model.named_parameters()):
        if n in ["lm_head.weight", "model.embed_tokens.weight"]:
            continue
        direction1[n] = torch.randn_like(p).cpu()
        direction2[n] = torch.randn_like(p).cpu()

    for (n, p) in tqdm(model.named_parameters()):
        if n in ["lm_head.weight", "model.embed_tokens.weight"]:
            continue
        norm_p = torch.linalg.norm(p.cpu().flatten())
        direction1[n].div_(torch.linalg.norm(direction1[n].cpu().flatten()).to(direction1[n])).mul_(norm_p.to(direction1[n]))
        direction2[n].div_(torch.linalg.norm(direction2[n].cpu().flatten()).to(direction2[n])).mul_(norm_p.to(direction2[n]))

    x = np.linspace(-1, 1, 60)
    y = np.linspace(-1, 1, 60)
    X, Y = np.meshgrid(x, y)

    Z = np.zeros_like(X)

    for x_idx in tqdm(range(x.size), desc="x progress"):
        for y_idx in tqdm(range(y.size), desc="y progress", leave=False):
            for n, p in model.named_parameters():
                if n in ["lm_head.weight", "model.embed_tokens.weight"]:
                    continue
                p.data.add_(x[x_idx]*direction1[n].to(p) + y[y_idx]*direction2[n].to(p))
            
            outputs = model(**inputs, labels=inputs["input_ids"])
            logits = outputs.logits[:, :-1, :][0, -len(answer_tokens):, :]
            scores = torch.softmax(logits, dim=-1)
            loss = 0.0
            for indice, idx in enumerate(answer_tokens):
                loss += -torch.log(scores[indice, idx])
            loss /= len(answer_tokens)
            Z[x_idx, y_idx] = loss.item()
            
            for n, p in model.named_parameters():
                if n in ["lm_head.weight", "model.embed_tokens.weight"]:
                    continue
                p.data.sub_(x[x_idx]*direction1[n].to(p) + y[y_idx]*direction2[n].to(p))

    np.savez(f"qwq32b/cache/qwq32b_aime24_landscape_{i}.npy", X=X, Y=Y, Z=Z)
