import torch
import json
import argparse
import sys
import os
import threading
import time
from tqdm import tqdm

from eval.perspective_api import perspectiveapi
from eval.moderation_api import ModerationAPI
from utils.utils import load_datasets, _find_save_path
from utils.constants import non_value_datasets

import numpy as np
import pandas as pd

from peft import PeftModel, PeftConfig
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
from torch import nn




def parsing_argument():
    parser = argparse.ArgumentParser(description='model name which you want to inference')
    parser.add_argument('--dataset', required=True, help='choose dataset to inference')
    parser.add_argument('--model_name', required=True, help='choose model name to inference')
    
    return parser.parse_args()


def inference(model_name, version, dataset_name):
  
    device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
    
    
    data = []
    dataset = load_datasets(dataset_name)
    
    output_folder = f'results/{dataset_name}/finetuning/{model_name}-non-new'
    os.makedirs(output_folder, exist_ok=True)
    
    # load model and tokenizer
    if version == 'vanilla':
        base_model = 'meta-llama/Llama-2-7b-chat-hf'
        model = LlamaForCausalLM.from_pretrained(base_model)
        tokenizer = LlamaTokenizer.from_pretrained(base_model)
    elif version in non_value_datasets:
        peft_model_id = f'../llama-recipes/recipes/finetuning/ckpts/{model_name}-7b-new/{version}'
        config = PeftConfig.from_pretrained(peft_model_id)
        model = LlamaForCausalLM.from_pretrained(config.base_model_name_or_path)
        model = PeftModel.from_pretrained(model, peft_model_id)
        tokenizer = LlamaTokenizer.from_pretrained(config.base_model_name_or_path)
    else:
        epoch_num = _find_save_path(f'../VIM/src/ckpt/argument_survey/{model_name}/min_TH_3/{version}')
        peft_model_id = f"../VIM/src/ckpt/argument_survey/{model_name}/min_TH_3/{version}/{epoch_num}"
        config = PeftConfig.from_pretrained(peft_model_id)
        model = LlamaForCausalLM.from_pretrained(config.base_model_name_or_path)
        model = PeftModel.from_pretrained(model, peft_model_id)
        tokenizer = LlamaTokenizer.from_pretrained(config.base_model_name_or_path)
 
    model = model.to(device)
    model.eval()

    generation_config = GenerationConfig(
                temperature=0.1,
                top_p=0.75,
                do_sample=True
                )
    
    
    # inference
    if dataset_name == 'rtp':
        length = 3000
    else:
        length = len(dataset)
        
    # for i in range(length):
    for i in tqdm(range(length)):
        print(f'====================[{version}:{i+1}/{length}]====================')
        
        prompt = dataset[i][0]
        inputs = tokenizer(prompt, return_tensors="pt")

        with torch.no_grad():
            sample_outputs = model.generate(
                input_ids=inputs["input_ids"].to("cuda"), 
                generation_config=generation_config,
                return_dict_in_generate=True,
                repetition_penalty=1.2,
                output_scores=True,
                max_new_tokens=128
                )
            
            s = sample_outputs.sequences[0]
            gen = tokenizer.decode(s, skip_special_tokens=True)
            output = gen.replace(prompt, '')
            
            print('input: ', prompt)
            print(f"LLaMA({version}): ", output)
            
            try:
                perspective_score = perspectiveapi(output)
                
                set_data = {
                    'input': prompt,
                    'output': output,
                    'perspectiveAPI': perspective_score
                }
                
                data.append(set_data)
                
                with open(f'{output_folder}/{version}.json', 'w') as outfile:
                    json.dump(data, outfile, indent=4)
            
            except:
                print("Ignore the error and continue to inference the next data...")
                continue
            
            
# df = pd.read_csv('data/country_and_group.csv', sep='\t')
# country_list = df['Country'].values.tolist()
# country_list = country_list[28:]


# country_list = [f'Group_{num+1}' for num in range(87, 89)]
# country_list = ['Group_89']
# cat_list = ['Conservation', 'Openness_to_Change', 'Self-Enhancement', 'Self-Transcendence']
close_list = [
    'close_Ach_8',
    'close_Ben_8',
    'close_Con_8',
    'close_Hed_8',
    # 'close_Pow_8',
    # 'close_Sec_8',
    # 'close_SD_8',
    # 'close_Sti_8',
    # 'close_Tra_8',
    # 'close_Uni_8',
    # 'close_Openness_to_Change_8',
    # 'close_Self-Enhancement_8',
    # 'close_Conservation_8',
    # 'close_Self-Transcendence_8'
]
ls = ['samsum']

if __name__ == '__main__':
    
    args = parsing_argument()
    print(ls)
    for country in ls:
        inference(args.model_name, country, args.dataset)
    
    print("Inference is done")