import argparse
import json
import os
import sys
import torch
from vllm import LLM
from vllm.sampling_params import SamplingParams
from tqdm import tqdm
from tqdm.contrib import tzip
import multiprocessing
import numpy as np
from transformers import (
    AutoTokenizer,
)
from datasets import load_from_disk, load_dataset, Dataset, DatasetDict
sys.path.append(os.path.join(os.path.dirname(__file__), '../WIT/pre_process'))
from utils import read_with_orjsonl, write_with_orjsonl, write_with_orjsonl_extend

os.environ["TOKENIZERS_PARALLELISM"] = "false"
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"

stop_tokens = [
    DEFAULT_PAD_TOKEN,
    DEFAULT_EOS_TOKEN,
    DEFAULT_BOS_TOKEN,
]

stop_tokens_ids = None

def parse_args():
    parser = argparse.ArgumentParser()
    
    # Model settings
    parser.add_argument("--model_name_or_path", type=str, default='YOUR_ROOT_PATH/model/CapsFus-LLaMA')
    parser.add_argument("--tokenizer", type=str, default='YOUR_ROOT_PATH/model/llama2-1229/Llama-2-7b-hf', help="Path to our tokenizer directory.")
    parser.add_argument("--use_slow_tokenizer", action="store_true", help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--tensor_parallel_size", type=int, default=1)
    parser.add_argument("--num_gpu", type=int, default=1, help="The number of gpus to use.")
    parser.add_argument("--dataset_shard_index", type=int, default=0, help="The shard index of the dataset.")
    parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=['float16', 'bfloat16', 'float32'])
    
    # Dataset settings
    parser.add_argument('--dataset_name', type=str, default='Merged_new', help='dataset name')
    parser.add_argument('--dataset_dir', type=str, default='YOUR_ROOT_PATH/data/MLLM/IC', help='path to origin caption dir')
    parser.add_argument('--swap_space', type=int, default=20, help='swap space in GB')
    parser.add_argument('--process_batch_size', type=int, default=200, help='process batch size')
    parser.add_argument('--process_num_workers', type=int, default=multiprocessing.cpu_count(), help='preprocessing num workers')
    
    # Generation settings
    parser.add_argument("--num_return_sequences", type=int, default=1, help="Number of output sequences to return for the given prompt.")
    parser.add_argument("--best_of", type=int, default=3, help="Number of output sequences that are generated from the prompt. From these `best_of` sequences, the top `n` sequences are returned. `best_of` must be greater than or equal to `n`. This is treated as the beam width when `use_beam_search` is True. By default, `best_of` is set to `n`.")
    parser.add_argument("--presence_penalty", type=float, default=0.0, help="Float that penalizes new tokens based on whether they appear in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat tokens.")
    parser.add_argument("--frequency_penalty", type=float, default=0.0, help="Float that penalizes new tokens based on their frequency in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat tokens.")
    parser.add_argument("--repetition_penalty", type=float, default=1.0, help="Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far. Values > 1 encourage the model to use new tokens, while values < 1 encourage the model to repeat tokens.")
    parser.add_argument("--temperature", type=float, default=0.0, help="Float value controlling randomness in boltzmann distribution. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.")
    parser.add_argument("--top_p", type=float, default=1.0, help="Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to 1 to consider all tokens.")
    parser.add_argument("--top_k", type=int, default=-1, help="Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.")
    parser.add_argument("--min_p", type=float, default=0.0, help="Float that represents the minimum probability for a token to be considered, relative to the probability of the most likely token. Must be in [0, 1]. Set to 0 to disable this.")
    parser.add_argument("--use_beam_search", action="store_true", help="Whether to use beam search instead of sampling.")
    parser.add_argument("--length_penalty", type=float, default=1.0, help="Float that penalizes sequences based on their length. Used in beam search.")
    parser.add_argument("--early_stopping", type=str, default="False", choices=["True", "False", "never"], help="Controls the stopping condition for beam search. It accepts the following values: `True`, where the generation stops as soon as there are `best_of` complete candidates; `False`, where an heuristic is applied and the generation stops when is it very unlikely to find better candidates; `'never'`, where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm).")
    parser.add_argument("--stop", nargs='+', default=stop_tokens, help="List of strings that stop the generation when they are generated. The returned output will not contain the stop strings.")
    parser.add_argument("--stop_token_ids", nargs='+', default=stop_tokens_ids, help=" List of tokens that stop the generation when they are generated. The returned output will contain the stop tokens unless the stop tokens are special tokens.")
    parser.add_argument("--include_stop_str_in_output", action="store_true", help="Whether to include the stop strings in output text. Defaults to False.")
    parser.add_argument("--ignore_eos", action="store_true", help="Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.")
    parser.add_argument("--max_tokens", type=int, default=80, help="Maximum number of tokens to generate per output sequence.")
    parser.add_argument("--skip_special_tokens", action="store_true", help="Whether to skip special tokens in the output.")
    parser.add_argument("--spaces_between_special_tokens", action="store_true", help="Whether to add spaces between special tokens in the output.")
    # logprobs: Number of log probabilities to return per output token. Note that the implementation follows the OpenAI API: The return result includes the log probabilities on the `logprobs` most likely tokens, as well the chosen tokens. The API will always return the log probability of the sampled token, so there  may be up to `logprobs+1` elements in the response.
    # prompt_logprobs: Number of log probabilities to return per prompt token.
    # logits_processors: List of functions that modify logits based on previously generated tokens.

    args = parser.parse_args()

    if args.torch_dtype == "bfloat16":
        args.torch_dtype = torch.bfloat16
    elif args.torch_dtype == "float16":
        args.torch_dtype = torch.float16
    elif args.torch_dtype == "float32":
        args.torch_dtype = torch.float32
    else:
        raise ValueError(f"Invalid torch dtype: {args.torch_dtype}")

    if args.early_stopping == "True":
        args.early_stopping = True
    elif args.early_stopping == "False":
        args.early_stopping = False

    return args

def main():
    args = parse_args()

    result_path = os.path.join(args.dataset_dir, args.dataset_name, "fused_caption", "captions.jsonl")
    fused_caption_path = os.path.join(args.dataset_dir, args.dataset_name, "fused_caption")
    if args.dataset_shard_index == -1:
        # merge to dataset
        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, use_fast=False, legacy=False)
        tokenizer.pad_token = tokenizer.eos_token

        if os.path.exists(result_path):
            merged_dataset = load_from_disk(os.path.join(args.dataset_dir, args.dataset_name, 'image_token', 'merged'))
            print(merged_dataset)
            results = read_with_orjsonl(result_path)
            results.sort(key=lambda x: x['row_index'])
            if args.dataset_name == "Merged_new":
                caption_capsfusion = merged_dataset['caption_capsfusion']
                merged_dataset = merged_dataset.remove_columns(['caption_capsfusion'])
            else:
                caption_capsfusion = [None] * len(merged_dataset)
                assert len(merged_dataset) == len(results)
            drop_index = []
            for result in tqdm(results):
                if caption_capsfusion[result['row_index']] is None or caption_capsfusion[result['row_index']] == '':
                    caption_capsfusion[result['row_index']] = result['response']
                else:
                    raise ValueError(f"caption_capsfusion[{result['row_index']}] already exists: {caption_capsfusion[result['row_index']]}")
            # check if all captions are filled
            for caption_index, caption in enumerate(tqdm(caption_capsfusion)):
                if caption is None or caption == '':
                    drop_index.append(caption_index)
            print(len(drop_index), len(caption_capsfusion), drop_index)
            if len(drop_index) > 0:
                index_array = np.array(list(range(len(merged_dataset))))
                index_array = np.delete(index_array, np.array(drop_index)).tolist()
                merged_dataset = merged_dataset.select(index_array)
                for index in drop_index[::-1]:
                    del caption_capsfusion[index]
            merged_dataset = merged_dataset.add_column('caption_capsfusion', caption_capsfusion)
            print(merged_dataset)
            merged_dataset = merged_dataset.train_test_split(test_size=0.01, seed=42, shuffle=True)
            print(merged_dataset)
            merged_dataset.save_to_disk(os.path.join(args.dataset_dir, args.dataset_name, 'image_token', 'merged_new'), max_shard_size="20GB")
        # merge generations
        else:
            results = []
            for shard_index in range(args.num_gpu):
                results.extend(read_with_orjsonl(f"{result_path[:-6]}_shard_{shard_index}.jsonl"))
            results.sort(key=lambda x: x['row_index'])
            write_with_orjsonl(results, result_path)
            # remove 
            for shard_index in range(args.num_gpu):
                os.remove(f"{result_path[:-6]}_shard_{shard_index}.jsonl")
        exit(-1)

    llm = LLM(
        model=args.model_name_or_path,
        tokenizer=args.model_name_or_path,
        tokenizer_mode="slow" if args.use_slow_tokenizer else "auto",
        trust_remote_code=True,
        tensor_parallel_size=args.tensor_parallel_size,
        dtype=args.torch_dtype,
        seed=args.seed,
        swap_space=args.swap_space,
    )

    tokenizer = llm.get_tokenizer()

    def wrap(ori_caption, synthetic_caption):
        sample = f"Please merge and refine the information from the two given sentences. " \
                    f"Sentence 1 provides detailed real-world knowledge, " \
                    f"yet it suffers from flaws in sentence structure and grammar. " \
                    f"Sentence 2 exhibits nice sentence structure, " \
                    f"but lacking in-depth real-world details and may contain false information. " \
                    f"Please combine them into a new sentence, " \
                    f"ensuring a well-structured sentence while retaining the detailed real-world information provided in Sentence 1. " \
                    f"Avoid simply concatenating the sentences.\n\n" \
                    f"Sentence 1: {ori_caption}\n" \
                    f"Sentence 2: {synthetic_caption}\n" \
                    f"New Sentence:"
        return sample

    def truncate_input_text(ori, syn):
        total_len = len(tokenizer(wrap(ori, syn)).input_ids)
        if total_len >= 256:
            ori = tokenizer.decode(tokenizer(ori, add_special_tokens=False).input_ids[:-(total_len - 256)])
        return wrap(ori, syn)

    def prepare_inputs(examples):
        new_example = {
            # "row_index": idxes,
            "row_index": examples['input_index'],
            "input_text": [],
        }
        for caption_origin, caption_coco in zip(examples['caption_origin'], examples['caption_coco']):
            new_example['input_text'].append(truncate_input_text(caption_origin, caption_coco))
        return new_example

    # merge caption and select captions need to be generated
    if not os.path.exists(fused_caption_path):
        caption_dataset = load_from_disk(os.path.join(args.dataset_dir, args.dataset_name, "caption"))
        cur_column_names = caption_dataset.column_names
        merged_dataset = caption_dataset.map(
            prepare_inputs,
            batched=True,
            batch_size=args.process_batch_size,
            num_proc=args.process_num_workers,
            remove_columns=cur_column_names,
            desc=f"Preparing inputs",
        )
        os.makedirs(fused_caption_path, exist_ok=True)
        merged_dataset.save_to_disk(fused_caption_path, max_shard_size="20GB")
        exit(-1)

    sampling_params = SamplingParams(
        n=args.num_return_sequences,
        best_of=args.best_of,
        presence_penalty=args.presence_penalty,
        frequency_penalty=args.frequency_penalty,
        repetition_penalty=args.repetition_penalty,
        temperature=args.temperature,
        top_p=args.top_p,
        top_k=args.top_k,
        min_p=args.min_p,
        use_beam_search=args.use_beam_search,
        length_penalty=args.length_penalty,
        early_stopping=args.early_stopping,
        stop=args.stop,
        stop_token_ids=args.stop_token_ids,
        include_stop_str_in_output=args.include_stop_str_in_output,
        ignore_eos=args.ignore_eos,
        max_tokens=args.max_tokens,
        skip_special_tokens=args.skip_special_tokens,
        spaces_between_special_tokens=args.spaces_between_special_tokens,
        logits_processors=None,
    )
    print(f"Sampling params: {sampling_params}")

    result_path = os.path.join(args.dataset_dir, args.dataset_name, "fused_caption", "captions.jsonl")
    merged_dataset = load_from_disk(fused_caption_path)
    cur_merged_dataset = merged_dataset.shard(num_shards=args.num_gpu, index=args.dataset_shard_index, contiguous=True)
    responses = llm.generate(sampling_params=sampling_params, prompts=cur_merged_dataset["input_text"])
    results = []
    for row_index, response in tzip(cur_merged_dataset['row_index'], responses):
        response = [r.text.strip() for r in response.outputs]
        results.append({
            "row_index": row_index,
            "response": response[0],
        })
    write_with_orjsonl(results, f"{result_path[:-6]}_shard_{args.dataset_shard_index}.jsonl")


if __name__ == '__main__':
    main()