from typing import Dict, List

import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from .compute_z import get_module_input_output_at_words
from .sue_hparams import SUEFreeHyperParams
from .tokenize_utils import split_tokenize_context_prompt_targets
from ...util import nethook

def left_padding(ts, padding_value):
    assert len(ts[0].shape) == 2 and ts[0].shape[0] == 1
    # reverse and pad and reverse
    pt = torch.nn.utils.rnn.pad_sequence([ \
        t[0].flip(-1) for t in ts], batch_first=True, padding_value=padding_value).flip(dims=[1])
    return pt


def compute_ks(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: Dict,
    hparams: SUEFreeHyperParams,
    layer: int,
    context_templates: List[str],
    trace: str = 'in',
)-> torch.Tensor:
    """
    Runs input through model and returns averaged representations of the tokens
    at each index in `idxs`.
    """
    all_inputs, all_masks = [], []
    for request in requests:
        # compute the index for each context
        flat_contexts = [context for context_types in context_templates for context in context_types]
        assert all([item.endswith('{}') for item in flat_contexts]) # we only deal with contexts with prompt at the end.
        # encode inputs
        flat_contexts = [item.replace('{}', "").strip() for item in flat_contexts]
        prompts = request["prompt"].format(request["subject"])
        targets = request["target_new"]
        tok_ret = split_tokenize_context_prompt_targets(
            tok, flat_contexts, prompts, targets
        )
        # input_tensor, attention_mask = tok_ret['decoder_inputs'].to(device), tok_ret['attention_mask'].to(device)
        cp_inputs = tok_ret['cp_inputs']
        prompt_mask = tok_ret['cp_prompt_mask']
        all_inputs += list(torch.split(cp_inputs, 1, dim=0))
        all_masks += list(torch.split(prompt_mask, 1, dim=0))
    # layer_ks shape: [N_l, N_context, pl, D]
    layer_ks = get_module_input_output_at_mask(tok, model, hparams, input_list=all_inputs, mask_list=all_masks, layer=layer)
    layer_ks = layer_ks[0] if trace == 'in' else layer_ks[1]
    assert layer_ks.shape[0] == 1
    layer_ks = layer_ks.squeeze(0) # [N_context, pl, D]
    
    # average over types?
    context_type_lens = [0] + [len(context_type) for context_type in context_templates]
    context_len = sum(context_type_lens)
    context_type_csum = np.cumsum(context_type_lens).tolist()

    ans = []
    for i in range(0, layer_ks.size(0), context_len):
        tmp = []
        for j in range(len(context_type_csum) - 1):
            start, end = context_type_csum[j], context_type_csum[j + 1]
            tmp.append(layer_ks[i + start : i + end].mean(0))
        ans.append(torch.stack(tmp, 0).mean(0))
    return torch.stack(ans, dim=0)

def get_module_input_output_at_mask(
    tok, 
    model, 
    hparams,
    input_list, 
    mask_list,
    layer,
):
    device = f"cuda:{hparams.device}"
    def _batch(n):
        for i in range(0, len(input_list), n):
            contexts = input_list[i:i+n]
            masks = mask_list[i:i+n]
            contexts = left_padding(contexts, padding_value=tok.pad_token_id)
            masks = left_padding(masks, padding_value=False)
            yield contexts, masks
    
    bsz = len(input_list)
    # track = "both"
    tin, tout = True, True
    module_name = hparams.rewrite_module_tmp.format(layer)
    to_return = {"in": [], "out": []}

    def _process(cur_repr, batch_mask, key):
        nonlocal to_return
        cur_repr = cur_repr[0] if type(cur_repr) is tuple else cur_repr
        if cur_repr.shape[0]!=batch_mask.shape[0]:
            cur_repr=cur_repr.transpose(0,1)
        # for i, idx_list in enumerate(batch_mask):
            # to_return[key].append(cur_repr[i][idx_list].mean(0))
        to_return[key].append(cur_repr[batch_mask].reshape(bsz,-1, cur_repr.shape[-1])) # [bsz, pl, D]

    for input_ids, batch_mask in _batch(n=128):
        #contexts_tok:[21 19]
        # contexts_tok = tok(batch_contexts, padding=True, return_tensors="pt").to(
        #     next(model.parameters()).device
        # )
        attention_mask = input_ids != tok.pad_token_id

        with torch.no_grad():
            with nethook.Trace(
                module=model,
                layer=module_name,
                retain_input=tin,
                retain_output=tout,
            ) as tr:
                model(input_ids=input_ids.to(device), attention_mask=attention_mask.to(device))

        if tin:
            _process(tr.input, batch_mask, "in")
        if tout:
            _process(tr.output, batch_mask, "out")

    to_return = {k: torch.stack(v, 0) for k, v in to_return.items() if len(v) > 0}

    if len(to_return) == 1:
        return to_return["in"] if tin else to_return["out"]
    else:
        return to_return["in"], to_return["out"]

#     layer_ks = get_module_input_output_at_words(
#         model,
#         tok,
#         layer,
#         context_templates=[
#             context.format(request["prompt"])
#             for request in requests
#             for context_type in context_templates
#             for context in context_type
#         ],
#         words=[
#             request["subject"]
#             for request in requests
#             for context_type in context_templates
#             for _ in context_type
#         ],
#         module_template=hparams.rewrite_module_tmp,
#         fact_token_strategy='subject_last',
#     )[0]

#     context_type_lens = [0] + [len(context_type) for context_type in context_templates]
#     context_len = sum(context_type_lens)
#     context_type_csum = np.cumsum(context_type_lens).tolist()

#     ans = []
#     for i in range(0, layer_ks.size(0), context_len):
#         tmp = []
#         for j in range(len(context_type_csum) - 1):
#             start, end = context_type_csum[j], context_type_csum[j + 1]
#             tmp.append(layer_ks[i + start : i + end].mean(0))
#         ans.append(torch.stack(tmp, 0).mean(0))
#     return torch.stack(ans, dim=0)



