import json
import os, sys
import re
import time 
from tqdm import tqdm
import argparse
from pathlib import Path
from typing import Tuple
import pandas as pd
from datasets import load_dataset
from utils import * 
from transformers import LlamaForCausalLM, LlamaTokenizer, AutoTokenizer, AutoModel, AutoModelForCausalLM
# import tensor_parallel as tp
import accelerate
import random


templates = [
    "What do you think of {}?",
    "What do you feel about {}?",
    "How do you view {}?",
]
for position in [
    "opinion of",
    "stance on",
    "position on",
    "attitude about",
    "view on",
    "take on",
    "impression of",
    "assessment of",
    "judgment of",
    "sentiment of",
]:
    templates.append("What is your " + position + " {}?")



per_list = [
    "extraversion",
    "agreeableness", 
    "neuroticism"
]


def load(ckpt_dir, model_type):
    print("start loading")
    n_gpus = torch.cuda.device_count()
    tokenizer = AutoTokenizer.from_pretrained(
        ckpt_dir,
        use_fast=False,
        padding_side="left",
        trust_remote_code=True,
    )
    tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
    tokenizer.bos_token_id = 1
    
    # if model_type == 'llama':
    #     # we use tensor parallel for loading llama
    #     model = AutoModelForCausalLM.from_pretrained(ckpt_dir, low_cpu_mem_usage = True, torch_dtype=torch.float16, trust_remote_code=True)
    #     model = tp.tensor_parallel(model, [i for i in range(n_gpus)]) 
    # else:
    model = AutoModelForCausalLM.from_pretrained(ckpt_dir, device_map ='auto', torch_dtype=torch.float16, trust_remote_code=True) #balanced_low_0
    model.eval()

    return model, tokenizer


def batch_infer(model, tokenizer, prompts, batch_size=1, max_new_tokens=4):
    answers = []
    for batch_input in tqdm(batch_split(prompts, batch_size)):
        encode_inputs = prepare_input(tokenizer, batch_input)
        outputs = model.generate(**encode_inputs, max_new_tokens=max_new_tokens,
                                    return_dict_in_generate=True, 
                                    output_scores=True)
        input_length = encode_inputs.input_ids.shape[1]
        generated_tokens = outputs.sequences[:, input_length:]
        generated_texts = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        answers.extend(generated_texts)
        
    return answers


def batch_split(prompts, batch_num):
    batch_prompts = []
    mini_batch = []
    for prompt in prompts:
        mini_batch.append(prompt)
        if len(mini_batch) == batch_num:
            batch_prompts.append(mini_batch)
            mini_batch = []
    if len(mini_batch) != 0:
        batch_prompts.append(mini_batch)
    return batch_prompts


def prepare_input(tokenizer, prompts):
    input_tokens = tokenizer.batch_encode_plus(prompts, return_tensors="pt", padding=True)
    for t in input_tokens:
        if torch.is_tensor(input_tokens[t]):
            input_tokens[t] = input_tokens[t].to('cuda')

    return input_tokens


def format_data(method_type="prompt", prefix="pre"):
    
    data_path = "./data/personality/test.json"
    data = json.load(open(data_path, "r"))
    instruction = open("./prompt_lib/pre_instruction.txt", "r").read() if prefix=="pre" else open("./prompt_lib/edit_instruction.txt", "r").read()
    case_format = open("./prompt_lib/pre_case.txt", "r").read() if prefix=="pre" else open("./prompt_lib/edit_case.txt", "r").read()
    demo = open("./prompt_lib/pre_demo.txt", "r").read() if prefix=="pre" else open("./prompt_lib/edit_demo.txt", "r").read()
    
    examples = []
    
    for mention in data:
        topic = mention["ent"]
        target_per_text = mention["target_per"]
        question = random.sample(templates, 1)[0]
        
        if prefix == "pre":
            target_per_text = ""
            if method_type=="prompt":
                prompt = instruction + case_format.format(question.format(topic))
            else:
                prompt = instruction + demo + case_format.format(question.format(topic))
        else:
            if method_type=="prompt":
                prompt = instruction + case_format.format(target_per_text, topic, question.format(topic))
            else:
                prompt = instruction + demo + case_format.format(target_per_text, topic, question.format(topic))
                
        examples.append({
            "target_per_text": target_per_text,
            "topic": topic,
            "question": question,
            "prompt": prompt
        })
    
    return examples
        
        

def main(ckpt_dir: str,):
    
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    output_filename = os.path.join(args.output_dir, args.prefix + ".json")
    
    run_result = []
    
    model, tokenizer = load(ckpt_dir, "llama")
    
    start_time = time.time()
    over_len_err = 0
    total_cnt = 0

    formatted_data = format_data(args.method_type, args.prefix)

    pred_answers = batch_infer(model, tokenizer, [mention['prompt'] for mention in formatted_data], batch_size=args.batch_size, max_new_tokens=100)
    
    for mention, pred in zip(formatted_data, pred_answers):
        mention["pred"] = pred
        run_result.append(mention)
    
    with open(output_filename, 'w') as f:
        json.dump(run_result, f, ensure_ascii=False, indent=4)

    end_time = time.time()
    print("total run time %.2f" % (end_time - start_time))
    print("over_len_err:{}/{}".format(over_len_err, total_cnt))



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--ckpt_dir', type=str, default='llama-2-7b')
    parser.add_argument('--output_dir', type=str, default='./output')
    parser.add_argument('--method_type', type=str, default='prompt')
    parser.add_argument('--prefix', type=str, default='pre')
    parser.add_argument('--batch_size', type=int, default=4)
    args = parser.parse_args()
    
    with torch.no_grad():
        main(args.ckpt_dir,)
