import torch.nn as nn

from losses import masked_log_probs
from utils import _logits, shift_targets


class EditableModel(nn.Module):
    def __init__(self, model, config, model_constructor):
        super().__init__()

        self.model = model
        self.config = config
        self.model_constructor = model_constructor

        def _edit_loss_fn(pred, targ, **kwargs):
            return masked_log_probs(pred, targ, shift=shift_targets(self.config), **kwargs)
        self.edit_loss_fn = _edit_loss_fn
        self.loc_loss_fn = _edit_loss_fn

    def edit(self, batch, condition=None, detach_history=False):
        raise NotImplementedError

    def forward(self, *inputs, **kwargs):
        return _logits(self.model(*inputs, **kwargs))

    def outer_parameters(self, grouped=False):
        if grouped:
            return [dict(params=self.parameters(), lr=self.config.lr)]
        else:
            return list(self.parameters())

    def generate(self, *args, **kwargs):
        return self.model.generate(*args, max_new_tokens=100, output_scores=True, return_dict_in_generate=True, **kwargs)

    def base_loss(self, input_ids, attention_masks, label_ids):
        pass
