# adapted from storytelling/main.py
import argparse
import os
import torch
import psutil

from uncertainty_quantification.model_utils import compute_gpu_memory_utilization
from uncertainty_quantification.arg_utils import step1_forward_args
from uncertainty_quantification.manager import ForwardManager



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='NudgingArgsParsing.')
    parser = step1_forward_args(parser, sample_counts=50, max_tokens=1024, log_probs=100, min_p=0.1, top_p=1.0, temperature=1.0, output_root_dir="response_storywriting")
    # mmlu-specific setup
    parser.add_argument("--constraint_level", type=int, default=0,
                        help="constraint level (in #(words) * multiplier/#(sentences))")
    parser.add_argument("--word_level_constraint", action="store_true", help="constraint level in word level")
    parser.add_argument("--word_level_constraint_multiplier", type=int, default=25, help="constraint level multiplier in word level")
    parser.add_argument("--max_constraint_level", type=int, default=10,)
    parser.add_argument("--input_file", type=str, help='input file', default=None)
    # nudging-specific setup
    parser.add_argument("--eval_model", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct",)
    parser.add_argument("--eval_log_probs", type=int, default=50, help="log probs")
    parser.add_argument("--eval_output_dir", type=str, default="nudging_experiments", help="log probs")
    args = parser.parse_args()
    output_root_dir = args.output_root_dir
    os.makedirs(output_root_dir, exist_ok=True)

    model = args.model
    chat_template_path = args.chat_template_path
    constraint_level = args.constraint_level
    assert constraint_level >= 0, "constraint level should be non-negative"

    # Get memory usage of the current process
    process = psutil.Process()
    memory_info = process.memory_info()

    # Print memory usage in bytes
    print(memory_info.rss)

    # Print memory usage in MB
    print("Memory usage in MB:", memory_info.rss / (1024 * 1024))

    file_name = "{}_response_n_{}_max_tokens_{}_log_probs_{}_min_p_{}_top_p_{}_seed{}{}{}.pt".format(
        os.path.basename(model),
        args.sample_counts,
        args.max_tokens, args.log_probs,
        args.min_p, args.top_p, args.seed,
        "_word_level_constraint_multiplier_{}".format(args.word_level_constraint_multiplier) if args.word_level_constraint else "",
        '_input_file_{}'.format(os.path.basename(args.input_file)) if args.input_file is not None else "")
    nudging_output_root_dir = args.eval_output_dir
    os.makedirs(nudging_output_root_dir, exist_ok=True)
    nudging_output_filename = file_name + ".{}_top_{}.nudging".format(os.path.basename(args.eval_model), args.eval_log_probs)
    nudging_output_filename = os.path.join(nudging_output_root_dir, nudging_output_filename)
    file_name = os.path.join(output_root_dir, file_name)
    metadata_filename = file_name.replace(".pt", ".metadata")
    assert os.path.exists(metadata_filename), f"metadata file ({metadata_filename}) not found"
    # prompts, roles, answers, options, all_constrained_prompts, metadata_args = torch.load(metadata_filename)
    prompts, roles, all_task_prompts, all_original_prompts, metadata_args = torch.load(metadata_filename)
    # nudging part tricks
    # trick: update the log_probs to the eval_log_probs
    args.log_probs = args.eval_log_probs
    args.model = args.eval_model
    eval_model = args.eval_model
    # trick #2: update min_p and top_p -- both params are only used in the sampling process, so we can safely ignore them
    args.min_p = 0
    args.top_p = 1.0
    manager = ForwardManager(args, ckpt_freq=args.ckpt_freq)
    gpu_memory_utilization = compute_gpu_memory_utilization(model)
    max_num_seqs = 16
    # as we use a40 for quick run, we need to check whether the model is 70B -- if it is, we have to use 0.8
    if "70b" in eval_model.lower() and "llama-3" in eval_model.lower():
        gpu_memory_utilization = 0.6
        max_num_seqs = 4
    if "8b" in eval_model.lower():
        gpu_memory_utilization = 0.6
        max_num_seqs = 4
    assert os.path.exists(file_name), f"file ({file_name}) not found"
    manager.setup_model(max_num_seqs=max_num_seqs, gpu_memory_utilization=gpu_memory_utilization)
    response = torch.load(file_name)
    print("Memory usage in MB:", memory_info.rss / (1024 * 1024))

    manager.fillin_logits_routine(response, file_name, max_num_seqs=max_num_seqs,
                                  gpu_memory_utilization=gpu_memory_utilization,
                                  try_reuse_existing_logprobs=False,
                                  update_filename=nudging_output_filename + ".spectrum",
                                  patch_filename=nudging_output_filename + ".patch",)
