from typing import Callable

# from MicroOptimizer import MicroOptimizer
import adam_microbatch_grad as adam_microbatch_grad
import torch


class MicroGradScheduler:
    def __init__(self, model, micro_optimizer: adam_microbatch_grad):
        self.model = model
        self.micro_optim = micro_optimizer
        self.hook_handlers = []

        self.param_index_mapping = {}

        self.next_layer_mapping = {}
        self.build_comp_graph()
        self.register_layer_hook()


    def build_comp_graph(self):
        param_group_idx = 0
        param_group_idx_container = []
        prev_layer = None
        for i, block in enumerate(self.model.children()):
            if i == 0:
                prev_layer = block
                for para in block.parameters():
                    param_group_idx_container.append(param_group_idx)
                    param_group_idx += 1
                self.param_index_mapping[block] = param_group_idx_container
                param_group_idx_container = []

                # pass
            elif i == 5:
                self.next_layer_mapping[prev_layer] = block
                for para in block.parameters():
                    param_group_idx_container.append(param_group_idx)
                    param_group_idx += 1
                self.param_index_mapping[block] = param_group_idx_container
                param_group_idx_container = []

            else:
                for bottleneck in block.children():
                    self.next_layer_mapping[prev_layer] = bottleneck
                    prev_layer = bottleneck
                    for para in bottleneck.parameters():
                        param_group_idx_container.append(param_group_idx)
                        param_group_idx += 1
                    self.param_index_mapping[bottleneck] = param_group_idx_container
                    param_group_idx_container = []



    def register_layer_hook(self):

        for i, block in enumerate(self.model.children()):
            if i == 0 or i == 5:
                block.register_full_backward_hook(self.Layer_hook_factory(i))
                # self.hook_handlers.append(hook_)
            else:
                for bottleneck in block.children():
                    bottleneck.register_full_backward_hook(self.Layer_hook_factory(i))




    def windup_optim_step(self, closure: Callable = None):

        for p in self.model.conv1.parameters():  # for faster speed, but may be logically flaw
            if hasattr(p, "grad"):
                del p.grad

    def Layer_hook_factory(self, layer_num):
        def Layer_hook(module, grad_input, grad_output):
            self.micro_optim.step(first_block =False, param_group_idx_container=self.param_index_mapping[module])
            next_layer = self.next_layer_mapping[module]
            for p in next_layer.parameters():
                if hasattr(p, "grad"):
                    del p.grad
                    # pass

        def First_layer_hook(module, grad_input, grad_output):
            self.micro_optim.step(first_block =True, param_group_idx_container=self.param_index_mapping[module])
            next_layer = self.next_layer_mapping[module]
            for p in next_layer.parameters():
                if hasattr(p, "grad"):
                    del p.grad
                    # pass


        
        def Last_layer_hook(module, grad_input, grad_output):
            self.micro_optim.step(first_block =False, param_group_idx_container=self.param_index_mapping[module])
            # pass

        if layer_num == 0:
            return First_layer_hook
            
        elif layer_num == 5:
            return Last_layer_hook

        

        return Layer_hook