import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
import torch
import pickle
import os
import math
import numpy as np
from tqdm import tqdm
from sklearn.metrics import matthews_corrcoef, accuracy_score
from IPython import embed

from src.arguments import parse_args
from src.models.get_model import get_model
from src.data.get_data import get_icl_dataset


def main(args):
    os.makedirs(args.output_dir, exist_ok=True)
    device = torch.device("cuda:0") if args.cuda else torch.device("cpu")

    print("loading dataset ...")
    icl_dataset, dataset_cls = get_icl_dataset(
        args.dataset, args.split, args.subset, args.sampling,
        args.num_samples, args.prompt_type)
    num_data = len(icl_dataset)
    print("dataset size:", num_data)
    print("Example datum:", icl_dataset[0])

    print("loading model ...")
    model = get_model(args.model, args.low_resource_mode, device)
    model.eval()
    features = []
    for _datum in tqdm(icl_dataset):
        features.append(model.get_features(
            _datum, ['logits', 'attentions', 'input_ids',
                     'prediction']
        ))

    predictions = np.array(
        [dataset_cls.convert_output(_datum["prediction"][1:])
         for _datum in features]
    )
    labels = np.array(
        [dataset_cls.convert_output(_datum["output"])
         for _datum in icl_dataset]
    )

    embed()
    exit()

    metric_dict = {
        "accuracy": accuracy_score,
        "matthews": matthews_corrcoef
    }
    for _metric in args.metrics:
        print(_metric, ":", metric_dict[_metric](predictions, labels))

    # Part 1:
    # overall attention map
    n_layer = 28
    n_head = 16
    sample_grid = 30
    n_sample = args.num_samples
    overall_attention = np.zeros((n_layer, n_head, sample_grid * (n_sample+1)))

    for _feat in tqdm(features):
        input_ids = _feat["input_ids"][0]
        start_pos = (input_ids == 20560).nonzero()[0].tolist() + \
            [input_ids.shape[0]]
        pos_scale = []
        for _seg_i in range(n_sample+1):
            pos_scale.append(np.linspace(
                _seg_i, _seg_i+1, start_pos[_seg_i+1]-start_pos[_seg_i],
                endpoint=False))
        pos_scale = np.concatenate(pos_scale)

        attentions = _feat["attentions"]
        for layer_i, layer_attn in enumerate(attentions):
            for head_i, head_attn in enumerate(layer_attn):
                for _pos, _attn in enumerate(head_attn):
                    overall_attention[
                        layer_i, head_i,
                        math.floor(pos_scale[_pos]*sample_grid)
                    ] += _attn

    with open(os.path.join(args.output_dir, "overall_attention.pkl"), "wb"
              ) as f:
        pickle.dump(overall_attention, f)


if __name__ == "__main__":
    args = parse_args()
    main(args)
