"""

Script adopted from:
    https://github.com/huggingface/transformers/blob/master/examples/run_glue.py

"""

import logging

from transformers import BertConfig
from model.model import BertForDistantRE

from utils.train_utils import *

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)


def evaluate_test(model, model_dir, set_type="dev", eval_lower_80=False, eval_upper_20=False, set_lower_80=False, load_eval=False, cumulative_perf=False):
    eval_dataset = load_dataset(set_type, logger, ent_types=False)
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=config.eval_batch_size)

    # Eval!
    logger.info("***** Running evaluation {} *****".format(set_type))
    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", config.eval_batch_size)
    eval_loss = 0.0
    nb_eval_steps = 0
    eval_logits, eval_labels, eval_preds, eval_groups = [], [], [], []

    counter = 0
    if load_eval:
        # Save evaluation raw data
        fname = os.path.join(model_dir, set_type + "_raw_eval_data.pkl")
        logger.info("Loading raw results file: {}".format(fname))
        logger.info("Length of lower 80 set: {}".format(len(set_lower_80)))
        with open(fname, "rb") as wf:
            eval = pickle.load(wf)
        if eval_lower_80 or eval_upper_20:
            eval_loss = eval['loss']
            for label, logit, pred, group in zip(eval['labels'], eval['logits'], eval['preds'], eval['groups']):
                h, t = group[0].item(), group[1].item()
                r = label.item()
                trip = "\t".join([str(h), str(r), str(t)])
                if eval_upper_20:
                    if trip not in set_lower_80:
                        counter += 1
                        eval_labels.append(label)
                        eval_logits.append(logit)
                        eval_groups.append(group)  # groups
                        eval_preds.append(pred)
                elif eval_lower_80:
                    if trip in set_lower_80:
                        counter += 1
                        eval_labels.append(label)
                        eval_logits.append(logit)
                        eval_groups.append(group)  # groups
                        eval_preds.append(pred)
            logger.info("Counted {} lower 80 trips.".format(counter))
            logger.info("Length of labels: {}.".format(len(eval_labels)))
            eval = {
                'loss': eval_loss,
                'labels': torch.stack(eval_labels),  # B,
                'logits': torch.stack(eval_logits),  # B x C
                'preds': np.asarray(eval_preds),
                'groups': torch.stack(eval_groups)  # B x 2
            }
    else:
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            model.eval()
            batch = tuple(t.to(config.device) for t in batch)

            with torch.no_grad():
                inputs = {
                    "input_ids": batch[0],
                    "entity_ids": batch[1],
                    "attention_mask": batch[2],
                    "labels": batch[4],
                    "is_train": False
                }
                tmp_eval_loss, logits = model(**inputs)
                eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1

            eval_labels.append(inputs["labels"].detach().cpu())
            eval_logits.append(logits.detach().cpu())
            eval_groups.append(batch[3].detach().cpu())  # groups
            eval_preds.append(torch.argmax(logits.detach().cpu(), dim=1).item())

        del model, batch, logits, tmp_eval_loss, eval_dataloader, eval_dataset # memory mgmt

        eval = {
            'loss': eval_loss / nb_eval_steps,
            'labels': torch.cat(eval_labels),  # B,
            'logits': torch.cat(eval_logits),  # B x C
            'preds': np.asarray(eval_preds),
            'groups': torch.cat(eval_groups)  # B x 2
        }

    # Get all positive relationship lables
    rel2idx = read_relations(config.relations_file)
    pos_rel_idxs = list(rel2idx.values())
    rel_idx_na = rel2idx['na']
    del pos_rel_idxs[rel_idx_na]

    a = accuracy_score(eval['labels'].numpy(), eval['preds'])
    p, r, f1, support = precision_recall_fscore_support(eval['labels'].numpy(), eval['preds'], average='micro', labels=pos_rel_idxs)
    logger.info('Accuracy (including "NA"): {}\nP: {}, R: {}, F1: {}'.format(a, p, r, f1))
    results = {}
    results['new_results'] = {
        'acc_with_na': a,
        'scikit_precision': p,
        'scikit_recall': r,
        'scikit_f1': f1,
        "loss": eval_loss,
        "counter": eval['labels'].shape
    }

    results['original'] = compute_metrics(eval['logits'], eval['labels'], eval['groups'], set_type, logger)
    logger.info("Results: %s", results)

    if load_eval:
        # Save evaluation results
        with open(os.path.join(model_dir, set_type + "_metrics_from_load.txt"), "w") as wf:
            json.dump(results, wf, indent=4)

    else:
        # Save evaluation results
        with open(os.path.join(model_dir, set_type + "_metrics.txt"), "w") as wf:
            json.dump(results, wf, indent=4)

        # Save evaluation raw data
        with open(os.path.join(model_dir, set_type + "_raw_eval_data.pkl"), "wb") as wf:
            pickle.dump(eval, wf)



def main():
    config.relations_file = '/projects/ibm_aihl/whogan/umls-main/data/original/processed/2019.spacy/relations.txt'
    num_labels = len(read_relations(config.relations_file))

    # Load raw results (don't re-run test eval model)
    load_eval = True

    # Lower 80
    eval_lower_80 = False
    eval_upper_20 = False
    if eval_lower_80:
        logger.info("EVAL'N LOWER 80 ONLY!")
    elif eval_upper_20:
        logger.info("EVAL'N UPPER 20 ONLY!")
    else:
        logger.info("EVAL FULL TEST SET")
    with open(config.lower_half_trips, "rb") as f:
        set_lower_80 = pickle.load(f)
        logger.info("Loaded lower 80 file: {}".format(config.lower_half_trips))
        logger.info("Lower 80 set length: {}".format(len(set_lower_80)))

    # Evaluation
    model_dir = '[insert model dir here]'
    logger.info("Evaluate the checkpoint: %s", model_dir)
    model = BertForDistantRE(BertConfig.from_pretrained(model_dir), num_labels, bag_attn=config.use_bag_attn)
    model.load_state_dict(torch.load(model_dir + "/pytorch_model.bin", map_location=torch.device(config.device)))
    model.to(config.device)
    evaluate_test(model, model_dir, "test", eval_lower_80=eval_lower_80, eval_upper_20=eval_upper_20, set_lower_80=set_lower_80, load_eval=load_eval)



if __name__ == "__main__":
    main()
