"""
Inference with bad questions as inputs
"""

import sys
sys.path.append('./')

import csv

import fire
import torch
import os
import warnings
import time
from typing import List
from collections import defaultdict

from peft import PeftModel, PeftConfig
from transformers import LlamaConfig, LlamaTokenizer, LlamaForCausalLM
from eval_utils.model_utils import load_model, load_peft_model
from eval_utils.prompt_utils import apply_prompt_template
import json
import copy
from eval_utils.chat_utils import question_read


def get_model(model_name: str, 
              peft_model: str=None,
              use_fast_kernels: bool = False,
              quantization: bool=False,
):
    model = load_model(model_name, quantization)
    if peft_model:
        model = load_peft_model(model, peft_model)

    model.eval()
    
    if use_fast_kernels:
        """
        Setting 'use_fast_kernels' will enable
        using of Flash Attention or Xformer memory-efficient kernels 
        based on the hardware being used. This would speed up inference when used for batched inputs.
        """
        try:
            from optimum.bettertransformer import BetterTransformer
            model = BetterTransformer.transform(model)    
        except ImportError:
            print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")

    tokenizer = LlamaTokenizer.from_pretrained(model_name)
    # tokenizer.add_special_tokens(
    #     {
         
    #         "pad_token": "<PAD>",
    #     }
    # )
    # model.resize_token_embeddings(model.config.vocab_size + 1) 
    
    return model, tokenizer


def inference(
    model,
    tokenizer,
    output_file: str = None,
    max_new_tokens = 512, #The maximum numbers of tokens to generate
    prompt_file: str='openai_finetuning/customized_data/manual_harmful_instructions.csv',
    prompt_template_style: str='base',
    seed: int=42, #seed value for reproducibility
    do_sample: bool=True, #Whether or not to use sampling ; use greedy decoding otherwise.
    min_length: int=None, #The minimum length of the sequence to be generated, input prompt + min_new_tokens
    use_cache: bool=True,  #[optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
    top_p: float=0.0, # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
    temperature: float=1.0, # [optional] The value used to modulate the next token probabilities.
    top_k: int=50, # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering.
    repetition_penalty: float=1.0, #The parameter for repetition penalty. 1.0 means no penalty.
    length_penalty: int=1, #[optional] Exponential penalty to the length that is used with beam-based generation. 
    enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
    enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
    enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
    max_padding_length: int=None, # the max padding length to be used with tokenizer padding the prompts.
    use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
    save_activation: bool = True,
    **kwargs
):
    torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)
    
    question_dataset = question_read(prompt_file)
    
    # Apply prompt template
    chats = apply_prompt_template(prompt_template_style, question_dataset, tokenizer)
    
    out = []
    
    act_all = defaultdict(dict)
    if save_activation:
        activation = {}
        def get_activation(name):
            def hook(model, input, output):
                activation[name] = output.detach()
            return hook
                
        for l in range(32):
            model.model.layers[l].mlp.register_forward_hook(get_activation('layer.%d.mlp' % l))
            model.model.layers[l].post_attention_layernorm.register_forward_hook(get_activation('layer.%d.post_attention_layernorm' % l))

    with torch.no_grad():
        
        for idx, chat in enumerate(chats):
            tokens= torch.tensor(chat).long()
            tokens= tokens.unsqueeze(0)
            tokens= tokens.to("cuda")
            
            input_token_length = tokens.shape[1]
            
            #user_prompt = "How are you?"
            #batch = tokenizer(user_prompt, padding='max_length', truncation=True,max_length=max_padding_length,return_tensors="pt")
            #batch = {k: v.to("cuda") for k, v in batch.items()}
            
            #print(batch)
            #exit(0)
                
            outputs = model.generate(
                input_ids = tokens,
                max_new_tokens=max_new_tokens,
                do_sample=do_sample,
                top_p=top_p,
                temperature=temperature,
                use_cache=use_cache,
                top_k=top_k,
                repetition_penalty=repetition_penalty,
                length_penalty=length_penalty,
                **kwargs
            )

            output_text = tokenizer.decode(outputs[0][input_token_length:], skip_special_tokens=True).strip()
            
            if save_activation:
                for l in range(32):
                    if idx not in act_all['mlp']:
                        act_all['mlp'][idx] = []
                    if idx not in act_all['post_attention_layernorm']:
                        act_all['post_attention_layernorm'][idx] = []
                        
                    act_all['mlp'][idx].append(activation['layer.%d.mlp' % l])
                    act_all['post_attention_layernorm'][idx].append(activation['layer.%d.post_attention_layernorm' % l])
            
            out.append({'prompt': question_dataset[idx], 'answer': output_text})
            print('\n\n\n')
            print('>>> sample - %d' % idx)
            print('prompt = ', question_dataset[idx])
            print('answer = ', output_text)
            
    if output_file is not None:
        with open(output_file, 'w') as f:
            for li in out:
                f.write(json.dumps(li))
                f.write("\n")
                
    if save_activation:
        import pickle
        
        act_file = output_file.replace('.jsonl', '_act.pkl')
        with open(act_file, 'wb') as handle:
            pickle.dump(act_all, handle, protocol=pickle.HIGHEST_PROTOCOL)
        print('activation saved at', act_file)
    
    return out


def main(
    bn_model_name: str,
    ft_model_name: str,
    exchange_layers: str,
    output_file: str = None,
    peft_model: str=None,
    use_fast_kernels: bool = False,
    quantization: bool=False,
    **kwargs,
):
    ## Set the seeds for reproducibility
    
    bn_model = load_model(bn_model_name, quantization)
    ft_model = load_model(ft_model_name, quantization)
    
    # os.makedirs('./temp', exist_ok=True)
    # torch.save(bn_model.state_dict(), './temp/bn.ckpt')
    # torch.save(ft_model.state_dict(), './temp/ft.ckpt')
    
    # bn_tokenizer = LlamaTokenizer.from_pretrained(bn_model_name)
    # ft_tokenizer = LlamaTokenizer.from_pretrained(ft_model_name)
    
    # bn_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
    # ft_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
    
    # bn_model.resize_token_embeddings(bn_model.config.vocab_size + 1) 
    # ft_model.resize_token_embeddings(ft_model.config.vocab_size + 1) 
    
    bn_model, bn_tokenizer = get_model(bn_model_name, peft_model, use_fast_kernels, quantization)
    ft_model, ft_tokenizer = get_model(ft_model_name, peft_model, use_fast_kernels, quantization)
    
    # torch.save(bn_model.state_dict(), './temp/bn_pre_ex.ckpt')
    # torch.save(ft_model.state_dict(), './temp/ft_pre_ex.ckpt')
    
    bn_state_dict = bn_model.state_dict()
    ft_state_dict = ft_model.state_dict()
    
    exchange_layers = str(exchange_layers)
    print('+'*100)
    print('exchange_layers: ', exchange_layers)
    print('+'*100)
    exchange_layer_kw = ['layers.%s.' % _l if _l.isnumeric() else _l for _l in exchange_layers.split('.')]

    exchange_layer_fullname = set()
    for _kw in exchange_layer_kw:
        for _layer in bn_state_dict.keys():
            if _kw in _layer:
                exchange_layer_fullname.add(_layer)
    
    bn_param_org, ft_param_org = 0, 0
    for _layer in exchange_layer_fullname:
        bn_param_org += torch.sum(bn_state_dict[_layer]).item()
        ft_param_org += torch.sum(ft_state_dict[_layer]).item()
        
        bn_state_dict[_layer], ft_state_dict[_layer] = copy.deepcopy(ft_state_dict[_layer]), copy.deepcopy(bn_state_dict[_layer])
        
    bn_other_param_org, ft_other_param_org = 0, 0
    for _layer in bn_state_dict.keys():
        if _layer not in exchange_layer_fullname:
            bn_other_param_org += torch.sum(bn_state_dict[_layer]).item()
            ft_other_param_org += torch.sum(ft_state_dict[_layer]).item()
        
    bn_model.load_state_dict(bn_state_dict)
    ft_model.load_state_dict(ft_state_dict)
    
    bn_state_dict = bn_model.state_dict()
    ft_state_dict = ft_model.state_dict()
    
    bn_param_new, ft_param_new = 0, 0
    for _layer in exchange_layer_fullname:
        bn_param_new += torch.sum(bn_state_dict[_layer]).item()
        ft_param_new += torch.sum(ft_state_dict[_layer]).item()
    
    assert bn_param_org == ft_param_new and ft_param_org == bn_param_new, (bn_param_org, ft_param_org, bn_param_new, ft_param_new)
    
    
    # torch.save(bn_model.state_dict(), './temp/bn_ex.ckpt')
    # torch.save(ft_model.state_dict(), './temp/ft_ex.ckpt')
    import pickle
    with open('./temp/exchange_layer_fullname.pkl', 'wb') as handle:
        pickle.dump(exchange_layer_fullname, handle, protocol=pickle.HIGHEST_PROTOCOL)
            
    # bn_other_param_new, ft_other_param_new = 0, 0
    # for _layer in bn_state_dict.keys():
    #     if _layer not in exchange_layer_fullname:
    #         assert torch.equal(bn_state_dict[_layer], bn_state_dict_new[_layer])
    #         assert torch.equal(ft_state_dict[_layer], ft_state_dict_new[_layer])
            
    del bn_state_dict, ft_state_dict
    torch.cuda.empty_cache()
    
    inference(bn_model, bn_tokenizer, output_file.replace('exchange', 'bn_exchange_%s' % exchange_layers), **kwargs)
    print('+'*100)
    print('+'*100)
    print('+'*100)
    inference(ft_model, ft_tokenizer, output_file.replace('exchange', 'ft_exchange_%s' % exchange_layers), **kwargs)
    del bn_model, ft_model, bn_tokenizer, ft_tokenizer
    torch.cuda.empty_cache()
    
    # out = defaultdict(lambda: defaultdict(str))
    # for _dict in bn_out:
    #     out[_dict['prompt']]['bn_answer'] = _dict['answer']
    # for _dict in ft_out:
    #     out[_dict['prompt']]['ft_answer'] = _dict['answer']   
        
    # for i, question in enumerate(out.keys()):
    #     print('\n\n\n')
    #     print('>>> sample - %d' % i)
    #     print('prompt = ', question)
    #     print('bn_answer = ', out[question]['bn_answer'])
    #     print('ft_answer = ', out[question]['ft_answer'])
    
# def main(window_size: int, **kwargs):
#     for l in range(33-window_size):
#         exchange_layers = [str(_l) for _l in range(l, l+window_size)]
        
#         print('+'*100)
#         print('exchange layers', exchange_layers)
#         print('+'*100)
        
#         exchange_layer(exchange_layers, **kwargs) 
    
    
if __name__ == "__main__":
    fire.Fire(main)