import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
import torch
import pickle
import os
import math
import numpy as np
from tqdm import tqdm
from sklearn.metrics import matthews_corrcoef, accuracy_score
from IPython import embed

from src.arguments import parse_args
from src.models.get_model import get_model
from src.data.get_data import get_icl_dataset


def main(args):
    os.makedirs(args.output_dir, exist_ok=True)
    device = torch.device("cuda:0") if args.cuda else torch.device("cpu")

    print("loading dataset ...")
    icl_dataset, dataset_cls = get_icl_dataset(
        args.dataset, args.split, args.subset, args.sampling,
        args.num_samples, args.prompt_type)
    num_data = len(icl_dataset)
    print("dataset size:", num_data)
    print("Example datum:", icl_dataset[0])

    print("loading model ...")
    model = get_model(args.model, args.low_resource_mode, device)
    model.eval()
    features = []
    for _datum in tqdm(icl_dataset):
        features.append(model.get_features(
            _datum, ['logits', 'attentions', 'input_ids',
                     'prediction']
        ))

    predictions = np.array(
        [dataset_cls.convert_output(_datum["prediction"][1:])
         for _datum in features]
    )
    labels = np.array(
        [dataset_cls.convert_output(_datum["output"])
         for _datum in icl_dataset]
    )

    metric_dict = {
        "accuracy": accuracy_score,
        "matthews": matthews_corrcoef
    }
    for _metric in args.metrics:
        print(_metric, ":", metric_dict[_metric](predictions, labels))

    embed()
    exit()


if __name__ == "__main__":
    args = parse_args()
    main(args)
