import json
import torch
import numpy as np

import time
from tqdm import tqdm

from vllm import LLM, SamplingParams

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default="/zfsauton2/home/wentsec/Meta-Llama-3.1-8B-Instruct", help='name of the language model')
parser.add_argument('--node_id', type=str, default="26", help='id of the gpu node, used as data prefix')
parser.add_argument('--file_path', type=str, default="", help='where to save the trajectory')
parser.add_argument('--alpha', type=float, default=1.0, help='alpha for kl constraint')
parser.add_argument('--temperature', type=float, default=1.0, help='temperature for sampling')
parser.add_argument('--batch_size', type=int, default=256, help='number batch size for each iteration')
parser.add_argument('--seed', type=int, default=42, help='random seed')
args = parser.parse_args()

# set random seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)

# data preparation
file_name = args.file_path + "/agent.json"
with open(file_name, "r") as f:
    agent_traj = json.load(f)
file_name = args.file_path + "/feedback.json"
with open(file_name, "r") as f:
    verbal_fb = json.load(f)

def deepcopy_list(list_msgs):
    list_msgs_copy = []
    for msg in list_msgs:
        list_msgs_copy.append({"role": msg["role"], "content": msg["content"]})
    return list_msgs_copy

w_feedback = []
wo_feedback = []
origin_data = []
for idx in range(args.batch_size):
    if verbal_fb[idx] is not None:
        origin_data.append(agent_traj[idx])
        wo_feedback.append(deepcopy_list(agent_traj[idx][:-1]))
        msg = deepcopy_list(agent_traj[idx][:-1])
        msg[-1]["content"] += "Here is the advice you should follow: " + verbal_fb[idx] + "\n"
        w_feedback.append(msg)

# Create a sampling params object.
sampling_params = SamplingParams(
    temperature=args.temperature,
    logprobs=10000
)
llm = LLM(model=args.model_name, dtype='bfloat16', max_logprobs=10000)

def get_logprobs(outputs):
    logprobs = []
    for idx in range(len(wo_feedback)):
        candidate_logits = []
        for action in ["A", "B", "C", "D", "E"]: #, "F"]:
            not_found = True
            for key in outputs[idx].outputs[0].logprobs[0].keys():
                token = outputs[idx].outputs[0].logprobs[0][key].decoded_token
                logprob = outputs[idx].outputs[0].logprobs[0][key].logprob
                if token == action:
                    candidate_logits.append(logprob)
                    not_found = False
                    break
            if not_found:
                candidate_logits.append(-100)
        candidate_logits = torch.tensor(candidate_logits)
        logprob = candidate_logits.detach().cpu().numpy()
        logprobs.append(logprob)
    return np.array(logprobs)

# obj_list = ["box", "ball", "key"]
# color_list = ["green", "blue", "red", "yellow", "grey", "purple"]

outputs = llm.chat(wo_feedback, sampling_params)
logprobs_wo_fb = get_logprobs(outputs)

outputs = llm.chat(w_feedback, sampling_params)
logprobs_w_fb = get_logprobs(outputs)

file_name = args.file_path + "/result.json"
with open(file_name, "a") as f:
    for idx in range(len(logprobs_w_fb)):
        data = {
            "messages": origin_data[idx],
            "pi_sampling": [str(logp) for logp in logprobs_wo_fb[idx]],
            "pi_better": [str(logp) for logp in logprobs_w_fb[idx]],
            "alpha": str(args.alpha),
        }
        json.dump(data, f)
        f.write("\n")
    

