import os
import torch
import json

from evaluation.evaluate import initialize_model, evaluate

class Arguments():
    def __init__(
        self,
        prediction_file,
        evaluate_on_test,
    ):
        self.data_dir = "./dataset/opendialkg"
        self.pickle_folder = "./evaluation/dataset"
        self.hidden_size = 768
        self.train_batch_size = 12
        self.eval_batch_size = 12

        self.doc_stride = 128
        self.max_seq_length = 384
        self.max_mention_length = 30
        self.max_query_length = 64
        self.num_train_epochs = 2
        self.debug = False

        self.learning_rate = 3e-5
        self.warmup_steps = 0
        self.adam_epsilon = 1e-8
        self.weight_decay = 0.01
        self.bert_model = "bert-base-uncased"

        self.baseline = False
        self.seed = 42
        
        self.version_2_with_negative = False
        self.output_dir = "./save/KQA_tmp"
        self.n_best_size = 20
        self.verbose_logging = False
        self.max_answer_length = 30
        self.save_steps = 0
        self.gradient_accumulation_steps = 1

        self.null_score_diff_threshold = 0.0
        self.local_rank = -1

        self.read_data = False
        self.do_eval = False
        self.no_report = False
        self.checkpoint = "./path/to/checkpoint"
        self.prediction_file = prediction_file
        self.evaluate_on_test = evaluate_on_test
        self.lm_type = "bert"
        self.do_lower_case = True

        self.eval_target = "all"

def run_KQA(prediction_file, fold="dev"):
    evaluate_on_test = fold == "test"
    args = Arguments(prediction_file, evaluate_on_test)
    args.device = 'cuda'
    
    # Evaluate
    model, tokenizer = initialize_model(args)
    model.load_state_dict(torch.load(os.path.join(args.checkpoint, "pytorch_model.bin"), map_location="cpu"))
    model.to(args.device)

    output_file = os.path.join(args.output_dir, "predictions.json")

    if evaluate_on_test:
        results = evaluate(args, model, fold="test", output_file=output_file)
    else:
        results = evaluate(args, model, fold="dev", output_file=output_file)

    with open(os.path.join(args.output_dir, "results.json"), "w") as f:
        json.dump(results, f)
    return results["exact"], results["f1"]