import torch
import pickle
import os
import numpy as np
from tqdm import tqdm
from IPython import embed

from src.arguments import parse_args
from src.models.get_model import get_model
from src.data.get_data import get_icl_dataset


def main(args):
    os.makedirs(args.output_dir, exist_ok=True)
    device = torch.device("cuda:0") if args.cuda else torch.device("cpu")

    print("loading dataset ...")
    icl_dataset, dataset_cls = get_icl_dataset(
        args.dataset, args.split, args.subset, args.sampling,
        args.num_samples, args.prompt_type)
    num_data = len(icl_dataset)
    print("dataset size:", num_data)
    print("Example datum:", icl_dataset[0])

    print("loading model ...")
    model = get_model(args.model, args.low_resource_mode, device)
    model.eval()

    sample_keys = []
    sample_values = []
    sample_logits = []
    sample_input_ids = []
    sample_start_pos = []
    for _datum in tqdm(icl_dataset):
        _feat = model.get_features(
            _datum, ['logits', 'input_ids',
                     'past_keys', 'past_values',
                     'prediction']
        )
        # sample_logits.append(_feat["logits"])
        sample_input_ids.append(_feat["input_ids"])
        input_ids = _feat["input_ids"][0]
        start_pos = (input_ids == 20560).nonzero()[0].tolist() + \
            [input_ids.shape[0]]
        icl_keys = []
        icl_values = []
        icl_logits = []
        for _seg_i, _pos in enumerate(start_pos[1:-1]):
            # _pos-4 is the location for the label
            icl_keys.append(_feat["past_keys"][
                :, :, _pos-5:_pos])
            icl_values.append(_feat["past_values"][
                :, :, _pos-5:_pos])
            icl_logits.append(_feat["logits"][
                _pos-5])
        sample_keys.append(np.stack(icl_keys))
        sample_values.append(np.stack(icl_values))
        sample_start_pos.append(start_pos)
        sample_logits.append(np.stack(icl_logits))
    # shape: n_datum, n_icl_sample, layer=28, head=16, window=5, dim
    sample_keys = np.stack(sample_keys)
    # shape: n_datum, n_icl_sample, layer=28, head=16, window=5, dim
    sample_values = np.stack(sample_values)
    # shape: n_datum, n_icl_sample, vocab_size
    sample_logits = np.stack(sample_logits)
    # shape: n_datum, length
    sample_input_ids = [_ids[0] for _ids in sample_input_ids]
    # shape: n_datum, n_icl_sample+2
    sample_start_pos = np.stack(sample_start_pos)

    dataset_labels = np.array([_datum["original"]["label"] for _datum in
                               icl_dataset])
    sample_labels = []
    for _datum in icl_dataset:
        sample_labels.append(
            dataset_labels[_datum["sample_index"]]
        )
    # shape: n_datum, n_icl_sample
    sample_labels = np.array(sample_labels)

    # singular dataset
    singular_dataset, _ = get_icl_dataset(
        args.dataset, args.split, args.subset, args.sampling,
        0, args.prompt_type)
    singular_features = []
    for _datum in tqdm(icl_dataset):
        singular_features.append(model.get_features(
            _datum, ['logits', 'prediction', 'last_hidden_states']
        ))

    embed()
    exit()

    n_train = 3000
    # regression for value-labels
    label_acc = []
    label_corr = []
    for i in tqdm(range(28)):
        _acc = []
        _corr = []
        for j in range(16):
            _acc2 = []
            _corr2 = []
            for k in range(5):
                X = sample_values[:, :, i, j, k].reshape(
                    -1, sample_values.shape[-1])
                X = torch.tensor(X).to(device).float()
                y = sample_labels.reshape(-1) - 1
                y = torch.tensor(y).to(device).float()
                XTX = X[:n_train].T.matmul(X[:n_train])
                w = (XTX + 0.01 * torch.eye(XTX.shape[0]).to(device)
                     ).inverse().matmul(X[:n_train].T).matmul(y[:n_train])
                y_hat = w.T.matmul(X[n_train:].T)
                y_gold = y[n_train:]
                acc = ((y_hat > 0) == (y_gold > 0)).float().mean()
                corr = np.corrcoef(y_hat.cpu().numpy(),
                                   y_gold.cpu().numpy())[0, 1]
                _acc2.append(acc)
                _corr2.append(corr)
                # print(f"{acc:.4} {corr:.4}, ", end="")
            _acc.append(_acc2)
            _corr.append(_corr2)
        label_acc.append(_acc)
        label_corr.append(_corr)
        # print()
    label_acc = torch.tensor(label_acc).cpu().numpy()
    label_corr = torch.tensor(label_corr).cpu().numpy()

    # regression for singular prediction
    singular_logits = np.stack([_feat["logits"][-1] for _feat in
                                singular_features])
    sample_singular_logits = []
    for _datum in icl_dataset:
        sample_singular_logits.append(
            singular_logits[_datum["sample_index"]]
        )
    sample_singular_logits = np.stack(sample_singular_logits)

    singular_acc = []
    singular_corr = []
    for i in tqdm(range(28)):
        _acc = []
        _corr = []
        for j in range(16):
            _acc2 = []
            _corr2 = []
            for k in range(5):
                X = sample_keys[:, :, i, j, k].reshape(
                    -1, sample_values.shape[-1])
                X = torch.tensor(X).to(device).float()
                # y = sample_logits.reshape(-1, vocab_size)
                y = (sample_singular_logits[:, :, 39793]
                     - sample_singular_logits[:, :, 25741]
                     ).reshape(-1)
                y = torch.tensor(y).to(device).float()
                XTX = X[:n_train].T.matmul(X[:n_train])
                w = (XTX + 0.01 * torch.eye(XTX.shape[0]).to(device)
                     ).inverse().matmul(X[:n_train].T).matmul(y[:n_train])
                y_hat = w.T.matmul(X[n_train:].T)
                y_gold = y[n_train:]
                acc = ((y_hat > 0) == (y_gold > 0)).float().mean()
                corr = np.corrcoef(y_hat.cpu().numpy(),
                                   y_gold.cpu().numpy())[0, 1]
                _acc2.append(acc)
                _corr2.append(corr)
            _acc.append(_acc2)
            _corr.append(_corr2)
        singular_acc.append(_acc)
        singular_corr.append(_corr)
        # print()
    singular_acc = torch.tensor(singular_acc).cpu().numpy()
    singular_corr = torch.tensor(singular_corr).cpu().numpy()

    with open(os.path.join(args.output_dir, "interpret_elements.pkl"),
              "wb") as f:
        pickle.dump([label_acc, label_corr, singular_acc, singular_corr], f)


if __name__ == "__main__":
    args = parse_args()
    main(args)
