''' For every example, we have a list of G/A heads. We want to have the portion
of sum(word2anchor)/sum(gather_head) and sum(anchor2end)/sum(aggregate_head) for 
each G/A head.'''

from .checkpoint_surprisal import process_context_simple, add_tag
import torch
from .utils import context_file_template, context_file_idxs, tokenizer, word_list
from .test_grad_attn import checkpoint_path_to_model_w_attn, call_model, get_anchor_id, get_saliency_matrix
from pathlib import Path
import json
from collections import defaultdict


def main(checkpoint_dir_template, item_name, model_name):
    checkpoint_path = checkpoint_dir_template.format(item_name, model_name)
    model_step = model_name.split('_')[-1].replace('.pt', '')
    model = checkpoint_path_to_model_w_attn(checkpoint_path, device=device)

    with open(f'test/g_and_a_head/all_heads.json', 'r', encoding='utf-8') as f:
        ga_head = json.load(f)
    ga_portion = defaultdict(lambda: defaultdict(defaultdict))

    context_fls = []
    for file_idx in context_file_idxs:
        filename = context_file_template.format(file_idx)
        print('now process: ' + filename)
        updated_content = process_context_simple(
            filename, 'normal')
        context_fls.append(updated_content)

    for context_id, updated_content in enumerate(context_fls):
        context_name = 'context' + context_file_idxs[context_id]
        for k in updated_content:
            assert k in word_list
            sentence = add_tag(
                updated_content[k]['env'], ':<ENV>') + " <CHI> " + add_tag(updated_content[k]['lan'])

            try:  # get idx of anchor tokens in sentence
                anchor_env = get_anchor_id(sentence, k.lower())
            except AssertionError as e:
                print(
                    f"In file {file_idx}, key {k}: Skipping {updated_content[k]}, sentence: {sentence}")
                continue

            enc = tokenizer(sentence, return_tensors="pt")
            expected = add_tag(k)
            exp_id = tokenizer.encode(expected, add_special_tokens=False)[0]

            pth = Path(
                f'saliency_orig/{item_name}/ckpt_{model_step}/{context_name}/')
            try:
                saliency = torch.load(pth/f'key_{k}.pt')
                assert saliency.shape[-1] == enc['input_ids'].shape[1]
            except:
                outputs, loss = call_model(
                    enc, model=model, device=device, exp_id=exp_id)
                saliency = get_saliency_matrix(outputs, loss, sum=False)
            assert saliency.dim() == 4

            S = saliency.shape[-1]
            # word2anchor, or gather
            w2a_mask = torch.zeros((S, S), dtype=torch.bool)
            w2a_mask[anchor_env, :] = True
            w2a_mask = torch.tril(w2a_mask)
            # anchor2end, or aggregate
            a2e_mask = torch.zeros((S, S), dtype=torch.bool)
            a2e_mask[-1, anchor_env] = True

            try:
                # word (100, key) -> seed (5) -> context (10) -> {'gather': xx, 'agg': xx}
                ga_info = ga_head[k][seed2idx[item_name]][context_id]
            except KeyError as e:
                print(e)
                continue

            ga_tmp = defaultdict(list)
            for (gi, gj) in ga_info['gather']:
                w2a = saliency[gi, gj, w2a_mask]
                portion = ((torch.sum(w2a)/torch.sum(saliency[gi, gj])).item(),
                           torch.sum(w2a_mask).item())
                ga_tmp['gather'].append(portion)

            for (ai, aj) in ga_info['agg']:
                a2e = saliency[ai, aj, a2e_mask]
                portion = (a2e/torch.sum(saliency[ai, aj])).item()
                ga_tmp['aggregate'].append(portion)

            ga_portion[k][context_name] = ga_tmp

    with open(f'test/g_and_a_head/portion_{model_step}_{item_name}.json', 'w', encoding='utf-8') as f:
        json.dump(ga_portion, f, ensure_ascii=False, indent=2)
    print(f'saved test/g_and_a_head/portion_{model_step}_{item_name}.json')


if __name__ == '__main__':
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    import argparse
    parser = argparse.ArgumentParser(
        description='analyze attention flow in GPT-2')
    parser.add_argument('--item_name', type=str, default='childes_warmup_s42_shuffled',
                        help='name of the item to analyze')
    parser.add_argument('--checkpoint_dir_template', type=str, default='model/{}/{}',
                        help='directory for saving checkpoints')
    parser.add_argument('--model_name', type=str, default='checkpoint_14_20000.pt',
                        help='exact model name')

    args = parser.parse_args()

    print(f'****** Running with args: {args} ******')

    seed2idx = {'childes_warmup_s42_shuffled': 0, 'childes_warmup_s142_shuffled': 1,
                'childes_warmup_s242': 2, 'childes_warmup_s342': 3, 'childes_warmup_s442': 4}
    main(args.checkpoint_dir_template, args.item_name, args.model_name)
