import torch
import os
import sys
import argparse
import json
import time
import random
import numpy as np
from tqdm import tqdm
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer, AutoModelForCausalLM
import str_utils

################
# Configurations
################
def get_args():
    # Experiment Settings
    parser = argparse.ArgumentParser(description="Instruction Generation Manager.")
    parser.add_argument("--model_path", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct",
                        help="We will support more models in the future.")
    # Generation Parameters
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--top_p", type=float, default=1.0)
    parser.add_argument("--n", type=int, default=200, help="Number of samples to generate for one time.")
    parser.add_argument("--repeat", type=int, default=None, help="Number of times to repeat the instruction generation. Only available when total prompts is not specified.")
    parser.add_argument("--total_prompts", type=int, default=1000, help="Total number of prompts to generate. If specified, repeat will be ignored.")
    parser.add_argument("--max_tokens", type=int, default=2048)
    parser.add_argument("--max_model_len", type=int, default=4096)
    parser.add_argument("--early_stopping", type=bool, default=True, help="Stop generation when the \n is generated.")
    parser.add_argument("--disable_early_stopping", action="store_false", dest="early_stopping", help="Disable early stopping.")
    parser.add_argument("--system_prompt", action="store_true", help="Enable system prompt for extracting the input.")
    parser.add_argument("--sanitize", action="store_true", help="Sanitize the generated instructions.")
    parser.add_argument("--logits_processor", action="store_true", help="Enable logits processor for the generation.")
    parser.add_argument("--control_tasks", type=str, default=None, choices=[None, "translation", "code", "math"],  help="Control tasks for the generation. Currently only available for some models.")
    parser.add_argument("--shuffle", type=bool, default=True, help="Shuffle the outputs generated by vllm.")
    parser.add_argument("--skip_special_tokens", type=bool, default=True)
    parser.add_argument("--checkpoint_every", type=int, default=100, help="Save checkpoint every n repeats.")

    # System Settings
    parser.add_argument('--engine', default="vllm", type=str, choices=["vllm", "hf"])
    parser.add_argument("--device", type=str, default="0")
    parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float16", "bfloat16"])
    parser.add_argument("--tensor_parallel_size", type=int, default=1, help="Number of GPUs to use for tensor parallelism. Only used for Llama 70B models.")
    parser.add_argument("--gpu_memory_utilization", type=float, default=0.95)
    parser.add_argument("--swap_space", type=float, default=2.0)
    parser.add_argument("--output_folder", type=str, default="../data")
    parser.add_argument("--job_name", type=str, default=None, help="Job Name. Get from the script.")
    parser.add_argument("--timestamp", type=int, default=int(time.time()), help="Timestamp for the job. Also used as the random seed.")
    parser.add_argument("--seed", type=int, default=None, help="Random seed.")

    return parser.parse_args()

args = get_args()
print(f"Instruction Generation Manager. Arguments: {args}") # For logging

if args.total_prompts is None:
    if args.repeat is None:
        raise ValueError("Either total prompts or repeat should be specified.")
    args.total_prompts = args.repeat * args.n
else:
    # If total prompts is specified, repeat will be ignored
    args.repeat = int(np.ceil(args.total_prompts / args.n))

# Set the random seed for NumPy
if args.seed is not None:
    np.random.seed(args.seed)
    # Set the random seed for PyTorch
    torch.manual_seed(args.seed)
    # If you are using CUDA (i.e., a GPU), also set the seed for it
    torch.cuda.manual_seed_all(args.seed)

# Create output file / folder
output_filename = f"Magpie_{args.model_path.split('/')[-1]}_{args.total_prompts}_{args.timestamp}_ins.json"
if not args.job_name:
    if not os.path.exists(args.output_folder):
        os.makedirs(args.output_folder)
    output_dir = f"{args.output_folder}/{output_filename}"
else:
    output_dir = f"{args.output_folder}/{args.job_name}/{output_filename}"

# Set the device
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
# Set generation engine
if args.engine == "vllm":
    # Create vllm instance  
    llm = LLM(model=args.model_path, 
            dtype=args.dtype,
            trust_remote_code=True,
            gpu_memory_utilization=args.gpu_memory_utilization,
            max_model_len=args.max_model_len,
            swap_space=args.swap_space,
            tensor_parallel_size=args.tensor_parallel_size,
            seed=args.seed if args.seed is not None else args.timestamp,
            enable_prefix_caching=True)
elif args.engine == "hf":
    # Load the model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        device_map={'':torch.cuda.current_device()},
        torch_dtype=torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
    )


# Obtain config from configs/model_configs.json
with open("../configs/model_configs.json", "r", encoding="utf-8") as f:
    model_configs = json.load(f)
    model_config = model_configs[args.model_path]
    if args.control_tasks:
        pre_query_template = model_config[f"pre_query_template_{args.control_tasks}"]
        print("Control task: {args.control_tasks}")
    elif args.system_prompt:
        pre_query_template = model_config["pre_query_template_with_system_prompt"]
        print("System prompt enabled. Warning: The system prompt may degrade the performance.")
    else:
        pre_query_template = model_config["pre_query_template"]
    stop_tokens = model_config["stop_tokens"]
    stop_tokens_assistant = model_config["stop_tokens_assistant"]
    stop_tokens += stop_tokens_assistant
    stop_token_ids = model_config["stop_token_ids"]

    # Process early stopping. We found that sometimes LLM will generate responses immediately after the \n token.
    if args.early_stopping:
        stop_tokens.append("\n")

    print(f"Pre-query template: {pre_query_template}")
    print(f"Stop tokens: {stop_tokens}")
    print(f"Stop token ids: {stop_token_ids}")


# Initialize logits processors for llama-3.1
def de_md_logits_processor_for_llama3_1(token_ids, logits):
    # Only process the initial logits
    if len(token_ids) == 0:
        logits[2] = -9999.999 # "#": 2,
        logits[567] = -9999.999 # "##": 567,
        logits[14711] = -9999.999 # "###": 14711,
        logits[827] = -9999.999 # "####": 827,

    return logits

if args.logits_processor and "llama-3.1" in args.model_path.lower():
    logits_processor = de_md_logits_processor_for_llama3_1
    print(f"Logits processor applied: {logits_processor}")
else:
    logits_processor = None
    
# Define sampling parameters
sampling_params = SamplingParams(
    n=args.n,
    temperature=args.temperature,
    top_p=args.top_p,
    max_tokens=args.max_tokens,
    skip_special_tokens=args.skip_special_tokens,
    stop=stop_tokens,
    stop_token_ids=stop_token_ids,
    logits_processors=[logits_processor] if logits_processor else None
)

################
# Generate outputs
################
results = []
for rounds in tqdm(range(args.repeat)):
    # Generate outputs
    if args.engine == "vllm":
        output = llm.generate(pre_query_template, sampling_params)
        output_list = output[0].outputs
        if args.shuffle:
            random.shuffle(output_list)
    
    elif args.engine == "hf":
        input = tokenizer.encode(pre_query_template, add_special_tokens=False, return_tensors="pt").to(torch.cuda.current_device())
        # Gemma-2 bug, so we cannot set num_return_sequences > 1. 
        # Instead, we repeat the input n times.
        inputs = input.repeat(args.n, 1).to(torch.cuda.current_device())
        output = model.generate(inputs,
                                tokenizer=tokenizer, 
                                do_sample=True, 
                                temperature=args.temperature, 
                                top_p=args.top_p, 
                                max_length=args.max_tokens, 
                                num_return_sequences=1,
                                )
        # Remove the input from the output
        output_list = tokenizer.batch_decode(output[i][len(inputs[0]):] for i in range(args.n))
        # Stop on the first stop token
        for i, completion in enumerate(output_list):
            for stop_token in stop_tokens:
                if stop_token in completion:
                    output_list[i] = completion[:completion.index(stop_token)]
                                             
    # Save outputs
    for i, completion in enumerate(output_list):
        if args.engine == "vllm":
            instruction = completion.text.strip()
        elif args.engine == "hf":
            instruction = completion.strip()

        if args.sanitize:
            sanitized_instruction, class_num = str_utils.instruction_post_process(instruction, args.model_path)
            result = {
                "id": rounds * args.n + i,
                "pre_query_template": f"{pre_query_template}",
                "raw_instruction": instruction,
                "instruction": sanitized_instruction,
                "instruction_sanitize_class_num": class_num,
                "response": None,
                "created": int(time.time()),
                "gen_input_configs": {
                    "temperature": args.temperature,
                    "top_p": args.top_p,
                    "input_generator": f"{args.model_path}",
                    "seed": args.seed,
                },
                "gen_response_configs": None,
            }
        else:
            result = {
                "id": rounds * args.n + i,
                "pre_query_template": f"{pre_query_template}",
                "instruction": instruction,
                "response": None,
                "created": int(time.time()),
                "gen_input_configs": {
                    "temperature": args.temperature,
                    "top_p": args.top_p,
                    "input_generator": f"{args.model_path}",
                    "seed": args.seed,
                },
                "gen_response_configs": None,
            }
        results.append(result)

    # Save the checkpoints every args.checkpoint_every rounds
    if rounds % args.checkpoint_every == 0:
        with open(output_dir, "w") as f:
            json.dump(results, f, indent=2)
        print(f"Checkpoint saved. Total prompts: {len(results)}")

# Save the final results
with open(output_dir, "w") as f:
    json.dump(results, f, indent=2)

print(f"Instruction generated from {args.model_path}. Total prompts: {len(results)}")