from typing import Dict, List, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from .tokenize_utils import split_tokenize_context_prompt_targets

from ..rome import repr_tools
from ...util import nethook

from .sue_hparams import SUEFreeHyperParams
from .func_utils import gather_kl_tensor


def compute_z(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    request: Dict,
    hparams: SUEFreeHyperParams,
    layer: int,
    context_templates: List[str],
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Computes the value (right) vector for the rank-1 update.
    Runs a simple optimization procedure.
    """
    # Get model parameters
    lm_w, ln_f = (
        nethook.get_parameter(model, f"{hparams.lm_head_module}.weight").T,
        nethook.get_module(model, hparams.ln_f_module),
    )
    try:
        lm_b = nethook.get_parameter(model, f"{hparams.lm_head_module}.bias")
    except LookupError as _:
        lm_b = next(model.parameters()).new_zeros(model.config.vocab_size)

    print("Computing right vector (v)")
    
    """
    # Tokenize target into list of int token IDs
    target_ids = tok(request["target_new"], return_tensors="pt").to(f"cuda:{hparams.device}")[
        "input_ids"
    ][0]

    if target_ids[0] == tok.bos_token_id or target_ids[0] == tok.unk_token_id:
        target_ids = target_ids[1:]
    
    # Compile list of rewriting and KL x/y pairs
    rewriting_prompts, kl_prompts = [
        context.format(request["prompt"]) + tok.decode(target_ids[:-1])
        for context_types in context_templates
        for context in context_types
    ], ["{} is a"]
    
    # we compute kl only over inputs. Thus we split the tokenization process
    

    
    # NOTE: when use gpt2 tokenizer, there is no G' in the first token. 
    input_prompts = request["prompt"].format(request["subject"]) + tok.decode(target_ids[:-1]) # remove the last token
    # no special token included.
    inputs = tok(input_prompts, add_special_tokens=False)
    input_tok = inputs['input_ids']
    context_toks = [tok(cp, add_special_tokens=False)['input_ids'] for cp in flat_contexts]
    decoder_inputs = [ct + input_tok for ct in context_toks]
    # Compute rewriting targets
    rewriting_targets = torch.tensor(-100, device=f"cuda:{hparams.device}").repeat(
        len(rewriting_prompts), *input_tensor.shape[1:])
    for i in range(bsz):
        # padding side -> left
        rewriting_targets[i, - len(target_ids) :] = target_ids
    
    bsz, max_len = len(decoder_inputs), max([len(d) for d in decoder_inputs])
    # padding to the left
    input_tensor = torch.empty([bsz, max_len], dtype=torch.long).fill_(tok.pad_token_id).to(f"cuda:{hparams.device}")
    # input_tensor = torch.empty([bsz, max_len]).fill_(tok.pad_token_id).to(hparams.device)
    for i in range(len(decoder_inputs)):
        dec = decoder_inputs[i]
        input_tensor[i, -len(dec):] = torch.tensor(dec).to(input_tensor.device).long()
    attention_mask = input_tensor != tok.pad_token_id
    """
    
    # for context_types in context_templates:
    flat_contexts = [context for context_types in context_templates for context in context_types]
    assert all([item.endswith('{}') for item in flat_contexts]) # we only deal with contexts with prompt at the end.
    # encode inputs
    flat_contexts = [item.replace('{}', "").strip() for item in flat_contexts]

    tok_ret = split_tokenize_context_prompt_targets(
        tok, flat_contexts, request["prompt"].format(request["subject"]), request["target_new"]
    )
    device = f"cuda:{hparams.device}"
    input_tensor, attention_mask = tok_ret['decoder_inputs'].to(device), tok_ret['attention_mask'].to(device)
    bsz = input_tensor.shape[0]
    prompt_mask, rewriting_targets = tok_ret['prompt_mask'], tok_ret['rewriting_targets']
    target_ids = tok_ret['target_tokens']
    
    # NOTE: we do not contain KL prompts in our methods.

    # Finalize rewrite and loss layers
    loss_layer = max(hparams.v_loss_layer, layer)
    print(f"Rewrite layer is {layer}")
    print(f"Tying optimization objective to {loss_layer}")

    # Set up an optimization over a latent vector that, when output at the
    # rewrite layer, i.e. hypothesized fact lookup location, will induce the
    # target token to be predicted at the final layer.
    tgt_l = tok_ret['prompt_tokens'].shape[1]
    if hasattr(model.config, 'n_embd'):
        delta = torch.zeros((tgt_l, model.config.n_embd,), requires_grad=True, device=f"cuda:{hparams.device}")
    elif hasattr(model.config, 'hidden_size'):
        delta = torch.zeros((tgt_l, model.config.hidden_size,), requires_grad=True, device=f"cuda:{hparams.device}")
    else:
        raise NotImplementedError
    target_init, kl_distr_init = None, None
    batch_first = True

    # Inserts new "delta" variable at the appropriate part of the computation
    def edit_output_fn(cur_out, cur_layer):
        nonlocal target_init
        nonlocal batch_first

        if cur_layer == hparams.layer_module_tmp.format(layer):
            # Add intervened delta
            if bsz != len(cur_out[0]):
                batch_first = False
            
            # Store initial value of the vector of interest
            if target_init is None:
                print("Recording initial value of v*")
                # Initial value is recorded for the clean sentence
                target_init = cur_out[0][0][prompt_mask[0]].detach().clone() if batch_first \
                    else cur_out[0][:, 0][prompt_mask[0]].detach().clone()
            
            if batch_first:
                # cur_out[0][:, -delta.shape[0]:] += delta[None, :, :]
                # print(cur_out[0][prompt_mask].shape)
                # print(delta[None].expand(bsz, -1, -1).flatten(0,1).shape)
                try:
                    cur_out[0][prompt_mask] += delta[None].expand(bsz, -1, -1).flatten(0,1)
                except:
                    breakpoint()
            else:
                cur_out[0][prompt_mask.T] += delta[:, None].expand(-1, bsz, -1).flatten(0, 1)

        return cur_out

    # Optimizer
    opt = torch.optim.Adam([delta], lr=hparams.v_lr)
    nethook.set_requires_grad(False, model)

    # Execute optimization
    for it in range(hparams.v_num_grad_steps):
        opt.zero_grad()

        # Forward propagation
        with nethook.TraceDict(
            module=model,
            layers=[
                hparams.layer_module_tmp.format(loss_layer),
                hparams.layer_module_tmp.format(layer),
            ],
            retain_input=False,
            retain_output=True,
            edit_output=edit_output_fn,
        ) as tr:
            logits = model(input_tensor, attention_mask=attention_mask).logits
            # Compute distribution for KL divergence
            # TODO: pop out the target_new
            kl_logits = logits[:]
            gather_kl_logits = gather_kl_tensor(tensor=kl_logits, index=rewriting_targets)

            kl_log_probs = torch.nn.functional.log_softmax(gather_kl_logits, dim=1)
            if kl_distr_init is None:
                kl_distr_init = kl_log_probs.detach().clone()

        # Compute loss on rewriting targets
        output=tr[hparams.layer_module_tmp.format(loss_layer)].output[0]
        if output.shape[1]!=rewriting_targets.shape[1]:
            output=torch.transpose(output, 0, 1)
        full_repr = output[:bsz]

        # log_probs = torch.log_softmax(ln_f(full_repr) @ lm_w.to(full_repr.device) + lm_b.to(full_repr.device), dim=2)
        probs = torch.softmax(ln_f(full_repr) @ lm_w.to(full_repr.device) + lm_b.to(full_repr.device), dim=2)
        t = 1-probs
        eps=1e-7
        # eps = 0
        lprobs = torch.log(t + eps)

        loss = torch.gather(
            lprobs,
            2,
            torch.where(rewriting_targets != -100, rewriting_targets, 0).unsqueeze(2).to(lprobs.device),
        ).squeeze(2)
        gather_probs = torch.gather(
            probs,
            2,
            torch.where(rewriting_targets != -100, rewriting_targets, 0).unsqueeze(2).to(lprobs.device),
        ).squeeze(2)

        # import pdb; pdb.set_trace()
        # print(probs, lprobs, loss)
        mask = (rewriting_targets != -100).float()

        # Aggregate total losses
        nll_loss_each = -(loss * mask.to(loss.device)).sum(1) / target_ids.size(0)
        # print(nll_loss_each)
        if hparams.reverse_objective is True:
            nll_loss_each = -1 * nll_loss_each
        nll_loss = nll_loss_each.mean()
        kl_loss = hparams.kl_factor * torch.nn.functional.kl_div(
            kl_distr_init, kl_log_probs, log_target=True, reduction="batchmean"
        )
        weight_decay = hparams.v_weight_decay * (
            torch.norm(delta) / torch.norm(target_init) ** 2
        )
        # weight_decay = hparams.v_weight_decay * torch.norm(delta) ** 2
        # nll_loss = torch.exp(-nll_loss)
        loss = nll_loss + kl_loss.to(nll_loss.device) + weight_decay.to(nll_loss.device)
        
        print(
            f"loss {np.round(loss.item(), 3)} = {np.round(nll_loss.item(), 3)} + {np.round(kl_loss.item(), 3)} + {np.round(weight_decay.item(), 3)} "
            f"avg prob of [{request['target_new']}] "
            f"{gather_probs.mean().item()}"
        )
        if loss < 5e-2:
            break

        if it == hparams.v_num_grad_steps - 1:
            break
        
        if it == 1 and hparams.z_pos_quota>=0:
            breakpoint()

        # Backpropagate
        loss.backward()
        opt.step()
        print(delta.norm(dim=-1))

        # Project within L2 ball
        max_norm = hparams.clamp_norm_factor * target_init.norm()
        if delta.norm() > max_norm:
            with torch.no_grad():
                delta[...] = delta * max_norm / delta.norm()

    if hparams.z_pos_quota >= 0:
        breakpoint()
    target = target_init + delta
    print(
        f"Init norm {target_init.norm()} | Delta norm {delta.norm()} | Target norm {target.norm()}"
    )

    return target, delta


def get_module_input_output_at_words(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    layer: int,
    context_templates: List[str],
    words: List[str],
    module_template: str,
    fact_token_strategy: str,
    track=None,
) -> Tuple[torch.Tensor]:
    """
    Retrieves detached representations for a word at the input and
    output of a particular layer module.
    """

    word_repr_args = dict(
        model=model,
        tok=tok,
        layer=layer,
        module_template=module_template,
    )
    # if "subject_" in fact_token_strategy and fact_token_strategy.index("subject_") == 0:
    context_info = dict(
        context_templates=context_templates,
        words=words,
    )
    subtoken = fact_token_strategy[len("subject_") :]
    if track == 'out' or track == 'in':
        return repr_tools.get_reprs_at_word_tokens(
            track=track, subtoken=subtoken, **context_info, **word_repr_args
        )
    l_input, l_output = repr_tools.get_reprs_at_word_tokens(
        track="both", subtoken=subtoken, **context_info, **word_repr_args
    )
    # elif fact_token_strategy == "last":
    #     raise Exception("This is definitely bugged, fix it.")
    #     context_info = dict(
    #         contexts=[
    #             tmp[i].format(words[i]) for i, tmp in enumerate(context_templates)
    #         ],
    #         idxs=[000000],
    #     )
    #     if track == 'out' or track == 'in':
    #         return repr_tools.get_reprs_at_word_tokens(
    #             track=track, subtoken=subtoken, **context_info, **word_repr_args
    #         )
    #     l_input, l_output = repr_tools.get_reprs_at_idxs(
    #         track="both", **context_info, **word_repr_args
    #     )
    # else:
    #     raise ValueError(f"fact_token={fact_token_strategy} not recognized")

    return l_input.detach(), l_output.detach()



# def find_fact_lookup_idx(
#     prompt: str,
#     subject: str,
#     tok: AutoTokenizer,
#     fact_token_strategy: str,
#     verbose=True,
# ) -> int:
#     """
#     Computes hypothesized fact lookup index given a sentence and subject.
#     """

#     ret = None
#     if fact_token_strategy == "last":
#         ret = -1
#     elif (
#         "subject_" in fact_token_strategy and fact_token_strategy.index("subject_") == 0
#     ):
#         ret = repr_tools.get_words_idxs_in_templates(
#             tok=tok,
#             context_templates=[prompt],
#             words=[subject],
#             subtoken=fact_token_strategy[len("subject_") :],
#         )[0][0]
#     else:
#         raise ValueError(f"fact_token={fact_token_strategy} not recognized")

#     sentence = prompt.format(subject)
#     if verbose:
#         print(
#             f"Lookup index found: {ret} | Sentence: {sentence} | Token:",
#             tok.decode(tok(sentence)["input_ids"][ret]),
#         )

#     return ret
