import json
from transformers import set_seed
import torch
import os
import argparse
import logging

from utils import (
    get_likelihood,
    get_perplexity,
    get_entropy_binoculars,
    get_roc_metrics,
    evaluate_detectrl,
    load_model,
    get_sampling_discrepancy_fast_detect_gpt,
    get_sampling_discrepancy_analytic_fast_detect_gpt,
    get_sampling_discrepancy_lastde_doubleplus,
)

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")


def main(args):
    os.environ['TOKENIZERS_PARALLELISM'] = 'True'

    base_model, ref_model, base_tokenizer, ref_tokenizer = load_model(args)

    set_seed(args.seed)

    def get_score(text):
        tokenized = base_tokenizer(text, return_tensors="pt", return_token_type_ids=False).to(base_model.device)
        labels = tokenized.input_ids[:, 1:]
        with torch.no_grad():
            if args.method == 'iRM':
                base_inputs = tokenized.to(base_model.device)
                logits = base_model(**base_inputs).logits[:, :-1]
                likelihood = get_likelihood(logits, labels, return_sum=True)

                ref_inputs = tokenized.to(ref_model.device)
                logits = ref_model(**ref_inputs).logits[:, :-1]
                ref_likelihood = get_likelihood(logits, labels, return_sum=True)
                
                score = likelihood - ref_likelihood
                return score
            elif args.method == 'binoculars':
                base_inputs = tokenized.to(base_model.device)
                base_logits = base_model(**base_inputs).logits
                ref_inputs = tokenized.to(ref_model.device)
                ref_logits = ref_model(**ref_inputs).logits

                ppl = get_perplexity(tokenized.to(base_logits.device), base_logits)
                entropy = get_entropy_binoculars(ref_logits.to(base_logits.device), base_logits, tokenized.to(base_logits.device), base_tokenizer.pad_token_id)
                score = ppl / entropy
                score = score.tolist()
                return -score[0]
            elif args.method == 'fast_detect_gpt':
                logits_score = base_model(**tokenized).logits[:, :-1]
                if args.base_model == args.ref_model:
                    logits_ref = logits_score
                else:
                    tokenized = ref_tokenizer(text, return_tensors="pt", return_token_type_ids=False).to(ref_model.device)
                    assert torch.all(tokenized.input_ids[:, 1:] == labels), "Tokenizer is mismatch."
                    logits_ref = ref_model(**tokenized).logits[:, :-1]
                if args.discrepancy_analytic:
                    score = get_sampling_discrepancy_analytic_fast_detect_gpt(logits_ref, logits_score, labels)
                else:
                    score = get_sampling_discrepancy_fast_detect_gpt(logits_ref, logits_score, labels)
                return score
            elif args.method == 'lastde_doubleplus':
                logits_score = base_model(**tokenized).logits[:, :-1]
                if args.base_model == args.ref_model:
                    logits_ref = logits_score
                else:
                    tokenized = ref_tokenizer(text, return_tensors="pt", return_token_type_ids=False).to(ref_model.device)
                    assert torch.all(tokenized.input_ids[:, 1:] == labels), "Tokenizer is mismatch."
                    logits_ref = ref_model(**tokenized).logits[:, :-1]
                score = get_sampling_discrepancy_lastde_doubleplus(logits_ref, logits_score, labels)
                return score
            else:
                raise ValueError('No implementation')

    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    if args.benchmark == 'detectrl':
        evaluate_func = evaluate_detectrl
    else:
        raise ValueError('No implementation.')

    filenames = args.test_data_path.split(",")[:-1]

    for filename in filenames:
        logging.info(f"Test in {filename}")
        with open(filename, 'r') as fin:
            data = json.load(fin)
        
        predictions, scored_data = evaluate_func(data, get_score, skip_fail=args.skip_fail)

        roc_auc, optimal_threshold, conf_matrix, precision, recall, f1, accuracy, tpr_at_fpr_0_01 = get_roc_metrics(predictions['human'], predictions['llm'])

        result = {
            "roc_auc": roc_auc,
            "optimal_threshold": optimal_threshold,
            "conf_matrix": conf_matrix,
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "accuracy": accuracy,
            "tpr_at_fpr_0_01": tpr_at_fpr_0_01
        }
        logging.info(f"{result}")
        if args.benchmark == 'detectrl':
            save_path = args.save_path + '/' + filename.split("/")[-1].split(".json")[0].split(".")[0]
        
        elif args.benchmark == 'glimpse':
            save_path = args.save_path + '/' + filename.split("/")[-1].split(".json")[0].split(".raw")[0]
        
        else:
            raise ValueError('No implementation.')
        # save_path = args.save_path + '/' + filename.split("/")[-1].split(".json")[0].split(".")[0]
        with open(save_path + "_data.json", "w") as f:
            json.dump(scored_data, f, indent=4)

        with open(save_path + "_result.json", "w") as f:
            json.dump(result, f, indent=4)



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--benchmark', type=str,required=True)
    parser.add_argument('--test_data_path', type=str, required=True)
    parser.add_argument('--base_model', default=None, type=str, required=False)
    parser.add_argument('--ref_model', default=None, type=str, required=False)
    parser.add_argument('--device', default="0,1", type=str, required=False)
    parser.add_argument('--seed', default=42, type=int, required=False)
    parser.add_argument('--save_path', default=None, type=str, required=False)
    parser.add_argument('--method', default='iRM', type=str, required=True)
    parser.add_argument('--discrepancy_analytic', action='store_true',
                        help='For fast_detect_gpt.')
    parser.add_argument('--skip_fail', action='store_true',
                        help='For lastde.')
    args = parser.parse_args()

    main(args)
    
    
    