'''Code for analysing attention pattern. Only use ENV anchor and normal context.'''

from .checkpoint_surprisal import get_files_sorted, process_context_simple, add_tag
import os
import torch
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Config
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from .utils import context_file_template, context_file_idxs, tokenizer, word_list
from typing import Tuple
from pathlib import Path
from math import sqrt


def checkpoint_path_to_model_w_attn(path, device, layer_num=12):
    checkpoint = torch.load(path, map_location=device)

    config = GPT2Config(n_layer=layer_num)
    config.output_attentions = True
    config.return_dict = True
    model = GPT2LMHeadModel(config)
    model.resize_token_embeddings(tokenizer.vocab_size)
    model.load_state_dict(checkpoint["model_state_dict"])
    return model.eval()


def call_model(enc, model: GPT2LMHeadModel, device, exp_id: int) -> Tuple[CausalLMOutputWithCrossAttentions, torch.Tensor]:
    """Call the model with the encoded input."""
    outputs = model(**enc)
    next_logits = outputs.logits[:, -1, :]  # shape (1, Vocab)
    target = torch.tensor([exp_id], device=model.device)
    loss = F.cross_entropy(next_logits, target)
    return outputs, loss


def get_saliency_matrix(outputs, loss, sum=True) -> torch.Tensor:
    """Get the saliency matrix from the model outputs. saliency: (layer_num, S, S)"""
    # **Retain** gradients on each attention‐map tensor
    for att in outputs.attentions:  # type: ignore
        att.retain_grad()
    loss.backward()

    # outputs.attentions have 12 layers, each (1, H, S, S)
    # attn & attn_grad shape after stacking: (L, H, S, S). In our case (12, 12, S, S)
    attn_grad = torch.stack(
        [att.grad for att in outputs.attentions]).squeeze()  # type: ignore
    attn = torch.stack(outputs.attentions).squeeze()
    if sum:
        saliency = torch.sum(torch.abs(attn * attn_grad), dim=1)
    else:
        return torch.abs(attn * attn_grad)
    return saliency.detach()


def get_anchor_id(sentence: str, env: str, target=':<env>'):
    idx_ls = [idx for idx, w in enumerate(
        sentence.lower().split()) if w == env+target]
    if not idx_ls:
        idx_ls = [idx for idx, w in enumerate(  # morphological change
            sentence.lower().split()) if (env+target).lower() in w.lower()]
    assert len(
        idx_ls) == 1, f"Token '{env+target}' has problem in sentence: {sentence}"
    return torch.Tensor(idx_ls).int()


def main(checkpoint_dir_template, item_name, model_name, layer_num=12):
    checkpoint_path = checkpoint_dir_template.format(item_name, model_name)

    model_step = model_name.split('_')[-1].replace('.pt', '')
    model = checkpoint_path_to_model_w_attn(checkpoint_path, device, layer_num)

    for file_num_idx, file_idx in enumerate(context_file_idxs):
        context_name = f'context{file_idx}'

        filename = context_file_template.format(file_idx)
        print('now process: ' + filename)
        updated_content = process_context_simple(filename, 'normal')

        attn_across_content = []
        # TODO: make it parallel
        for k in word_list:
            try:
                sentence = add_tag(
                    updated_content[k]['env'], ':<ENV>') + " <CHI> " + add_tag(updated_content[k]['lan'])
            except KeyError as e:
                print(e, f'skip key {k}')
                continue

            enc = tokenizer(sentence, return_tensors="pt")
            expected = add_tag(k)
            exp_id = tokenizer.encode(
                expected, add_special_tokens=False)[0]

            pth = Path(
                f'saliency_orig/{item_name}/ckpt_{model_step}/{context_name}/')
            if (pth/f'key_{k}.pt').exists():
                print(
                    f'saliency_orig/{item_name}/ckpt_{model_step}/{context_name}/key_{k}.pt exist, continue')
                continue

            try:
                saliency = torch.load(pth/f'key_{k}.pt')
                assert saliency.shape[1] == enc['input_ids'].shape[1]
            except:
                outputs, loss = call_model(
                    enc, model=model, device=device, exp_id=exp_id)
                saliency = get_saliency_matrix(
                    outputs, loss, sum=False)
                pth.mkdir(parents=True, exist_ok=True)
                torch.save(saliency, pth/f'key_{k}.pt')

            try:  # get idx of anchor tokens in sentence
                anchor_env = get_anchor_id(sentence, k.lower())
            except AssertionError as e:
                print(
                    f"In file {file_num_idx}, key {k}: Skipping {updated_content[k]}, sentence: {sentence}")
                continue

            # mask2: mask of word2anchor
            mask2 = torch.zeros_like(saliency[0], dtype=torch.bool)
            mask2[anchor_env, :] = True
            mask2 = torch.tril(mask2, diagonal=-1)
            # mask3: anchor2end
            mask3 = torch.zeros_like(
                saliency[0], dtype=torch.bool)
            mask3[-1, anchor_env] = True
            # mask1: other
            mask1 = torch.tril(torch.ones_like(
                saliency[0], dtype=torch.bool) & ~mask2 & ~mask3)
            if anchor_env.numel() and torch.sum(mask2) > 0:
                S_wp = torch.sum(saliency[:, mask2], dim=1) / torch.sum(mask2)
            else:
                S_wp = torch.zeros(saliency.shape[0])
            S_pq = torch.sum(saliency[:, mask3], dim=1) / torch.sum(mask3)
            S_ww = torch.sum(saliency[:, mask1], dim=1) / torch.sum(mask1)

            cur_attn = torch.stack((S_wp, S_pq, S_ww))

            case_path = Path(
                f'attn_analysis/{item_name}/ckpt_{model_step}/{context_name}/')
            case_path.mkdir(parents=True, exist_ok=True)
            torch.save({'word2anchor': S_wp.cpu(), 'anchor2end': S_pq.cpu(),
                        'word2word': S_ww.cpu()}, case_path / f'{k}.pt')
            attn_across_content.append(cur_attn)

        num = len(attn_across_content)
        attn = torch.mean(torch.stack(attn_across_content), dim=0)
        stderr = torch.std(torch.stack(attn_across_content), dim=0)/sqrt(num)
        torch.save(
            {'saliency': attn, 'se': stderr},
            f'attn_analysis/{item_name}/ckpt_{model_step}/{context_name}_stat.pt')
        print(f'saved stat results for {num} cases',
              f'attn_analysis/{item_name}/ckpt_{model_step}/{context_name}_stat.pt')


if __name__ == '__main__':
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    import argparse
    parser = argparse.ArgumentParser(
        description='analyze attention flow in GPT-2')
    parser.add_argument('--item_name', type=str, default='childes_warmup_s42_shuffled',
                        help='name of the item to analyze')
    parser.add_argument('--checkpoint_dir_template', type=str, default='model/{}/{}',
                        help='directory for saving checkpoints')
    parser.add_argument('--model_name', type=str, default='checkpoint_14_20000.pt',
                        help='exact model name')
    parser.add_argument('--layer_num', type=int, default=12,
                        help='exact model name')

    args = parser.parse_args()

    print(f'****** Running with args: {args} ******')
    main(args.checkpoint_dir_template, args.item_name,
         args.model_name, args.layer_num)
