"""
Inference with bad questions as inputs
"""

import sys, os
sys.path.append(os.path.abspath('../..'))

import csv

import fire
import torch
import evaluate
import json
from collections import defaultdict

from transformers import LlamaConfig, LlamaTokenizer, LlamaForCausalLM
from utils.load_ft_data import load_ft_dataset
from llama2.safety_evaluation.eval_utils.model_utils import load_model, load_peft_model
from llama2.safety_evaluation.eval_utils.prompt_utils import apply_prompt_template




def main(
    model_name,
    data_name: str='<to set: fibe-tuning data only>',
    
    peft_model: str=None,
    quantization: bool=False,
    max_new_tokens = 512, #The maximum numbers of tokens to generate
    seed: int=42, #seed value for reproducibility
    do_sample: bool=True, #Whether or not to use sampling ; use greedy decoding otherwise.
    min_length: int=None, #The minimum length of the sequence to be generated, input prompt + min_new_tokens
    use_cache: bool=True,  #[optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
    top_p: float=0, # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
    temperature: float=1.0, # [optional] The value used to modulate the next token probabilities.
    top_k: int=50, # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering.
    repetition_penalty: float=1.0, #The parameter for repetition penalty. 1.0 means no penalty.
    length_penalty: int=1, #[optional] Exponential penalty to the length that is used with beam-based generation. 
    enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
    enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
    enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
    max_padding_length: int=None, # the max padding length to be used with tokenizer padding the prompts.
    use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
    output_dir: str = None,
    batch_size: int = 4,
    n_iteration: int = 3,
    beam_search_k: int = 1,
    **kwargs
):
    

    ## Set the seeds for reproducibility
    torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)
    
    model = load_model(model_name, quantization)
    if peft_model:
        model = load_peft_model(model, peft_model)
    
    model.eval()
    
    if use_fast_kernels:
        """
        Setting 'use_fast_kernels' will enable
        using of Flash Attention or Xformer memory-efficient kernels 
        based on the hardware being used. This would speed up inference when used for batched inputs.
        """
        try:
            from optimum.bettertransformer import BetterTransformer
            model = BetterTransformer.transform(model)    
        except ImportError:
            print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")

    tokenizer = LlamaTokenizer.from_pretrained(model_name)
    
    clean_qa = load_ft_dataset(data_name) # list[(q_text, a_text)]
    clean_dataset = [q+'\n'+a for q, a in clean_qa][:100]

    perplexity = evaluate.load("perplexity", module_type="metric")
    

    for it in range(n_iteration):
        # Apply prompt template
        chats = apply_prompt_template('regression', clean_dataset, tokenizer)
        out = defaultdict(list)
        
        for temperature in [0.1, 0.25, 0.5, 1.0]:
            for top_p in [0.1, 0.25, 0.5]:
        # for temperature in [0.25, 0.5, 0.75, 1.0]:
        #     for top_p in [0.25, 0.5, 0.75, 1.0]:   
                with torch.no_grad():
                    
                    for idx in range(0, len(chats), batch_size):
                        batch_chats = chats[idx:idx+batch_size]
                        
                        # padding 
                        max_l = max([len(_chat) for _chat in batch_chats])
                        batch_chats = [[0]*(max_l-len(_chat)) + _chat for _chat in batch_chats]
                        
                        tokens= torch.tensor(batch_chats).long()
                        # tokens= tokens.unsqueeze(0)
                        tokens= tokens.to("cuda")
                        
                        input_token_length = tokens.shape[1]
                        top_k=50
                        outputs = model.generate(
                            input_ids = tokens,
                            max_new_tokens=max_new_tokens,
                            do_sample=do_sample,
                            top_p=top_p,
                            temperature=temperature,
                            use_cache=use_cache,
                            top_k=top_k,
                            repetition_penalty=repetition_penalty,
                            length_penalty=length_penalty,
                            **kwargs
                        )
                        batch_output_text = tokenizer.batch_decode(outputs[:, input_token_length:], skip_special_tokens=True)
                        
                        ppl = perplexity.compute(model_id=model_name,
                                        add_start_token=False,
                                        device='cpu',
                                        predictions=[clean_qa[idx+_i][0] + '\n' + batch_output_text[_i] 
                                                     for _i in range(len(batch_output_text))],
                            )['perplexities']
                        
                        # TODO: debug, should use exact q not idx when beam_search_k>1
                        
                        for _i in range(len(batch_output_text)):
                            org_idx = idx + _i
                            out[clean_qa[org_idx][0]].append({
                                'question': clean_qa[org_idx][0],
                                'org_answer': clean_qa[org_idx][1], 
                                'answer': batch_output_text[_i].strip(), 
                                'ppl': ppl[_i], 
                                'temperature': temperature,
                                'top_p': top_p,
                                'iteration': it})
                            
                            print('\n\n\n')
                            print('>>> sample - %d' % org_idx)
                            print('question = ', clean_qa[org_idx][0])
                            print('answer = ', batch_output_text[_i].strip())
                            print('ppl = ', ppl[_i])
                            print('iteration =', it, ' temperature =', temperature, ' top_p =', top_p)
                            
        # beam search
        clean_dataset = []
                
        if output_dir is not None:
            os.makedirs(output_dir, exist_ok=True)
            with open(output_dir + '/%d.jsonl' % it, 'w') as f:
                for q, info in out.items():
                    sorted_info = sorted(info, key=lambda d: d['ppl'], reverse=False)[:beam_search_k]
                    for _d in sorted_info:
                        clean_dataset.append(q+'\n'+_d['answer'])
                
                        f.write(json.dumps(_d))
                        f.write("\n")
                    
        # print('len(clean_dataset)', len(clean_dataset))



if __name__ == "__main__":
    fire.Fire(main)