import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
import torch
import pickle
import os
import math
import numpy as np
from tqdm import tqdm
from sklearn.metrics import matthews_corrcoef, accuracy_score
from IPython import embed

from src.arguments import parse_args
from src.models.get_model import get_model
from src.data.get_data import get_icl_dataset


def main(args):
    os.makedirs(args.output_dir, exist_ok=True)
    device = torch.device("cuda:0") if args.cuda else torch.device("cpu")

    print("loading dataset ...")
    icl_dataset, dataset_cls = get_icl_dataset(
        args.dataset, args.split, args.subset, args.sampling,
        args.num_samples, args.prompt_type)
    num_data = len(icl_dataset)
    print("dataset size:", num_data)
    print("Example datum:", icl_dataset[0])

    print("loading model ...")
    model = get_model(args.model, args.low_resource_mode, device)
    model.eval()
    features = []
    for _datum in tqdm(icl_dataset):
        features.append(model.get_features(
            _datum, ['logits', 'attentions', 'input_ids',
                     'prediction']
        ))

    predictions = np.array(
        [dataset_cls.convert_output(_datum["prediction"][1:])
         for _datum in features]
    )
    labels = np.array(
        [dataset_cls.convert_output(_datum["output"])
         for _datum in icl_dataset]
    )

    metric_dict = {
        "accuracy": accuracy_score,
        "matthews": matthews_corrcoef
    }
    for _metric in args.metrics:
        print(_metric, ":", metric_dict[_metric](predictions, labels))

    # Part 2:
    # label reconstruction
    # reconstructions = []
    # for layer_i in tqdm(range(28)):
    #     layer_predictions = []
    #     for _datum, _feat in zip(icl_dataset, features):
    #         prediction = dataset_cls.convert_output(_feat["prediction"][1:])
    #         prob = torch.softmax(torch.tensor(_feat["logits"][-1]), 0).max()
    #         input_ids = _feat["input_ids"][0]
    #         attention = _feat["attentions"]
    #         start_pos = (input_ids == 20560).nonzero()[0].tolist() + \
    #             [input_ids.shape[0]]
    #         total_label = 0
    #         total_attn = 0
    #         for _seg_i, _pos in enumerate(start_pos[1:-1]):
    #             _label = dataset_cls.convert_output(
    #                     model.tokenizer.decode([input_ids[_pos-4]])[1:]
    #                 )
    #             _attn = attention[
    #                 layer_i, :, max(start_pos[_seg_i], 1): _pos].max()
    #             total_label += _label * _attn
    #             total_attn += _attn
    #         layer_predictions.append([
    #             total_label / total_attn,
    #             prediction*prob+(1-prediction)*(1-prob),
    #         ])
    #     layer_predictions = np.stack(layer_predictions)
    #     reconstructions.append(layer_predictions)
    # reconstructions = np.stack(reconstructions)
    # labels = np.array([_datum["output"] == "positive" for _datum in
    #                    icl_dataset])
    # x = reconstructions[:, :, 0]
    # y = reconstructions[:, :, 1]
    # for layer_i in range(28):
    #     corre = np.corrcoef(x[layer_i], y[layer_i])[0, 1]
    #     acc = ((x[layer_i] > x[layer_i].mean()) ==
    #            (y[layer_i] > 0.5)).astype(float).mean()
    #     acc2 = ((x[layer_i] > x[layer_i].mean())
    #             == labels).astype(float).mean()
    #     print(layer_i, corre, acc, acc2)

    # label reconstruction by head
    num_labels = 2
    head_reconstructions = []
    for layer_i in tqdm(range(28)):
        for head_i in range(16):
            head_predictions = []
            for _datum, _feat in zip(icl_dataset, features):
                prediction = np.zeros(num_labels)
                prediction[dataset_cls.convert_output(
                    _feat["prediction"][1:])] = 1
                prob = torch.softmax(torch.tensor(
                    _feat["logits"][-1]), 0).max().item()
                input_ids = _feat["input_ids"][0]
                attention = _feat["attentions"]
                start_pos = (input_ids == 20560).nonzero()[0].tolist() + \
                    [input_ids.shape[0]]
                total_label = 0
                total_attn = 0
                for _seg_i, _pos in enumerate(start_pos[1:-1]):
                    _label = np.zeros(num_labels)
                    _label[icl_dataset[_datum["sample_index"][_seg_i]][
                        "original"]["label"]] = 1
                    # _label = icl_dataset[_datum["sample_index"][_seg_i]][
                    #     "original"]["label"]
                    _attn = attention[
                        layer_i, head_i, max(start_pos[_seg_i], 1): _pos].max()
                    total_label = total_label + _label * _attn
                    total_attn += _attn
                head_predictions.append([
                    total_label / total_attn,
                    prediction*prob+(1-prediction)*(1-prob),
                ])
            head_predictions = np.stack(head_predictions)
            head_reconstructions.append(head_predictions)
    head_reconstructions = np.stack(head_reconstructions).reshape(
        (28, 16, -1, 2, num_labels))
    labels = np.array([_datum["output"] == "positive" for _datum in
                       icl_dataset])
    x = head_reconstructions[:, :, :, 0].argmax(-1)
    y = head_reconstructions[:, :, :, 1].argmax(-1)
    head_correlations = []
    for layer_i in range(28):
        for head_i in range(16):
            corre = np.corrcoef(x[layer_i, head_i], y[layer_i, head_i])[0, 1]
            acc = (x[layer_i, head_i] == y[layer_i, head_i]
                   ).astype(float).mean()
            acc2 = (x[layer_i, head_i] == labels).astype(float).mean()
            # acc = ((x[layer_i, head_i] > x[layer_i, head_i].mean()) ==
            #        (y[layer_i, head_i] > 0.5)).astype(float).mean()
            # acc2 = ((x[layer_i, head_i] > x[layer_i, head_i].mean())
            #         == labels).astype(float).mean()
            head_correlations.append([corre, acc, acc2])
    head_correlations = np.array(head_correlations).reshape((28, 16, 3))

    with open(os.path.join(args.output_dir, "sample_correlation.pkl"), "wb"
              ) as f:
        pickle.dump(head_correlations, f)


if __name__ == "__main__":
    args = parse_args()
    main(args)
