# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from kmeng01/rome.
from typing import Dict, List

import torch
from modelscope import AutoTokenizer

from swift.utils.logger import get_logger
from .repr_tools import get_reprs_at_idxs, get_reprs_at_word_tokens
from .rome_hparams import ROMEHyperParams

logger = get_logger()


def compute_u(
    model: torch.nn.Module,
    tokenizer: AutoTokenizer,
    request: Dict,
    hparams: ROMEHyperParams,
    layer: int,
    context_templates: List[str],
    batch_first=True,
) -> torch.Tensor:
    """
    Computes the left vector used in constructing the rank-1 update matrix.
    """

    logger.info('Computing left vector (u)...')

    # Compute projection token
    word_repr_args = dict(
        model=model,
        tokenizer=tokenizer,
        layer=layer,
        module_template=hparams.rewrite_module_tmp,
        track='in',
        batch_first=batch_first,
    )
    if 'subject_' in hparams.fact_token and hparams.fact_token.index(
            'subject_') == 0:
        word = request['subject']
        logger.info(f'Selected u projection object {word}')
        cur_repr = get_reprs_at_word_tokens(
            context_templates=[
                templ.format(request['prompt']) for templ in context_templates
            ],
            words=[word for _ in range(len(context_templates))],
            subtoken=hparams.fact_token[len('subject_'):],
            **word_repr_args,
        ).mean(0)
    elif hparams.fact_token == 'last':
        # Heuristic to choose last word. Not a huge deal if there's a minor
        # edge case (e.g. multi-token word) because the function below will
        # take the last token.
        cur_repr = get_reprs_at_idxs(
            contexts=[
                templ.format(request['prompt'].format(request['subject']))
                for templ in context_templates
            ],
            idxs=[[-1] for _ in range(len(context_templates))],
            **word_repr_args,
        ).mean(0)
        logger.info('Selected u projection token with last token')
    else:
        raise ValueError(f'fact_token={hparams.fact_token} not recognized')

    # Apply inverse second moment adjustment
    u = cur_repr
    return u / u.norm()
