from typing import Callable

from pytorch_pretrained_bert.optimization import BertAdamMicroGrad


class MicroGradScheduler:
    def __init__(self, model, micro_optimizer: BertAdamMicroGrad):
        self.model = model
        self.micro_optim = micro_optimizer
        self.hook_handlers = []

        self.param_index_mapping = {}
        for _idx, _param in enumerate(model.parameters()):
            self.param_index_mapping[_param] = _idx

        self.next_layer_mapping = {}
        self.build_comp_graph()
        self.register_layer_hook()

    def get_global_step(self, global_step):
        self.global_step = global_step

    def build_comp_graph(self):
        self.next_layer_mapping[self.model.bert.embeddings] = self.model.bert.encoder.FinalLayerNorm
        self.next_layer_mapping[self.model.bert.encoder.FinalLayerNorm] = self.model.bert.encoder.layer[0]
        

        prev_layer = None
        for i, layer in enumerate(self.model.bert.encoder.layer.children()):
            if i > 0:
                self.next_layer_mapping[prev_layer] = layer
            prev_layer = layer

        self.next_layer_mapping[prev_layer] = self.model.bert.pooler
        self.next_layer_mapping[self.model.bert.pooler] = self.model.cls

    def register_layer_hook(self):
        hook_ = self.model.bert.embeddings.register_backward_hook(self.Layer_hook_factory())
        self.hook_handlers.append(hook_)
        hook_ = self.model.bert.encoder.FinalLayerNorm.register_backward_hook(self.Layer_hook_factory())
        self.hook_handlers.append(hook_)
        for i, layer in enumerate(self.model.bert.encoder.layer.children()):
            hook_ = layer.register_backward_hook(self.Layer_hook_factory())
            self.hook_handlers.append(hook_)

        hook_ = self.model.bert.pooler.register_backward_hook(self.Layer_hook_factory())
        self.hook_handlers.append(hook_)


    def windup_optim_step(self, closure: Callable = None):
        # for p in self.model.parameters():
        for p in self.model.bert.embeddings.parameters():  # for faster speed, but may be logically flaw
            param_idx = self.param_index_mapping[p]
            self.micro_optim.micro_step(self.global_step, param_group_idx=param_idx)
            if hasattr(p, "grad"):
                del p.grad
        

    def Layer_hook_factory(self):
        def Layer_hook(module, grad_input, grad_output):
            next_layer = self.next_layer_mapping[module]
            for p in next_layer.parameters():
                param_idx = self.param_index_mapping[p]
                self.micro_optim.micro_step(self.global_step, param_group_idx=param_idx)
                if hasattr(p, "grad"):
                    del p.grad

        return Layer_hook
