import numpy as np
import torch
import torchmetrics


def init_best_metrics():
    return {
        'best_epoch': 0,
        'dev_best_perf': None,
        'test_best_perf': None,
    }

def init_perf_metrics(num_classes):
    assert num_classes >= 2
    perf_metrics = torch.nn.ModuleDict({
        'acc': torchmetrics.Accuracy(),
    })
    return perf_metrics

def calc_preds(logits):
    return torch.argmax(logits, dim=1)

def get_step_metrics(preds, labels, metrics):
    res = {}
    for key, metric_fn in metrics.items():
        res.update({key: metric_fn(preds, labels) * 100})
    return res

def get_epoch_metrics(metrics):
    res = {}
    for key, metric_fn in metrics.items():
        res.update({key: metric_fn.compute() * 100})
        metric_fn.reset()
    return res

def process_outputs(outputs, io_mode, tokenizer):
    assert io_mode in ['I-OR', 'I-RO']
    labels, rationales = [], []
    decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    for text in decoded_outputs:
        text_split = text.split('explanation:') if io_mode == 'I-OR' else text.split('label:')
        pred_0 = text_split[0].strip()
        if len(text_split) > 1:
            pred_1 = text_split[1].strip()
            # Also split on extra id token (which tends to appear as a delimiter frequently)
            # Ref: https://github.com/allenai/label_rationale_association/blob/main/custom_args.py#L177
            pred_1 = pred_1.split('<extra_id')[0].strip()
        else:
            pred_1 = ''
        labels.append(pred_0 if io_mode == 'I-OR' else pred_1)
        rationales.append(pred_1 if io_mode == 'I-OR' else pred_0)

    return labels, rationales