import torch
import json
import argparse
import sys
import os
import threading
import time

from eval.perspective_api import perspectiveapi
from eval.moderation_api import ModerationAPI
from utils.utils import load_datasets, _find_save_path
from utils.prompts import apply_prompt_template

import numpy as np
import pandas as pd

from peft import PeftModel, PeftConfig
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer
from torch import nn
from tqdm import tqdm




# def parsing_argument():
#     parser = argparse.ArgumentParser(description='model name which you want to inference')
#     parser.add_argument('--dataset', required=True, help='choose dataset to inference')
    
#     return parser.parse_args()


def inference(model_name):
  
    device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
    
    dataset_name = 'holisticbiasr_dispreferred'
    llama_version = 'llama2'
    
    output_dir = f'results/{dataset_name}/finetuning/{llama_version}-non-new/full'
    os.makedirs(output_dir, exist_ok=True)
    
    # load model and tokenizer
    if model_name == 'vanilla':
        if llama_version == 'llama2':
            base_model = 'meta-llama/Llama-2-7b-hf'
        elif 'chat' in llama_version.lower():
            base_model = 'meta-llama/Llama-2-7b-chat-hf'
        model = LlamaForCausalLM.from_pretrained(base_model)
        tokenizer = LlamaTokenizer.from_pretrained(base_model)
    elif model_name == 'samsum' or model_name == 'alpaca' or model_name == 'grammar' or model_name == 'dolly':
        if llama_version == 'llama2':
            peft_model_id = f'../llama-recipes/recipes/finetuning/ckpts/llama2-7b-new/{model_name}'
        elif 'chat' in llama_version.lower():
            peft_model_id = f'../llama-recipes/recipes/finetuning/ckpts/llama2-7b-chat/{model_name}'
        config = PeftConfig.from_pretrained(peft_model_id)
        model = LlamaForCausalLM.from_pretrained(config.base_model_name_or_path)
        model = PeftModel.from_pretrained(model, peft_model_id)
        tokenizer = LlamaTokenizer.from_pretrained(config.base_model_name_or_path)
    else:
        epoch_num = _find_save_path(f"../VIM/src/ckpt/argument_survey/{llama_version}/min_TH_3/{model_name}")
        peft_model_id = f"../VIM/src/ckpt/argument_survey/{llama_version}/min_TH_3/{model_name}/{epoch_num}"
        config = PeftConfig.from_pretrained(peft_model_id)
        model = LlamaForCausalLM.from_pretrained(config.base_model_name_or_path)
        model = PeftModel.from_pretrained(model, peft_model_id)
        tokenizer = LlamaTokenizer.from_pretrained(config.base_model_name_or_path)

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        
        model = nn.DataParallel(model)
 
    model = model.to(device)
    model.eval()

    generation_config = GenerationConfig(
                temperature=0.1,
                top_p=0.75,
                do_sample=True
                )
    
    data = []
    dataset = load_datasets(dataset_name)
    
    
    # inference
    # for i in range(len(dataset)):
    for i in tqdm(range(len(dataset))):
        # print(f'====================[{i+1}/{len(dataset)}]====================')
        
        prompt = dataset['prompt_text'][i]
        inputs = tokenizer(prompt, return_tensors="pt")

        with torch.no_grad():
            outputs = model.generate(
                input_ids=inputs["input_ids"].to("cuda"), 
                generation_config=generation_config,
                return_dict_in_generate=True,
                repetition_penalty=1.2,
                output_scores=True,
                max_new_tokens=40
                )
            
            s = outputs.sequences[0]
            output = tokenizer.decode(s, skip_special_tokens=True)
            
            # print('input: ', prompt)
            # print(f"LLaMA({model_name}): ", output)
            

            set_data = {
                'prompt_text': prompt,
                'response': output,
                'meta': {
                    "formatted_noun_phrase": dataset['formatted_noun_phrase'][i],
                    "context": dataset['context'][i],
                    "template": dataset['template'][i],
                    "axis": dataset['axis'][i],
                    "bucket": dataset['bucket'][i],
                    "descriptor": dataset['descriptor'][i],
                    "descriptor_gender": dataset['descriptor_gender'][i],
                    "descriptor_preference": dataset['descriptor_preference'][i],
                    "noun": dataset['noun'][i],
                    "plural_noun": dataset['plural_noun'][i],
                    "noun_gender": dataset['noun_gender'][i],
                    "noun_phrase": dataset['noun_phrase'][i],
                    "plural_noun_phrase": dataset['plural_noun_phrase'][i],
                    "noun_phrase_type": dataset['noun_phrase_type'][i],
                    },
            }
        
            data.append(set_data)
            
            with open(f'{output_dir}/{model_name}.json', 'w') as outfile:
                json.dump(data, outfile, indent=4)
            
            
df = pd.read_csv('data/country_and_group.csv', sep='\t')
country_list = df['Country'].values.tolist()
country_list = country_list[28:]

non_value_list = ['samsum']

val_list = ['Ach', 'Ben', 'Con', 'Hed', 'Pow', 'Sec', 'SD', 'Sti', 'Tra', 'Uni']
cat_list = ['Conservation', 'Openness_to_Change', 'Self-Enhancement', 'Self-Transcendence']

close_list = [
    'close_Ach_9',
    'close_Ben_9',
    'close_Con_9',
    'close_Hed_9',
    'close_Pow_9',
    'close_Sec_9',
    'close_SD_9',
    # 'close_Sti_9',
    # 'close_Tra_9',
    # 'close_Uni_9',
    # 'close_Openness_to_Change_9',
    # 'close_Self-Enhancement_9',
    # 'close_Conservation_9',
    # 'close_Self-Transcendence_9'
]

if __name__ == '__main__':
    
    # args = parsing_argument()
    # print(country_list[55:100])
    
    print('models to inference: ', non_value_list)
    for country in non_value_list:
        inference(country)
        
    print("Inference is done")