from src.common.tf.loaders import load_model, load_tokenizer
from src.settings.pipelines.common import MergeSettings


def merge_models_with_alpha(settings: MergeSettings) -> None:
    tokenizer = load_tokenizer(settings.tokenizer_settings, settings.old_model_settings)
    cur_tokenizer = load_tokenizer(settings.tokenizer_settings, settings.cur_model_settings)

    alpha = settings.alpha

    def compare_tokenizers(tok_1, tok_2):
        attributes_to_compare = ['vocab_size', 'model_max_length', 'padding_side', 'truncation_side']

        for attr in attributes_to_compare:
            if getattr(tok_1, attr) != getattr(tok_2, attr):
                print(f'Difference in {attr}: {getattr(tok_1, attr)} != {getattr(tok_2, attr)}')
                return False
        return True

    assert compare_tokenizers(tokenizer, cur_tokenizer)

    old_model = load_model(settings.old_model_settings, tokenizer)
    cur_model = load_model(settings.cur_model_settings, tokenizer)

    for cur_param, old_param in zip(cur_model.parameters(), old_model.parameters()):
        cur_param.data.copy_((alpha * old_param.data) + (1.0 - alpha) * cur_param.data)

    cur_model.save_pretrained(settings.save_path)
    tokenizer.save_pretrained(settings.save_path)
