import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
import torch
import pickle
import os
import math
import numpy as np
from tqdm import tqdm
from sklearn.metrics import matthews_corrcoef, accuracy_score
from IPython import embed

from src.arguments import parse_args
from src.models.get_model import get_model
from src.data.get_data import get_icl_dataset


def main(args):
    os.makedirs(args.output_dir, exist_ok=True)
    device = torch.device("cuda:0") if args.cuda else torch.device("cpu")

    print("loading dataset ...")
    icl_dataset, dataset_cls = get_icl_dataset(
        args.dataset, args.split, args.subset, args.sampling,
        args.num_samples, args.prompt_type)
    num_data = len(icl_dataset)
    print("dataset size:", num_data)
    print("Example datum:", icl_dataset[0])

    print("loading model ...")
    model = get_model(args.model, args.low_resource_mode, device)
    model.eval()
    features = []
    for _datum in tqdm(icl_dataset):
        features.append(model.get_features(
            _datum, ['logits', 'attentions', 'input_ids',
                     'prediction']
        ))

    predictions = np.array(
        [dataset_cls.convert_output(_datum["prediction"][1:])
         for _datum in features]
    )
    labels = np.array(
        [dataset_cls.convert_output(_datum["output"])
         for _datum in icl_dataset]
    )

    # interpreting the attention
    sample_attentions = []
    for _datum, _feat in zip(icl_dataset, features):
        input_ids = _feat["input_ids"][0]
        attention = _feat["attentions"]
        start_pos = (input_ids == 20560).nonzero()[0].tolist() + \
            [input_ids.shape[0]]
        attn_list = []
        for _seg_i, _pos in enumerate(start_pos[1:-1]):
            _attn = attention[
                :, :, max(start_pos[_seg_i], 1): _pos].max(2)
            attn_list.append(_attn)
        sample_attentions.append(np.stack(attn_list))
    sample_attentions = np.stack(sample_attentions)

    singular_dataset, _ = get_icl_dataset(
        args.dataset, args.split, args.subset, args.sampling,
        0, args.prompt_type)
    singular_features = []
    for _datum in tqdm(icl_dataset):
        singular_features.append(model.get_features(
            _datum, ['logits', 'prediction', 'last_hidden_states']
        ))

    embed()
    exit()

    metric_dict = {
        "accuracy": accuracy_score,
        "matthews": matthews_corrcoef
    }
    for _metric in args.metrics:
        print(_metric, ":", metric_dict[_metric](predictions, labels))

    # encoder similarity
    def mean_pooling(model_output, attention_mask):
        token_embeddings = model_output[0]
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(
            token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1
                         ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    sim_model = 'sentence-transformers/all-MiniLM-L6-v2'
    tokenizer = AutoTokenizer.from_pretrained(sim_model)
    sim_model = AutoModel.from_pretrained(sim_model).to(device)
    sentence_embeddings = []
    for _datum in tqdm(icl_dataset):
        _sentence = _datum["original"]["sentence"]
        encoded_input = tokenizer(_sentence, padding=True, truncation=True,
                                  return_tensors='pt').to(device)
        with torch.no_grad():
            model_output = sim_model(**encoded_input)
            _embedding = mean_pooling(model_output,
                                      encoded_input['attention_mask'])
            _embedding = F.normalize(_embedding, p=2, dim=1)
            sentence_embeddings.append(_embedding)
    sentence_embeddings = torch.stack(sentence_embeddings)[:, 0]
    sample_embeddings = []
    for _datum in icl_dataset:
        sample_embeddings.append(sentence_embeddings[_datum["sample_index"]])
    sample_embeddings = torch.stack(sample_embeddings)
    sample_encoder_similarity = sample_embeddings.matmul(
        sentence_embeddings[:, :, None])[:, :, 0].cpu().numpy()
    encoder_corr = []
    for i in range(28):
        _corr = []
        for j in range(16):
            _corr.append(np.corrcoef(
                sample_encoder_similarity.reshape(-1),
                sample_attentions[:, :, i, j].reshape(-1)
            )[0, 1])
        encoder_corr.append(_corr)
    encoder_corr = np.array()

    # GPT-j similarity
    # logit:
    singular_logits = np.stack([_feat["logits"][-1]
                                for _feat in singular_features])
    singular_logits = torch.tensor(singular_logits).softmax(1).numpy()
    # singular_logits -= singular_logits.mean(0)[None]
    sample_logit_similarity = np.matmul(
        singular_logits, singular_logits.transpose())
    sim_list = []
    for _i, _datum in enumerate(icl_dataset):
        sim_list.append(sample_logit_similarity[_i][_datum["sample_index"]])
    sample_logit_similarity = np.stack(sim_list)
    logit_corr = []
    for i in range(28):
        _corr = []
        for j in range(16):
            _corr.append(np.corrcoef(
                sample_logit_similarity.reshape(-1),
                sample_attentions[:, :, i, j].reshape(-1)
            )[0, 1])
        logit_corr.append(_corr)
    logit_corr = np.array(logit_corr)

    # hidden state:
    singular_hidden = np.stack([_feat["last_hidden_states"][-1]
                                for _feat in singular_features])
    singular_hidden -= singular_hidden.mean(0)[None]
    sample_hidden_similarity = np.matmul(singular_hidden,
                                         singular_hidden.transpose())
    sim_list = []
    for _i, _datum in enumerate(icl_dataset):
        sim_list.append(sample_hidden_similarity[_i][_datum["sample_index"]])
    sample_hidden_similarity = np.stack(sim_list)
    hidden_corr = []
    for i in range(28):
        _corr = []
        for j in range(16):
            _corr.append(np.corrcoef(
                sample_hidden_similarity.reshape(-1),
                sample_attentions[:, :, i, j].reshape(-1)
            )[0, 1])
        hidden_corr.append(_corr)
    hidden_corr = np.array(hidden_corr)

    with open(os.path.join(args.output_dir, "similarity.pkl"), "wb"
              ) as f:
        pickle.dump(logit_corr, f)


if __name__ == "__main__":
    args = parse_args()
    main(args)
