import os
import torch
import numpy as np
import gc
from easydict import EasyDict
from peft import PeftModel
from transformers import (
    AutoTokenizer, 
    AutoConfig, 
    AutoModelForCausalLM, 
)
import safeft.eval as eval
import safeft.utils as utils
from safeft.utils import configs
from argparse import ArgumentParser
import json
import copy
import re
def parse_args_and_config(args_list=None):
    parser = ArgumentParser()
    parser.add_argument("--utility_training_num", type=int, default=None, help="")
    parser.add_argument("--poison_training_num", type=int, default=None, help="")
    parser.add_argument("--name_suffix", type=str, default="", help="")
    parser.add_argument("--base_model_path", type=str, default="meta-llama/Llama-2-7b-chat-hf", help="")
    #parser.add_argument("--eval_save_path", type=str, default=None, help="")
    parser.add_argument("--checkpoint", type=str, default="-1", help="Checkpoint index to evaluate")

    parser.add_argument("--use_lora", type=int, default=1, help="Use LoRA")
    parser.add_argument("--utility_dataset_config", type=str, default="utility_dataset_config_samsum", help="Utility dataset config name")
    parser.add_argument("--poison_dataset_config", type=str, default="poison_dataset_config_RTA", help="Poison dataset config name")
    parser.add_argument("--output_dir", type=str, default=None, help="output directory for finetune models from finetune.py")
    parser.add_argument("--batch_size", type=int, default=None, help="Batch size for evaluation. If not provided, will use default batch size.")
    parser.add_argument("--identity_shift", type=int, default=0, help="Identity shift for evaluation. If not provided, will use default chat template")
    parser.add_argument("--harm_alpaca", type=int, default=0, help="Harm alpaca for evaluation. If not provided, will use default harm alpaca.")
    parser.add_argument("--logit_test", type=int, default=1, help="Whether to only use logits for evaluation.")
    parser.add_argument("--utility_test", type=int, default=1, help="Whether to only use logits for evaluation.")
    parser.add_argument("--attack_test", type=int, default=1, help="Whether to only use logits for evaluation.")
    parser.add_argument("--eval_harmful_score", type=int, default=1, help="Whether to evaluate harmful score.")
    parser.add_argument("--backdoor", type=int, default=0, help="Whether to use backdoor prefix for evaluation.")
    
    parser.add_argument("--generator_name", type=str, default=None, help="Generator name for evaluation.")
    parser_args = parser.parse_args(args_list) if args_list is not None else parser.parse_args()


    args = EasyDict(configs)
    args.poison_dataset_config = copy.deepcopy(getattr(args, parser_args.poison_dataset_config))
    args.utility_dataset_config = copy.deepcopy(getattr(args, parser_args.utility_dataset_config))
    args.name_suffix = parser_args.name_suffix
    args.base_model_path = parser_args.base_model_path
    args.model_path_name = args.base_model_path.split("/")[-1] 
    args.identity_shift = parser_args.identity_shift
    args.backdoor = parser_args.backdoor
    args.generator_name = parser_args.generator_name
    if parser_args.utility_training_num is not None:
        args.utility_dataset_config.train_num = parser_args.utility_training_num
    if parser_args.poison_training_num is not None:
        args.poison_dataset_config.train_num = parser_args.poison_training_num
    if parser_args.output_dir is not None:
        args.output_dir = parser_args.output_dir
    return args, parser_args

def load_model_and_tokenizer(path):
    tokenizer = AutoTokenizer.from_pretrained(path)
    tokenizer.pad_token = tokenizer.eos_token
    base_model = AutoModelForCausalLM.from_pretrained(
        path,
        torch_dtype=torch.bfloat16,
        device_map="auto"
    )
    return base_model, tokenizer

def load_model_only(path):
    base_model = AutoModelForCausalLM.from_pretrained(
        path,
        torch_dtype=torch.bfloat16,
        device_map="auto"
    )
    return base_model

def get_eval_datasets(args):
    dataset_split_utility = getattr(utils, args.utility_dataset_config.prepare_function)(args.utility_dataset_config)
    dataset_split_adveval = getattr(utils, args.adveval_dataset_config.prepare_function)(args.adveval_dataset_config)
    return dataset_split_utility, dataset_split_adveval

def get_token_logit(logits,tokenizer,token="I"):
    return logits[tokenizer.convert_tokens_to_ids(tokenizer.tokenize(token))[0]]


def evaluate(model, tokenizer, args, parser_args,model_path, dataset_split_utility, dataset_split_adveval):
    result_save_path = f"cache/evaluations/{model_path}/evaluation.json"
    logits_save_path = f"cache/evaluations/{model_path}/logits.npy"
    utility_save_path = f"cache/evaluations/{model_path}/utility"
    attack_save_path = f"cache/evaluations/{model_path}/attack_result.json"
    result = {"utility_rate": None, "attack_success_rate": None, "harmful_score": None, "logits_score": None}
    if parser_args.utility_test:
        print("###### now evaluating fineunting with utility data #########")
        utility_rate, utility_result = eval.utility_test(dataset_split_utility['test'], tokenizer, model, args, load_save_path=utility_save_path, batch_size=parser_args.batch_size)
        result["utility_rate"] = utility_rate   
    if parser_args.attack_test:
        print("###### now evaluating fineunting with adveval data #########")
        attack_success_rate, harmful_score, attack_result = eval.attack_test(dataset_split_adveval['test'], tokenizer, model, args, load_save_path=attack_save_path, batch_size=parser_args.batch_size,eval_harmful_score=parser_args.eval_harmful_score)
        result["attack_success_rate"] = attack_success_rate
        result["harmful_score"] = harmful_score
    if parser_args.logit_test:
        print("###### now saving logits with adveval data #########")
        logits_list = eval.logits_test(dataset_split_adveval['test'], tokenizer, model, args, load_save_path=logits_save_path)
        result["logits_I_score"] = np.mean([get_token_logit(logits,tokenizer,token="I") for logits in logits_list]).item()
        result["logits_Sorry_score"] = np.mean([get_token_logit(logits,tokenizer,token="Sorry") for logits in logits_list]).item()
        result["logits_Sure_score"] = np.mean([get_token_logit(logits,tokenizer,token="Sure") for logits in logits_list]).item()
        result["logits_score"] = result["logits_I_score"] - result["logits_Sure_score"]
    print(result)
    with open(result_save_path, 'w') as f:
        json.dump(result, f)
    print(f"Evaluation result saved to {result_save_path}")
    print("###############################")
    del model
    return result

def main():
    args, parser_args = parse_args_and_config()
    dataset_split_utility, dataset_split_adveval = get_eval_datasets(args)
    print(f"utility dataset size for test: {len(dataset_split_utility['test'])}")
    print(f"adveval dataset size for test: {len(dataset_split_adveval['test'])}")
    model_path_list = utils.get_lora_path_list(args, father_path=args.output_dir)
    #print final lora model path
    if parser_args.use_lora:
        base_model, tokenizer = load_model_and_tokenizer(args.base_model_path)
        if parser_args.checkpoint == "all":
            print("evaluating all lora checkpoints")
            for model_path in model_path_list:
                print(f"lora model path: {model_path}")
                model = PeftModel.from_pretrained(base_model, model_path)
                model.eval()
                evaluate(model, tokenizer, args, parser_args, model_path, dataset_split_utility, dataset_split_adveval)
        else:
            index = int(parser_args.checkpoint)
            print(f"lora model path: {model_path_list[index]}")
            model = PeftModel.from_pretrained(base_model, model_path_list[index])
            model.eval()
            evaluate(model, tokenizer, args, parser_args, model_path_list[index], dataset_split_utility, dataset_split_adveval)
    else:
        tokenizer = AutoTokenizer.from_pretrained(args.base_model_path)
        if parser_args.checkpoint == "all" or parser_args.checkpoint == "0":
            model_base = load_model_only(args.base_model_path)
            checkpoint_0 = re.sub(r'-\d+$', '-0',model_path_list[0])
            evaluate(model_base, tokenizer, args, parser_args, checkpoint_0, dataset_split_utility, dataset_split_adveval)
            del model_base
            torch.cuda.empty_cache()
            gc.collect()
        if parser_args.checkpoint == "all":
            print("evaluating all full model checkpoints")
            for model_path in model_path_list:
                model = load_model_only(model_path)
                model.eval()
                evaluate(model, tokenizer, args, parser_args, model_path, dataset_split_utility, dataset_split_adveval)
                del model
                torch.cuda.empty_cache()
                gc.collect()
        else:
            index = int(parser_args.checkpoint)
            model = load_model_only(model_path_list[index])
            model.eval()
            evaluate(model, tokenizer, args, parser_args, model_path_list[index], dataset_split_utility, dataset_split_adveval)
            del model
    torch.cuda.empty_cache()
    gc.collect()
    
if __name__ == "__main__":
    main()
