# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from kmeng01/rome.
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any, Dict, List, Set, Tuple

import torch
import torch.nn as nn

from swift import SwiftConfig
from swift.tuners.utils import SwiftAdapter, SwiftOutput
from swift.utils import get_logger
from .compute_u import compute_u
from .compute_v import compute_v
from .context_template import context_template
from .nethook import get_parameter
from .rome_hparams import ROMEHyperParams

CONTEXT_TEMPLATES_CACHE = None

logger = get_logger()


@dataclass
class RomeConfig(SwiftConfig):
    """
    The configuration class for the ROME module.
    This adapter can be used to inject/modify knowledge to models, without any training.
    ROME: [Rank-One Editing of Encoder-Decoder Models](https://arxiv.org/abs/2211.13317)
    Args:
        model_type(`str`): The model type, now support llama-7b/llama-13b
        tokenizer(`AutoTokenizer`): The tokenizer
        knowledge(`List[Dict]`): The knowledge to be injected to the model.
            format:
            >>> [
            >>>     {
            >>>         "prompt": "{} was the founder of",
            >>>         "subject": "Steve Jobs",
            >>>         "target": "Microsoft"
            >>>     }
            >>> ]
    """
    model_type: str = field(default=None, metadata={'help': 'The model type'})

    tokenizer: Any = field(
        default=None, metadata={'help': 'The tokenizer matching this model'})

    knowledge: List[Dict] = field(
        default=False, metadata={'help': 'The knowledge to be used'})

    batch_first: bool = field(
        default=True, metadata={'help': 'Batch at the first dimension or not'})

    def __post_init__(self):
        from swift.tuners.mapping import SwiftTuners
        self.swift_type = SwiftTuners.ROME

    @property
    def __dict__(self):
        _dict = super(RomeConfig, self).__dict__
        _dict.pop('tokenizer')
        return _dict


class Rome(SwiftAdapter):

    @staticmethod
    def prepare_model(model: nn.Module, config: RomeConfig, adapter_name: str):
        """
        Applies the selected model editing algorithm. Generates text both before and after
        for comparison of model behavior. Returns the updated model and the original values of
        weights that were changed.
        """
        modified_keys = set()
        if config.tokenizer is not None:
            for param in model.parameters():
                param.requires_grad = True

            hparams = ROMEHyperParams.from_name(config.model_type)
            modified_keys = apply_rome_to_model(model, config.tokenizer,
                                                config.knowledge, hparams,
                                                config.batch_first)

        def state_dict_callback(state_dict, adapter_name):
            return {
                key: value
                for key, value in state_dict.items() if key in modified_keys
            }

        def mark_trainable_callback(model):
            pass

        return SwiftOutput(config, state_dict_callback,
                           mark_trainable_callback)

    @staticmethod
    def has_additional_modules():
        return False


def apply_rome_to_model(
    model: torch.nn.Module,
    tokenizer: Any,
    knowledge: List[Dict],
    hparams: ROMEHyperParams,
    batch_first: bool,
) -> Set:
    """Apply ROME to a model

    Args:
        model(`torch.nn.Module`): The model instance.
        tokenizer(`Any`): The tokenizer.
        knowledge(`List[Dict]`): The knowledge to be filled into the model.
        hparams(`ROMEHyperParams`): The hyperparameter of ROME
        batch_first(`bool`): Batch first of not.
    """
    modified_keys = set()
    for i, request in enumerate(knowledge):
        deltas = execute_rome(model, tokenizer, request, hparams, batch_first)

        with torch.no_grad():
            for w_name, (delta_u, delta_v) in deltas.items():
                upd_matrix = delta_u.unsqueeze(1) @ delta_v.unsqueeze(0)
                w = get_parameter(model, w_name)
                upd_matrix = upd_matrix_match_shape(upd_matrix, w.shape)
                w[...] += upd_matrix
        modified_keys.update(set(deltas.keys()))
    return modified_keys


def execute_rome(
    model: torch.nn.Module,
    tok: Any,
    knowledge: Dict,
    hparams: ROMEHyperParams,
    batch_first: bool,
) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]:
    """
    Executes the ROME update algorithm for the specified update at the specified layer
    Invariant: model at beginning of function == model at end of function
    """

    # Update target and print info
    request = deepcopy(knowledge)
    logger.info(
        f'Executing ROME algorithm for the update: '
        f"[{request['prompt'].format(request['subject'])}] -> [{request['target']}]"
    )

    # Retrieve weights that user desires to change
    weights = {
        f'{hparams.rewrite_module_tmp.format(layer)}.weight':
        get_parameter(model,
                      f'{hparams.rewrite_module_tmp.format(layer)}.weight')
        for layer in hparams.layers
    }
    # Save old weights for future restoration
    weights_copy = {k: v.detach().clone() for k, v in weights.items()}

    # Update loop: sequentially intervene at each specified layer
    deltas = {}
    for layer in sorted(hparams.layers):
        # Compute rank-1 update matrix
        left_vector: torch.Tensor = compute_u(
            model,
            tok,
            request,
            hparams,
            layer,
            context_template,
            batch_first=batch_first,
        )
        logger.info(f'Left vector shape: {left_vector.shape}')
        right_vector: torch.Tensor = compute_v(
            model,
            tok,
            request,
            hparams,
            layer,
            left_vector,
            context_template,
            batch_first=batch_first,
        )
        logger.info(f'Right vector shape: {right_vector.shape}')
        right_vector = right_vector.to(left_vector.dtype)

        with torch.no_grad():
            # Determine correct transposition of delta matrix
            weight_name = f'{hparams.rewrite_module_tmp.format(layer)}.weight'
            upd_matrix = left_vector.unsqueeze(1) @ right_vector.unsqueeze(0)
            upd_matrix = upd_matrix_match_shape(upd_matrix,
                                                weights[weight_name].shape)

            # Update model weights and record desired changes in `delta` variable
            weights[weight_name][...] += upd_matrix
            deltas[weight_name] = (
                left_vector.detach(),
                right_vector.detach(),
            )

    # Restore state of original model
    with torch.no_grad():
        for k, v in weights.items():
            v[...] = weights_copy[k]

    logger.info(f'Deltas successfully computed for {list(weights.keys())}')

    return deltas


def upd_matrix_match_shape(matrix: torch.Tensor,
                           shape: torch.Size) -> torch.Tensor:
    """
    GPT-2 and GPT-J have transposed weight representations.
    Returns a matrix that matches the desired shape, else raises a ValueError
    """

    if matrix.shape == shape:
        return matrix
    elif matrix.T.shape == shape:
        return matrix.T
    else:
        raise ValueError(
            'Update matrix computed by ROME does not match original weight shape. '
            'Check for bugs in the code?')
