import json
import argparse
from collections import defaultdict
from tqdm import tqdm
from vllm import LLM, SamplingParams
from itertools import combinations
import numpy as np
def main(dataset, method_name, model_id, iteration, original_data_dir, num_response, model_path, tensor_parallel_size, log_file_path, exp_name):
    data_points = []
    with open(original_data_dir, 'r') as f:
        for line in f:
            data_points.append(json.loads(line))
    data_points = data_points[1:]
    # randomly sample 1000 data points using numpy
    # np.random.seed(iteration)
    # idx = np.random.choice(len(data_points), 1000, replace=False)
    # data_points = [data_points[i] for i in idx]
    
    llm = LLM(model=model_path, tensor_parallel_size=tensor_parallel_size, tokenizer=model_id)
    sampling_params = SamplingParams(temperature=1, top_p=0.95, logprobs=1, n=num_response, max_tokens=1024)

    all_prompts = [prompt for prompt, _, _, _ in data_points]

    outputs = llm.generate(all_prompts, sampling_params)
    final_data_points = []
    truncation_mode = 'keep_end' if dataset == 'hh' else 'keep_start'

    # Print the outputs.
    for output in outputs:
        all_response = [output.outputs[res_id].text for res_id in range(num_response)]
        # for each pair of responses, get a new data point
        for res1_id, res2_id in combinations(range(num_response), 2):
            final_data_points.append([output.prompt, all_response[res1_id], all_response[res2_id], truncation_mode])
    
    # Save the final data points
    if '/' in model_id:
        model_id = model_id.split('/')[-1]
    final_data_points = [{'selection': False}] + final_data_points
    
    if len(exp_name) > 0:
        generated_data_path = f'{exp_name}_{method_name}_{model_id}-{dataset}-generated-iteration-{iteration}.jsonl'
    else:
        generated_data_path = f'{method_name}_{model_id}-{dataset}-generated-iteration-{iteration}.jsonl'
    with open(generated_data_path, 'w') as f:
        for dp in final_data_points:
            f.write(json.dumps(dp) + '\n')
    # read data from log_file_path
    with open(log_file_path, 'r') as f:
        logs = json.loads(f.readline())
    logs['latest_generated_data'] = generated_data_path
    with open(log_file_path, 'w') as f:
        f.write(json.dumps(logs))
    print(f"Saved {len(final_data_points)} data points to {generated_data_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Generate responses using LLM.')
    parser.add_argument('--method_name', type=str, default='ours', help='The name of the method')
    parser.add_argument('--dataset', type=str, default='tldr', help='Dataset name')
    parser.add_argument('--model_id', type=str, default='meta-llama/Llama-2-7b-hf', help='Model id')
    parser.add_argument('--iter', type=int, default='0', help='Iteration number')
    parser.add_argument('--original_data_dir', type=str, default='tldr_train_data_10000.jsonl', help='Path to the original data directory')
    parser.add_argument('--num_response', type=int, default=5, help='Number of responses to generate')
    parser.add_argument('--model_path', type=str, default='whole_model_unload', help='Path to the model')
    parser.add_argument('--tensor_parallel_size', type=int, default=4, help='Tensor parallel size')
    parser.add_argument('--log_file_path', type=str, default='', help='Log file path')
    parser.add_argument('--exp_name', type=str, default='', help='Name of the experiment')

    args = parser.parse_args()
    main(args.dataset, args.method_name, args.model_id, args.iter, args.original_data_dir, args.num_response, args.model_path, args.tensor_parallel_size, args.log_file_path, args.exp_name)