import torch
import torch.nn as nn
import higher

from editable_model import EditableModel
from utils import _logits


def fomaml_callback(all_grads):
    return [g.detach() if g is not None else None for g in all_grads]


class ENN(EditableModel):
    def __init__(self, model, config, model_constructor, edit_lrs=None, edit_loss_fn=None):
        super().__init__(model, config, model_constructor)

        if edit_lrs is None:
            edit_lrs = nn.Parameter(torch.tensor([config.edit_lr] * len(self.config.model.inner_params)))
        self.edit_lrs = edit_lrs

        if edit_loss_fn is not None:
            self.edit_loss_fn = edit_loss_fn

        self.grad_callback = fomaml_callback if config.enn.first_order else lambda x: x

    def outer_parameters(self, grouped=False):
        extra_params = [self.edit_lrs]
        if self.config.no_grad_layers is None:
            model_params = self.model.parameters() if type(self.model.parameters()) == list else list(self.model.parameters())
        else:
            model_params = []
            for m in self.model.modules():
                if isinstance(m, nn.ModuleList):
                    model_params.extend(list(m[self.config.no_grad_layers:].parameters()))

        if grouped:
            return [
                dict(params=model_params, lr=self.config.lr),
                dict(params=extra_params, lr=self.config.lr_lr)
            ]
        else:
            return model_params + extra_params

    def get_state_dict(self):
        return self.state_dict()

    def edit(self, batch, condition=None, detach_history=False):
        opt = torch.optim.SGD([{"params": p, "lr": None}
                               for (n, p) in self.model.named_parameters() if n in self.config.model.inner_params])
        with torch.enable_grad(), higher.innerloop_ctx(
                self.model,
                opt,
                override={'lr': list(self.edit_lrs)},
                copy_initial_weights=False,
                track_higher_grads=self.training,
                in_place=True
        ) as (fmodel, diffopt):
            fmodel.eval()
            for edit_step in range(self.config.enn.n_edit_steps):
                output = _logits(fmodel(**batch))
                loss = self.edit_loss_fn(output, batch["labels"])["nll"]
                diffopt.step(loss, grad_callback=self.grad_callback)

        if not detach_history:
            model_edited = fmodel
        else:
            model_edited = self.model_constructor()
            model_edited.load_state_dict(fmodel.state_dict())
        model_edited.train(self.training)

        return ENN(model_edited, self.config, self.model_constructor, edit_lrs=self.edit_lrs, edit_loss_fn=self.edit_loss_fn), {}


def test():
    import transformers
    import types
    import copy

    model = transformers.GPT2LMHeadModel.from_pretrained("gpt2")

    config = types.SimpleNamespace()
    config.edit_lr = 0.1
    config.model.inner_params = [
        "transformer.h.9.mlp.c_fc.weight",
        "transformer.h.9.mlp.c_proj.weight",
        "transformer.h.10.mlp.c_fc.weight",
        "transformer.h.10.mlp.c_proj.weight",
        "transformer.h.11.mlp.c_fc.weight",
        "transformer.h.11.mlp.c_proj.weight",
    ]
    config.enn = {
        "n_edit_steps": 2,
        "first_order": False
    }

    enn = ENN(model, config, lambda: copy.deepcopy(model)).cuda()

    x = torch.arange(100).view(5, 20).cuda() + 1000

    edited = enn.edit(x, masks=torch.ones_like(x), labels=x)

    orig_param = [p for (n, p) in enn.model.named_parameters() if n == config.model.inner_params[-1]][0]
    edited_param = [p for (n, p) in edited.model.named_parameters() if n == config.model.inner_params[-1]][0]

    print((orig_param - edited_param).abs().max())
    edited.eval()
    print(enn(x, labels=x).loss, edited(x, labels=x).loss, edited.edit_loss_fn(edited(x).logits, x)["nll"])
    edited.edit_loss_fn(edited(x).logits, x).backward()
    import pdb; pdb.set_trace()


if __name__ == '__main__':
    with torch.autograd.set_detect_anomaly(True):
        test()
