import torch

from lbmqt.conf import config
from lbmqt.qscheme import QGradScheme, QScheme


class LowBitOptimizer(torch.optim.Optimizer):
    def __init__(self, params, defaults, param_scheme: QScheme):
        # initialize defauls and super
        defaults['quantifiable'] = False
        super(LowBitOptimizer, self).__init__(params, defaults)

        # steps recorder
        self.num_steps = 0

        # basic scheme
        self.param_scheme = param_scheme
        q_param_names = self.get_quantifiable_param_names()
        self.grad_scheme = QGradScheme(
            'grad',
            param_names=q_param_names,
            bits=config.compression_bits_grad,
            group_size=config.group_size,
            enable=config.enable_quantize_grad,
            num_mode=config.numerical_mode_grad,
            stochastic=config.stochastic,
        )

        # simpilified param id
        self.param_id = {}
        self.param_dict = {}
        self.id2group = {}
        self.id2name = {}
        self.name2param = {}

        # interesting code: unique ids being assigned to individual parameters
        largest_param_numel = 0
        count = 0
        for i, param_group in enumerate(self.param_groups):
            for name, param in zip(param_group['names'], param_group['params']):
                unique_id = id(param)
                self.param_id[unique_id] = count
                self.param_dict[count] = param
                self.id2group[count] = param_group
                self.id2name[count] = name
                self.name2param[name] = param
                if param.numel() > largest_param_numel:
                    largest_param_numel = param.numel()
                count = count + 1
        # print(f'in optimizer init, param_id: {self.param_id}')
        # print(f'in optimizer init, name2param: {self.name2param}')

        # initialize gradient accumulator 
        self.grad_accumulator = {}
        for i, params_group in enumerate(self.param_groups):
            for p in params_group['params']:
                param_id = self.get_param_id(p)
                self.grad_accumulator[param_id] = None

    def get_quantifiable_param_names(self):
        names = []
        for group in self.param_groups:
            if group['quantifiable']:
                names += group['names']
        return names

    def get_param_id(self, param):
        unique_id = id(param)
        return self.param_id[unique_id]
    
    def get_param_name(self, param):
        param_id = self.get_param_id(param)
        return self.id2name[param_id]

    def get_param_group(self, key):
        if isinstance(key, str):
            return self.get_param_group_by_name(key)
        elif isinstance(key, torch.Tensor):
            return self.get_param_group_by_param(key)
        else:
            raise ValueError()

    def get_param_group_by_param(self, param):
        param_id = self.get_param_id(param)
        return self.id2group[param_id]

    def get_param_group_by_name(self, name):
        param = self.name2param[name]
        return self.get_param_group_by_param(param)

    def get_param_state(self, key, val_key):
        if isinstance(key, str):
            return self.get_param_state_by_name(key, val_key)
        elif isinstance(key, torch.Tensor):
            return self.get_param_state_by_param(key, val_key)
        else:
            raise ValueError() 

    def get_param_state_by_param(self, param, name):
        val = self.state[param].get(name, None)
        if val is not None:
            return val
        else:
            param_id = self.get_param_id(param)
            group = self.get_param_group(param)
            val = group[name]
            if val is not None:
                return val
            else:
                raise ValueError('param {} does not have state {} in optimizer'.format(
                    param_id, name
                ))

    def get_param_state_by_name(self, param_name, name):
        param = self.name2param[param_name]
        return self.get_param_state_by_param(param)
    
    def accumulate_gradient(self, param_name, grad):
        if grad is None:
            return

        p = self.name2param[param_name]
        param_id = self.get_param_id(p)
        if self.grad_accumulator[param_id] is not None:
            if self.get_param_state(p, 'quantifiable'):
                old_grad = self.grad_scheme.dequantize_data(self.grad_accumulator[param_id], self.get_param_name(p))
                new_grad = old_grad + grad
                self.grad_scheme.midst_process_inside_microbatch(new_grad, param_name)
                self.grad_accumulator[param_id] = self.grad_scheme.quantize_data(new_grad, self.get_param_name(p))
            else:
                old_grad = self.grad_accumulator[param_id]
                new_grad = old_grad + grad
                self.grad_accumulator[param_id] = new_grad
        else:
            if self.get_param_state(p, 'quantifiable'):
                # reset quantization config for gradient accumulator when accumulating from scratch
                self.grad_scheme.reset_quantization_state(param_name)
                q_grad = self.grad_scheme.quantize_data(grad, self.get_param_name(p))
                self.grad_accumulator[param_id] = q_grad
            else:
                self.grad_accumulator[param_id] = grad

    @torch.no_grad()
    def step(self, closure=None):
        r"""Performs a single optimization step with quantization.
        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        raise NotImplementedError('step function need overriding')

    def add_param_group(self, param_group):
        params = param_group.get('params', None)
        names = param_group.get('names', None)
        assert len(params) == len(names), 'names of each parameter should exist in the param group'

        super(LowBitOptimizer, self).add_param_group(param_group)