from typing import Any, Dict, List, Tuple
import torch
from copy import deepcopy
from transformers import AutoModelForCausalLM, AutoTokenizer
from .FT_ewc import Finetune_ewc
from .ft_ewc_hparams import FTEWCHyperParams
from ..grace.utils import tokenize
batch_history = []

def apply_ft_ewc_to_model(
        model: AutoModelForCausalLM,
        tok: AutoTokenizer,
        requests: List[Dict],
        hparams: FTEWCHyperParams,
        copy=False,
        return_orig_weights=False,
        keep_original_weight=False,
        **kwargs: Any,
) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]:
    request = requests
    if copy:
        model = deepcopy(model)
    editor = Finetune_ewc(model=model, config=hparams, device=hparams.device)

    tokens = tokenize(request, tokenizer=tok, device=hparams.device)

    global batch_history
    batch_history.append(tokens)
    editor.edit(config=hparams, tokens=tokens, batch_history=batch_history)

    if len(batch_history) > hparams.fisher_mem:
        batch_history.pop(0) # pop the oldest
        assert len(batch_history) == hparams.fisher_mem, print('error in fisher memory...')

    return editor, editor.weights_copy


