import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import transformers
import higher
import logging
# from higher.patch import monkeypatch as make_functional
from .patch import monkeypatch as make_functional
from collections import defaultdict

from editable_model import EditableModel
from hooks import hook_model
import nn as local_nn
from utils import _logits, _inner_params

LOG = logging.getLogger(__name__)


def update_counter(x, m, s, k):
    new_m = m + (x - m) / k
    new_s = s + (x - m) * (x - new_m)

    return new_m, new_s


class GradientTransform(nn.Module):
    def __init__(self, x_dim: int, delta_dim: int, cfg, n_modes = None):
        super().__init__()

        self.x_dim = x_dim
        self.delta_dim = delta_dim
        self.cfg = cfg
        if cfg.combine and (cfg.one_sided or cfg.x_only or cfg.delta_only):
            raise ValueError("cfg.combine cannot be used with one-sided MEND variants")

        self.norm_init = False
        self.register_buffer("u_mean", torch.full((x_dim,), float("nan")))
        self.register_buffer("v_mean", torch.full((delta_dim,), float("nan")))
        self.register_buffer("u_std", torch.full((x_dim,), float("nan")))
        self.register_buffer("v_std", torch.full((delta_dim,), float("nan")))
        self.register_buffer("u_s", torch.full((x_dim,), float("nan")))
        self.register_buffer("v_s", torch.full((delta_dim,), float("nan")))
        self.register_buffer("k", torch.full((1,), float("nan")))

        MlpClass = getattr(local_nn, cfg.mlp_class)
        LOG.info(f"Building Gradient Transform with MLP class {MlpClass}")

        def delta_net():
            return MlpClass(delta_dim, delta_dim, delta_dim * 2, cfg.n_hidden, init=cfg.init, act=cfg.act, rank=cfg.rank, n_modes=n_modes)

        def x_net():
            return MlpClass(x_dim, x_dim, x_dim * 2, cfg.n_hidden, init=cfg.init, act=cfg.act, rank=cfg.rank, n_modes=n_modes)

        def combined_net():
            return MlpClass(delta_dim + x_dim, delta_dim + x_dim, (delta_dim + x_dim) * 2,
                            cfg.n_hidden, init=cfg.init, act=cfg.act, rank=cfg.rank, n_modes=n_modes)

        def ID():
            return lambda x, mode=None: x

        if cfg.combine:
            self.mlp = combined_net()
        elif cfg.one_sided:
            if x_dim > delta_dim:
                self.mlp1, self.mlp2 = ID(), delta_net()
            else:
                self.mlp1, self.mlp2 = x_net(), ID()
        elif cfg.x_only:
            self.mlp1, self.mlp2 = x_net(), ID()
        elif cfg.delta_only:
            self.mlp1, self.mlp2 = ID(), delta_net()
        else:
            self.mlp1, self.mlp2 = x_net(), delta_net()

    def forward(self, u, v, param_idx=None):
        LOG.info("call_forward of GradientTransform")
        u, v = u.to(torch.float32), v.to(torch.float32)

        u_ = u.view(-1, u.shape[-1])
        v_ = v.view(-1, v.shape[-1])

        nz_mask = (u_ != 0).any(-1) * (v_ != 0).any(-1)  # Skip batch elements with zero grad
        u_ = u_[nz_mask]
        v_ = v_[nz_mask]

        if self.training:
            for idx in range(u_.shape[0]):
                if not self.norm_init:
                    self.u_mean = u_[idx].clone().detach()
                    self.v_mean = v_[idx].clone().detach()
                    self.u_s.zero_()
                    self.v_s.zero_()
                    self.k[:] = 1
                    self.norm_init = True
                else:
                    self.k += 1
                    self.u_mean, self.u_s = update_counter(u_[idx], self.u_mean, self.u_s, self.k)
                    self.v_mean, self.v_s = update_counter(v_[idx], self.v_mean, self.v_s, self.k)

            if self.k < 2:
                raise RuntimeError(f"Can't perform normalization with only {self.k} samples so far")
            self.u_std = (self.u_s / (self.k - 1)) ** 0.5
            self.v_std = (self.v_s / (self.k - 1)) ** 0.5

        if self.cfg.norm:
            u_input = (u_ - self.u_mean) / (self.u_std + 1e-7)
            v_input = (v_ - self.v_mean) / (self.v_std + 1e-7)
        else:
            u_input = u_
            v_input = v_

        if self.cfg.combine:
            output = self.mlp(torch.cat((u_input, v_input), -1), mode=param_idx)
            out1, out2 = output.split([u.shape[-1], v.shape[-1]], -1)
            return out1, out2
        else:
            return self.mlp1(u_input, mode=param_idx), self.mlp2(v_input, mode=param_idx)


class MEND(EditableModel):
    def get_shape(self, p):
        # We need to flip the shapes since OpenAI gpt2 uses convs instead of linear
        return p.shape if isinstance(self.model, transformers.GPT2LMHeadModel) else (p.shape[1], p.shape[0])

    def __init__(self, model, config, model_constructor, mend=None, edit_lrs=None):
        super().__init__(model, config, model_constructor)

        if edit_lrs is None:
            edit_lrs = nn.Parameter(torch.tensor([config.edit_lr] * len(self.config.model.inner_params)))
        self.edit_lrs = edit_lrs

        if not hasattr(self.model, "handles"):
            hook_model(self.model, self.config.model.inner_params)
            LOG.info(f"Hooked {len(self.model.handles)//2} modules")

        if config.mend.shared:
            shape_dict = defaultdict(list)
            for n, p in _inner_params(model.named_parameters(), self.config.model.inner_params):
                shape_dict[self.get_shape(p)].append(n)
            self.shape_dict = shape_dict

        if mend is None:
            if not config.mend.shared:
                self.mend = nn.ModuleDict({
                    n.replace(".", "#"): GradientTransform(*self.get_shape(p), config.mend)
                    for (n, p) in _inner_params(model.named_parameters(), self.config.model.inner_params)
                })
            else:
                self.mend = nn.ModuleDict({
                    str(tuple(s)): GradientTransform(*s, config.mend, len(shape_dict[s]))
                    for s in shape_dict.keys()
                })
        else:
            self.mend = mend

    def state_dict(self, destination=None, prefix="", keep_vars=False):
        state_dict = super().state_dict(prefix=prefix, keep_vars=keep_vars)  # Get default state dict
        model_keys = self.model.state_dict(prefix=prefix, keep_vars=keep_vars).keys()  # Remove model params
        for k in model_keys:
            del state_dict[f"model.{k}"]
        state_dict["model_config"] = self.model.config  # Include model config
        return state_dict

    def load_state_dict(self, state_dict, strict: bool = True):
        config = state_dict["model_config"]
        del state_dict["model_config"]
        if config != self.model.config:
            LOG.info("Loaded model config doesn't match current model config.")
            LOG.info(f"Loaded: {config}")
            LOG.info(f"Current: {self.model.config}")

        res = super().load_state_dict(state_dict, False)
        # We should only have missing keys for the model, and no unexpected keys
        assert len([k for k in res.missing_keys if not k.startswith("model.")]) == 0, "Should only have missing keys for model."
        assert len(res.unexpected_keys) == 0, "Shouldn't have any unexpected keys"
        return res

    def outer_parameters(self):
        return list(self.mend.parameters()) + [self.edit_lrs]

    def edit(self, batch, condition=None, detach_history=False, gptj_labels=None):
        # outputs = _logits(self.model(batch))
        outputs = self.model(batch)
        if not isinstance(outputs, torch.Tensor):
            batch_labels = outputs.labels
            outputs = outputs.logits
        loss = self.edit_loss_fn(outputs, batch_labels)["nll"]

        names = set([n for n, p in self.model.named_parameters()])
        pset = set(self.config.model.inner_params)
        for p in pset:
            assert p in names, f"inner param {p} not in model"

        loss.backward()

        import pdb; pdb.set_trace()
        
        if self.config.mend.shared:
            param_idx = lambda n, p: self.shape_dict[self.get_shape(p)].index(n) if self.config.mend.shared else None  # noqa: E731
            
            # print(_inner_params(self.model.named_parameters(), self.config.model.inner_params))
            if self.config.model.use_parallelize:
                self.mend.to(_inner_params(self.model.named_parameters(), self.config.model.inner_params)[0][1].device)
            transformed_factors = {
                n: self.mend[str(tuple(self.get_shape(p)))](p.__x__, p.__delta__, param_idx(n, p))
                for n, p in _inner_params(self.model.named_parameters(), self.config.model.inner_params)
            }
        else:
            transformed_factors = {
                n: self.mend[n.replace(".", "#")](p.__x__, p.__delta__)
                for n, p in _inner_params(self.model.named_parameters(), self.config.model.inner_params)
            }

        # Should be bi,bj->ji for nn.Linear, but GPT2 uses Conv1d instead...
        if isinstance(self.model, transformers.GPT2LMHeadModel):
            targ = "ij"
        else:
            targ = "ji"
        mean_grads = {
            n: torch.einsum(f"bi,bj->{targ}", x, delta)
            for n, (x, delta) in transformed_factors.items()
        }

        info_dict = {}
        idx = 0
        for n, p in _inner_params(self.model.named_parameters(), self.config.model.inner_params):
            info_dict[f"grad/true_mag{idx}"] = p.grad.norm(2).item()
            info_dict[f"grad/pseudo_mag{idx}"] = mean_grads[n].norm(2).item()
            info_dict[f"grad/true_std{idx}"] = p.grad.std().item()
            info_dict[f"grad/pseudo_std{idx}"] = mean_grads[n].std().item()
            info_dict[f"grad/diff{idx}"] = (p.grad - mean_grads[n]).norm(2).item()
            info_dict[f"grad/cos{idx}"] = F.cosine_similarity(p.grad.reshape(-1), mean_grads[n].reshape(-1), dim=0).item()
            idx += 1

        self.model.zero_grad()

        assert len(self.edit_lrs) == len(list(mean_grads.items()))
        updates = {n: lr * g for lr, (n, g) in zip(self.edit_lrs, mean_grads.items())}

        edited_model = self.model
        if not isinstance(edited_model, higher.patch._MonkeyPatchBase):
            # if self.config.model.model_name == "blip2":
            #     edited_model.opt_model = make_functional(edited_model.opt_model, in_place=True)
            #     # edited_model.add_module("OPTForCausalLM", edited_model.opt_model)
            # else:
            #     edited_model = make_functional(edited_model, in_place=True)

            edited_model = make_functional(edited_model, in_place=True)

        new_params = []

        # if self.config.model.model_name == "blip2":
        #     for n, p in edited_model.named_parameters():
        #         if "opt_model."+n in pset:
        #             # if p.dtype == torch.float16:
        #             new_params.append(p + updates["opt_model."+n].to(p.dtype))
        #         else:
        #             new_params.append(p)
        #     edited_model.update_params(new_params)
        # else:
        #     for n, p in edited_model.named_parameters():
        #         if n in pset:
        #             # if p.dtype == torch.float16:
        #             new_params.append(p + updates[n].to(p.dtype))
        #         else:
        #             new_params.append(p)
        #     edited_model.update_params(new_params)

        for n, p in edited_model.named_parameters():
            if n in pset:
                # if p.dtype == torch.float16:
                new_params.append(p + updates[n].to(p.dtype))
            else:
                new_params.append(p)
        edited_model.update_params(new_params)

        # if self.config.model.model_name == "blip2":
        #     new_model = self.model_constructor()
        #     new_model.load_state_dict(self.model.state_dict())
        #     new_model.opt_model.load_state_dict(edited_model.state_dict())
        #     edited_model = new_model

        if detach_history:
            new_model = self.model_constructor()
            new_model.load_state_dict(edited_model.state_dict())
            edited_model = new_model

        return MEND(edited_model, self.config, self.model_constructor, self.mend, edit_lrs=self.edit_lrs), info_dict