import argparse
import json
import os
import sys
import torch
from vllm import LLM
from vllm.sampling_params import SamplingParams
from tqdm import tqdm
from datasets import load_from_disk
from transformers.generation.logits_process import LogitsProcessor, UnbatchedClassifierFreeGuidanceLogitsProcessor
from eval_utils import vllm_configs, eval_score, choice_count_dict, stop_tokens, stop_tokens_ids
import multiprocessing
import numpy as np


class SuppressTokensLogitsProcessorText(LogitsProcessor):
    r"""This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so that they
    are not sampled."""

    def __init__(self, start_index, torch_dtype):
        self.start_index = start_index
        self.min = torch.finfo(torch_dtype).min

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # print(scores.shape, scores.dtype)
        scores[self.start_index:] = self.min
        return scores

class SuppressTokensLogitsProcessorImage(LogitsProcessor):
    r"""This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so that they
    are not sampled."""

    def __init__(self, end_index, torch_dtype):
        self.end_index = end_index
        self.min = torch.finfo(torch_dtype).min

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        scores[:self.end_index] = self.min
        return scores

class SuppressTokensLogitsProcessorChoices(LogitsProcessor):
    r"""This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so that they
    are not sampled."""

    def __init__(self, suppress_ids, torch_dtype, choices):
        self.suppress_ids = suppress_ids
        self.choices = choices
        self.min = torch.finfo(torch_dtype).min

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        scores[self.suppress_ids] = self.min
        return scores

def parse_args():
    parser = argparse.ArgumentParser()
    
    # Model settings
    parser.add_argument("--model_name_or_path", type=str, default='YOUR_ROOT_PATH/model/llama2-1229/Llama-2-7b-hf')
    parser.add_argument("--visual_codebook", type=str, default='YOUR_ROOT_PATH/model/LaVIT-7B-v2', help="Path to pretrained visual codebook.")
    parser.add_argument("--checkpoint_path", type=str, default='YOUR_ROOT_PATH/model/checkpoint/MLLM/adjust_OIv3_lora_custom_only_ic_e3_512_2e_ls1_uni/last_1392')
    parser.add_argument('--src_path', type=str, default='YOUR_ROOT_PATH/MLLM/src', help='path to src code')
    parser.add_argument("--tokenizer", type=str, default='YOUR_ROOT_PATH/model/checkpoint/MLLM/tokenizer', help="Path to tokenizer directory.")
    parser.add_argument("--vl_vocab_size", type=int, default=48386, help="The vocab size of vision-language vocab.")
    parser.add_argument("--image_start_token_id", type=int, default=32000, help="The start token id of image tokens.")
    parser.add_argument("--expand_vocab", type=str, default="normal", help="How to expand the language vocab to vision-language vocab.", choices=["normal", "random", "factorized"])
    parser.add_argument("--factorized_linear_mlp", action="store_true", help="Whether to use mlp as factorized linear.")
    parser.add_argument("--use_lora", action="store_true", help="Whether to use LoRA.")
    parser.add_argument("--use_slow_tokenizer", action="store_true", help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--tensor_parallel_size", type=int, default=1)
    parser.add_argument("--num_gpu", type=int, default=1, help="The number of gpus to use.")
    parser.add_argument("--dataset_shard_index", type=int, default=0, help="The shard index of the dataset.")
    parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=['float16', 'bfloat16', 'float32'])
    
    
    # Dataset settings
    parser.add_argument('--dataset_name', type=str, default='COCO', help='dataset name')
    parser.add_argument('--dataset_dir', type=str, default='YOUR_ROOT_PATH/data/MLLM/Evaluation', help='path to dataset dir')
    parser.add_argument("--result_dir", type=str, default='YOUR_ROOT_PATH/data/MLLM/Evaluation/results', help='path to output dir')
    parser.add_argument('--prompt_setting', type=str, default='zero_shot', help='prompt setting')
    parser.add_argument('--template_index', type=int, default=0, help='template index')
    parser.add_argument('--swap_space', type=int, default=20, help='swap space in GB')
    parser.add_argument("--from_hf", action="store_true", help="Whether to use huggingface datasets.")
    parser.add_argument("--generation_mode", type=str, default="text", choices=["text", "image", "image-text"], help="The generation mode.")
    parser.add_argument("--use_config", action="store_true", help="Whether to use config.")
    parser.add_argument('--process_batch_size', type=int, default=200, help='process batch size')
    parser.add_argument('--process_num_workers', type=int, default=multiprocessing.cpu_count(), help='preprocessing num workers')
    parser.add_argument('--debug', type=int, default=0, help="Whether to debug.")
    
    # Generation settings
    parser.add_argument("--num_return_sequences", type=int, default=1, help="Number of output sequences to return for the given prompt.")
    parser.add_argument("--best_of", type=int, default=None, help="Number of output sequences that are generated from the prompt. From these `best_of` sequences, the top `n` sequences are returned. `best_of` must be greater than or equal to `n`. This is treated as the beam width when `use_beam_search` is True. By default, `best_of` is set to `n`.")
    parser.add_argument("--presence_penalty", type=float, default=0.0, help="Float that penalizes new tokens based on whether they appear in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat tokens.")
    parser.add_argument("--frequency_penalty", type=float, default=0.0, help="Float that penalizes new tokens based on their frequency in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat tokens.")
    parser.add_argument("--repetition_penalty", type=float, default=1.0, help="Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far. Values > 1 encourage the model to use new tokens, while values < 1 encourage the model to repeat tokens.")
    parser.add_argument("--temperature", type=float, default=1.0, help="Float value controlling randomness in boltzmann distribution. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.")
    parser.add_argument("--top_p", type=float, default=1.0, help="Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to 1 to consider all tokens.")
    parser.add_argument("--top_k", type=int, default=-1, help="Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.")
    parser.add_argument("--min_p", type=float, default=0.0, help="Float that represents the minimum probability for a token to be considered, relative to the probability of the most likely token. Must be in [0, 1]. Set to 0 to disable this.")
    parser.add_argument("--use_beam_search", action="store_true", help="Whether to use beam search instead of sampling.")
    parser.add_argument("--length_penalty", type=float, default=1.0, help="Float that penalizes sequences based on their length. Used in beam search.")
    parser.add_argument("--early_stopping", type=str, default="False", choices=["True", "False", "never"], help="Controls the stopping condition for beam search. It accepts the following values: `True`, where the generation stops as soon as there are `best_of` complete candidates; `False`, where an heuristic is applied and the generation stops when is it very unlikely to find better candidates; `'never'`, where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm).")
    parser.add_argument("--stop", nargs='+', default=stop_tokens, help="List of strings that stop the generation when they are generated. The returned output will not contain the stop strings.")
    parser.add_argument("--stop_token_ids", nargs='+', default=stop_tokens_ids, help=" List of tokens that stop the generation when they are generated. The returned output will contain the stop tokens unless the stop tokens are special tokens.")
    parser.add_argument("--include_stop_str_in_output", action="store_true", help="Whether to include the stop strings in output text. Defaults to False.")
    parser.add_argument("--ignore_eos", action="store_true", help="Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.")
    parser.add_argument("--max_tokens", type=int, default=36, help="Maximum number of tokens to generate per output sequence.")
    parser.add_argument("--skip_special_tokens", action="store_true", help="Whether to skip special tokens in the output.")
    parser.add_argument("--spaces_between_special_tokens", action="store_true", help="Whether to add spaces between special tokens in the output.")
    # logprobs: Number of log probabilities to return per output token. Note that the implementation follows the OpenAI API: The return result includes the log probabilities on the `logprobs` most likely tokens, as well the chosen tokens. The API will always return the log probability of the sampled token, so there  may be up to `logprobs+1` elements in the response.
    # prompt_logprobs: Number of log probabilities to return per prompt token.
    # logits_processors: List of functions that modify logits based on previously generated tokens.


    args = parser.parse_args()
    os.makedirs(args.result_dir, exist_ok=True)



    if args.torch_dtype == "bfloat16":
        args.torch_dtype = torch.bfloat16
    elif args.torch_dtype == "float16":
        args.torch_dtype = torch.float16
    elif args.torch_dtype == "float32":
        args.torch_dtype = torch.float32
    else:
        raise ValueError(f"Invalid torch dtype: {args.torch_dtype}")

    if args.early_stopping == "True":
        args.early_stopping = True
    elif args.early_stopping == "False":
        args.early_stopping = False

    return args

def main():
    args = parse_args()
    
    import sys
    sys.path.append(args.src_path)
    from merge_model import merge_to_base_model
    from utils import read_with_orjsonl, write_with_orjsonl, write_with_orjsonl_extend
    sys.path.append(args.src_path + '/../data')
    from Evaluation import dataset_name_split_mapping, dataset_name_answer_mapping, get_dataset_type

    if args.use_lora or args.expand_vocab == "factorized":
        merged_path = merge_to_base_model(args, device_map={"": "cuda"})
    else:
        merged_path = args.checkpoint_path

    if args.tokenizer and not os.path.exists(args.tokenizer):
        # just for vllm, since it will assert if the input token ids exceed the tokenizer vocab size
        from transformers import AutoTokenizer
        tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, legacy=False, use_fast=not args.use_slow_tokenizer)
        tokenizer.pad_token_id = tokenizer.eos_token_id
        tokenizer.add_special_tokens({'additional_special_tokens': ['<image>']}, replace_additional_special_tokens=False)
        tokenizer.add_special_tokens({'additional_special_tokens': ['</image>']}, replace_additional_special_tokens=False)
        image_start_token = tokenizer.additional_special_tokens[0]
        image_start_token_id = tokenizer.additional_special_tokens_ids[0]
        assert image_start_token_id == args.image_start_token_id
        tokenizer.add_tokens([f"<image_{str(i)}>" for i in range(16384)]) # 48386-32000-2
        tokenizer.save_pretrained(args.tokenizer)
        print(tokenizer)

    llm = LLM(
        model=merged_path,
        tokenizer=args.tokenizer if args.tokenizer else args.checkpoint_path,
        tokenizer_mode="slow" if args.use_slow_tokenizer else "auto",
        trust_remote_code=True,
        tensor_parallel_size=args.tensor_parallel_size,
        dtype=args.torch_dtype,
        seed=args.seed,
        swap_space=args.swap_space,
    )

    eval_path = os.path.join(args.dataset_dir, args.dataset_name, 'eval')
    args.dataset_type = get_dataset_type(args.dataset_name)
    os.makedirs(os.path.join(args.result_dir, f"{args.dataset_name}"), exist_ok=True)

    # prepare sampling params
    # use config setting in eval_utils.py
    if args.use_config:
        if args.prompt_setting.endswith('_choices_ppl'):
            specific_config = vllm_configs['multi-choice-ppl']
        else:
            specific_config = vllm_configs[args.dataset_name]
        print(f"Use specific config for {args.dataset_name}:")
        print(specific_config)
        for key, value in specific_config.items():
            setattr(args, key, value)

    logits_processors = None
    sampling_params = SamplingParams(
        n=args.num_return_sequences,
        best_of=args.best_of,
        presence_penalty=args.presence_penalty,
        frequency_penalty=args.frequency_penalty,
        repetition_penalty=args.repetition_penalty,
        temperature=args.temperature,
        top_p=args.top_p,
        top_k=args.top_k,
        min_p=args.min_p,
        use_beam_search=args.use_beam_search,
        length_penalty=args.length_penalty,
        early_stopping=args.early_stopping,
        stop=args.stop,
        stop_token_ids=args.stop_token_ids,
        include_stop_str_in_output=args.include_stop_str_in_output,
        ignore_eos=args.ignore_eos,
        max_tokens=args.max_tokens,
        skip_special_tokens=args.skip_special_tokens,
        spaces_between_special_tokens=args.spaces_between_special_tokens,
        logits_processors=logits_processors,
    )

    # prepare logits processors
    if args.prompt_setting.endswith('_choices_ppl'):
        token_id_array = np.array(list(range(args.vl_vocab_size)))
        if args.dataset_type == 'Y/N':
            # prepare a yes_or_no logits processor
            yes_or_no_ids = [3582, 3869, 4874, 8241, 21143, 22483, 694, 1217, 1939, 3782, 6632, 11698]
            suppress_ids = np.delete(token_id_array, yes_or_no_ids).tolist()
            sampling_params.logits_processors = [SuppressTokensLogitsProcessorChoices(suppress_ids=suppress_ids, torch_dtype=args.torch_dtype, choices=llm.get_tokenizer().decode(yes_or_no_ids))]
        elif args.dataset_type == 'multi-choice':
            choice_ids_all = [350, 315, 360, 382, 383, 402, 379, 306] # B C D E F G H I
            choice_ids_all_single = [29933, 29907, 29928, 29923, 29943, 29954, 29950, 29902] # B C D E F G H I without space
            used_choice_ids = [319] # A
            used_choice_ids += [29909]
            choice_logits_processors = []
            for choice_ids, choice_ids_single in zip(choice_ids_all, choice_ids_all_single):
                used_choice_ids.append(choice_ids)
                used_choice_ids.append(choice_ids_single)
                suppress_ids = np.delete(token_id_array, used_choice_ids).tolist()
                choice_logits_processors.append(SuppressTokensLogitsProcessorChoices(suppress_ids=suppress_ids, torch_dtype=args.torch_dtype, choices=llm.get_tokenizer().decode(used_choice_ids)))
    else:
        if args.generation_mode == "text":
            sampling_params.logits_processors = [SuppressTokensLogitsProcessorText(start_index=args.image_start_token_id, torch_dtype=args.torch_dtype)]
        elif args.generation_mode == "image":
            sampling_params.logits_processors = [SuppressTokensLogitsProcessorImage(end_index=args.image_start_token_id + 1, torch_dtype=args.torch_dtype)]
        elif args.generation_mode == "image-text":
            # TODO: implement a logitsprocessor for image-text interleaved generation
            raise NotImplementedError
        else:
            raise ValueError(f"Invalid generation mode: {args.generation_mode}")

    # load dataset
    tokenizer = llm.get_tokenizer()
    result_path = os.path.join(args.result_dir, f"{args.dataset_name}/{args.prompt_setting}_{args.template_index}_{'_'.join(args.checkpoint_path.split('/')[-2:])}.jsonl")
    combined_dataset = load_from_disk(os.path.join(eval_path, f"{args.prompt_setting}_{args.template_index}"))
    results = []

    # print the first 100 examples for checking
    def detokenize_text_part(input_tokens, image_start_token_id):
        text_token_lists = []
        text_start = False
        for token_id in input_tokens:
            if token_id >= image_start_token_id:
                if text_start:
                    text_token_lists.append(text_token_list)
                    text_start = False
            else:
                if not text_start:
                    text_token_list = []
                    text_start = True
                text_token_list.append(token_id)

        if text_start:
            text_token_lists.append(text_token_list)

        text_tokens = [tokenizer.decode(text_token_list) for text_token_list in text_token_lists]
        text_tokens = "[image]".join(text_tokens)
        return text_tokens
    
    check_examples = [detokenize_text_part(example, args.image_start_token_id) for example in combined_dataset["input_tokens"][:100]]
    print(f"Final Input Examples: {check_examples}")

    if args.dataset_shard_index != -1:
        if args.debug:
            cur_combined_dataset = combined_dataset.select(range(20000)).shard(num_shards=args.num_gpu, index=args.dataset_shard_index, contiguous=True) # debug for VQAv2_VAL
        else:
            cur_combined_dataset = combined_dataset.shard(num_shards=args.num_gpu, index=args.dataset_shard_index, contiguous=True)
        
        if args.prompt_setting.endswith('_choices_ppl') and args.dataset_type == 'multi-choice':
            choice_count_list = choice_count_dict[args.dataset_name]
            for choice_count in choice_count_list:
                sampling_params.logits_processors = [choice_logits_processors[choice_count - 2]]
                print(f"Count {choice_count}\tChoices {sampling_params.logits_processors[0].choices}\tSampling params: {sampling_params}")
                cur_combined_dataset_subset = cur_combined_dataset.filter(lambda example: example['choice_count'] == choice_count)
                if cur_combined_dataset_subset.num_rows > 0:
                    responses = llm.generate(sampling_params=sampling_params, prompt_token_ids=cur_combined_dataset_subset["input_tokens"], use_tqdm=False)
                    for input_index, response in zip(cur_combined_dataset_subset['input_index'], responses):
                        response = [r.text.strip() for r in response.outputs]
                        results.append({
                            "input_index": input_index,
                            "response": response[0],
                        })
        elif args.prompt_setting == "zero_shot" or "few_shot" in args.prompt_setting or (args.dataset_type == 'Y/N' and args.prompt_setting.endswith('_choices_ppl')):
            responses = llm.generate(sampling_params=sampling_params, prompt_token_ids=cur_combined_dataset["input_tokens"], use_tqdm=False)
            for input_index, response in zip(cur_combined_dataset['input_index'], responses):
                response = [r.text.strip() for r in response.outputs]
                results.append({
                    "input_index": input_index,
                    "response": response[0],
                })
        else:
            raise ValueError(f"Invalid prompt setting and dataset_type: {args.prompt_setting}, {args.dataset_type}")
        
        write_with_orjsonl(results, f"{result_path[:-6]}_shard_{args.dataset_shard_index}.jsonl")
    else: # time to merge
        for shard_index in range(args.num_gpu):
            results.extend(read_with_orjsonl(f"{result_path[:-6]}_shard_{shard_index}.jsonl"))
        results.sort(key=lambda x: x['input_index'])
        write_with_orjsonl(results, result_path)
        # remove 
        for shard_index in range(args.num_gpu):
            os.remove(f"{result_path[:-6]}_shard_{shard_index}.jsonl")
        eval_score(args, results)

if __name__ == '__main__':
    main()
