from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
import argparse
import json
from tqdm import tqdm


def read_jsonl(file_path):
    data = []
    with open(file_path, 'r') as f:
        for line in f:
            data.append(json.loads(line))
    return data

def read_json(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)
    return data

def write_jsonl(file_path, data):
    with open(file_path, 'w') as f:
        for item in data:
            f.write(json.dumps(item) + '\n')

def write_json(file_path, data):
    with open(file_path, 'w') as f:
        json.dump(data, f, indent=4)


def parse_args():
    parser = argparse.ArgumentParser(description="Math Verify")
    parser.add_argument(
        "--model_name",
        type=str,
        help="Path to the model directory or model name from Hugging Face Hub."
    )
    parser.add_argument("--generation_file", type=str, default=None, help="Path to the generation file.")
    parser.add_argument("--config_file", type=str, default=None, help="Path to the generation config file.")
    parser.add_argument("--question_key", type=str, default="problem", help="Key for the question in the input JSON.")
    parser.add_argument("--answer_key", type=str, default="answer", help="Key for the answer in the input JSON.")
    parser.add_argument("--output_dir", type=str, default="generations.jsonl", help="Directory to save the output.")
    return parser.parse_args()

good_token = '+'
bad_token = '-'
step_tag = 'ки'

if __name__ == "__main__":
    args = parse_args()
    # Load tokenizer separately to handle chat template
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name
    )
    candidate_tokens = tokenizer.encode(f"{good_token} {bad_token}")[1:] # [648, 387]
    print(f"Candidate token IDs: {candidate_tokens}")
    step_tag_id = tokenizer.encode(f"{step_tag}")[-1] # 12902
    print(f"Step tag ID: {step_tag_id}")
    
    generations = read_jsonl(args.generation_file)

    model = AutoModelForCausalLM.from_pretrained(
        args.model_name, 
        device_map='auto', 
    ).eval()
        
    for gen in tqdm(generations, desc="Scoring generations"):
        question = gen[args.question_key]
        gen["critique_scores"] = []
        for resp in gen["response"]:
            resp = resp.strip()
            input_for_prm = f"{question} {resp} ки"
            input_ids = torch.tensor([tokenizer.encode(input_for_prm)]).to(model.device)
            with torch.no_grad():
                logits = model(input_ids).logits[:, :, candidate_tokens]
                scores = logits.softmax(dim=-1)[:, :, 0]  # good_token probability
                step_scores = scores[input_ids == step_tag_id]
                # print(step_scores)
                scalar_score = step_scores.item()
            gen["critique_scores"].append([round(scalar_score, 2)])
    
    # Save the updated generations with critique scores
    model_name = args.model_name.split("/")[-1]
    output_file = os.path.join(args.output_dir, f"generation_verified_by_{model_name}.jsonl")
    write_jsonl(output_file, generations)

    config = read_json(args.config_file)
    config["generation_file"] = output_file
    write_json(os.path.join(args.output_dir, "config.json"), config)

    print(f"Critique scores added to the generations and saved to {output_file}.")
