import re
import torch
from llm.prompt_utils import get_prompt_template
def get_nodescore(model, tokenizer, x_text, data_name, device, nei_text=None):
    template_l, template_r = get_prompt_template(data_name, "classscore")
    if data_name in ['cora', 'citeseer', 'pubmed']:
        query = template_l + f"###The abstract of paper A is: {x_text}\n" + "###The abstract of relavant papers are: \n" 
    elif data_name == 'instagram':
        query = template_l + f"###The personal profile of user A is: {x_text}\n" + "###The personal profile of relavant users are: \n"
    elif data_name == 'reddit':
        query = template_l + f"###The last 3 posts of user A is: {x_text}\n" + "###The last 3 psots of relavant users are: \n"
    elif data_name == 'wikics':
        query = template_l + f"###The content of webpage A is: {x_text}\n" + "###The content of relavant webpages are: \n"
    elif data_name == 'history' or data_name == 'children':
        query = template_l + f"###The description and title of book A is: {x_text}\n" + "###The description and title of relavant books are: \n"
    elif data_name == 'photo':
        query = template_l + f"###The user review of electronic product A is: {x_text}\n" + "###The description and title of relavant electronic products are: \n"
    elif data_name == 'amazonratings':
        query = template_l + f"###The description of product A is: {x_text}\n" + "###The description of relavant products are: \n"
        
    for text in nei_text:
        query += text + "\n"
        
    query += template_r
    inputs = tokenizer.apply_chat_template([{"role": "user", "content": query}],
                                       add_generation_prompt=True,
                                       tokenize=True,
                                       truncation=True,
                                       max_length=8000,
                                       return_tensors="pt",
                                       return_dict=True
                                       )
    inputs = inputs.to(device)
    
    gen_kwargs = {"max_length": 9000, "do_sample": True, "top_p": 0.98, "temperature": 0.3, "repetition_penalty": 1.0}
    with torch.no_grad():
        outputs = model.generate(**inputs, **gen_kwargs)
        outputs = outputs[:, inputs['input_ids'].shape[1]:]
        answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
        answer = re.findall(r'-(\d+)', answer)
        answer = [int(i) for i in answer]
    # print(answer)
    return answer