import torch
import re

from config_evaluator import Lazy

def scale_grad(model, scaling):
    if scaling is None:
        return
    length = 0.0
    for p in model.parameters():
        if p.grad is not None:
            length += torch.square(p.grad).sum()
    length = torch.sqrt(length)
    if length > scaling:
        length = length / scaling
        for p in model.parameters():
            if p.grad is not None:
                p.grad /= length


def get_optimizer(model, optimizer, optimizer_groups):
    if optimizer is None:
        optimizer = Lazy(dict(), torch.optim.Adam)

    if optimizer_groups:
        # Example config
        # "optimizer_groups": [
        #     [".*prefix_embedding.*", {"lr": 1.0}],
        #     [".*lm_head.*", {"lr": 1e-5}],
        #     [".*", {"lr": 0.0}]  # all other parameters are frozen
        # ]

        groups = []
        for regex, hyperparam in optimizer_groups:
            h = dict(hyperparam)
            h["params"] = []
            groups.append(h)

        for name, param in model.named_parameters():
            for (regex, _), group in zip(optimizer_groups, groups):
                if re.match(regex, name):
                    group["params"].append(param)
                    break
        # Exclude groups with learning rate 0
        new_groups = []
        for d in groups:
            if "lr" in d and d["lr"] == 0.0:
                for param in d["params"]:
                    param.requires_grad_(False)
            else:
                new_groups.append(d)
        optimizer = optimizer.run(params=groups)
    else:
        optimizer = optimizer.run(params=model.parameters())

    return optimizer


class Regularizer:
    def __init__(self, initial_point: torch.nn.Module):
        pass

    def apply_reg(self, model: torch.nn.Module) -> torch.Tensor:
        raise NotImplementedError()

class IsotropicGaussPrior(Regularizer):
    def __init__(self, initial_point: torch.nn.Module, coeff: float, regex: str = ".*", invert_regex: bool = False):
        """
        Regex determines which parameters the prior should be applied to
        :param initial_point:
        :param regex:
        :param invert_regex:
        """
        super().__init__(initial_point)
        self.coeff = coeff
        self.initial_params = {k: v.detach().clone() for k, v in initial_point.named_parameters()
                               if (not invert_regex and re.fullmatch(regex, k)) or (invert_regex and re.fullmatch(regex, k) is None)}

    def apply_reg(self, model: torch.nn.Module):
        reg = 0
        model_dict = dict(model.named_parameters())
        for k,v in self.initial_params.items():
            reg += torch.sum(torch.square(model_dict[k] - v))
        return self.coeff * reg


