import json
from transformers import set_seed
import torch
import os
import argparse
import logging

from utils import (
    get_lastde,
    get_likelihood,
    get_entropy,
    get_lrr,
    get_rank,
    get_logrank,
    get_roc_metrics,
    evaluate_detectrl,
    load_model
)


logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")


def main(args):
    os.environ['TOKENIZERS_PARALLELISM'] = 'True'

    base_model, ref_model, base_tokenizer, ref_tokenizer = load_model(args)
    
    set_seed(args.seed)
    
    def get_score(text):
        if 'Ray2333' in args.base_model:
            message = [
                {'role': 'user', 'content': ''},
                {'role': 'assistant', 'content': text}
            ]
            text = base_tokenizer.apply_chat_template(message, tokenize=False)
        tokenized = base_tokenizer(text, return_tensors="pt", return_token_type_ids=False)
        labels = tokenized.input_ids[:, 1:]
        with torch.no_grad():
            base_inputs = tokenized.to(base_model.device)
            if args.method == 'RM' or args.method == 'RM-template':
                score = base_model(**base_inputs).logits[0]
                return score.cpu().item()
            logits = base_model(**base_inputs).logits[:, :-1]
            if args.method == 'likelihood':
                score = get_likelihood(logits, labels)
            elif args.method == 'entropy':
                score = get_entropy(logits, labels)
            elif args.method == 'rank':
                score = get_rank(logits, labels)
            elif args.method == 'logrank':
                score = get_logrank(logits, labels)
            elif args.method == 'lastde':
                score = get_lastde(logits, labels)
            elif args.method == 'lrr':
                score = get_lrr(logits, labels)
            else:
                raise ValueError('No implementation')
        return score

    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
        
    if args.benchmark == 'detectrl':
        evaluate_func = evaluate_detectrl
    else:
        raise ValueError('No implementation.')
    
    filenames = args.test_data_path.split(",")[:-1]
    for filename in filenames:
        logging.info(f"Test in {filename}")
        with open(filename, 'r') as fin:
            data = json.load(fin)

        predictions, scored_data = evaluate_func(data, get_score)
        
        roc_auc, optimal_threshold, conf_matrix, precision, recall, f1, accuracy, tpr_at_fpr_0_01 = get_roc_metrics(predictions['human'], predictions['llm'])

        result = {
            "roc_auc": roc_auc,
            "optimal_threshold": optimal_threshold,
            "conf_matrix": conf_matrix,
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "accuracy": accuracy,
            "tpr_at_fpr_0_01": tpr_at_fpr_0_01
        }
        
        logging.info(f"{result}")
        save_path = args.save_path + '/' + filename.split("/")[-1].split(".json")[0].split(".raw_data")[0]
        with open(save_path + "_data.json", "w") as f:
            json.dump(scored_data, f, indent=4)

        with open(save_path + "_result.json", "w") as f:
            json.dump(result, f, indent=4)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--benchmark', type=str,required=True)
    parser.add_argument('--test_data_path', type=str, required=True)
    parser.add_argument('--base_model', default=None, type=str, required=False)
    parser.add_argument('--ref_model', default=None, type=str, required=False)
    parser.add_argument('--device', default="0,1", type=str, required=False)
    parser.add_argument('--seed', default=42, type=int, required=False)
    parser.add_argument('--save_path', default=None, type=str, required=False)
    parser.add_argument('--method', default=None, type=str, required=True)
    args = parser.parse_args()

    main(args)
    
    
    