import torch
import math

from lbmqt.qscheme import QScheme
from lbmqt.optim.optimizer import LowBitOptimizer
from lbmqt.conf import config


class Adam(LowBitOptimizer):
    def __init__(
        self,
        params,
        param_scheme,
        lr=1e-3,
        betas=(0.9, 0.999), 
        eps=1e-8,
        weight_decay=0,
        amsgrad=False,
        num_micro_batches=-1,
    ):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay)
        super(Adam, self).__init__(params, defaults, param_scheme)

        q_param_names = self.get_quantifiable_param_names()
        self.num_micro_batches = num_micro_batches
        self.mm_scheme = QScheme(
            name='momentum',
            param_names=q_param_names,
            bits=config.compression_bits_mm,
            group_size=config.group_size,
            enable=config.enable_quantize_mm,
            num_mode=config.numerical_mode_mm,
        )
        self.sqmm_scheme = QScheme(
            name='square_momentum',
            param_names=q_param_names,
            bits=config.compression_bits_sqmm,
            group_size=config.group_size,
            enable=config.enable_quantize_sqmm,
            num_mode=config.numerical_mode_sqmm,
        )

    @torch.no_grad()
    def step(self, closure=None):
        r"""Performs a single optimization step with quantization.
        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        self.num_steps += 1

        for group in self.param_groups:
            for p in group['params']:
                
                # only update for some micro batches when self.num_micro_batches > 0
                if self.num_micro_batches > 0 and self.num_steps % self.num_micro_batches != 0:
                    continue

                param_name = self.get_param_name(p)
                param_id = self.get_param_id(p)
                if self.grad_accumulator[param_id] is None:
                    continue

                state = self.state[param_id]
                if len(state) == 0:
                    state['exp_avg'] = None
                    state['exp_avg_sq'] = None

                # get the full-precision gradient from gradient accumulator
                # do not consider amsgrad, nesterov
                grad = self.grad_accumulator[param_id]
                fp_grad = self.grad_scheme.dequantize_data(grad, param_name)
                fp_p = self.param_scheme.dequantize_data(p.data, param_name)

                lr = group['lr']
                beta1, beta2 = group['betas']
                weight_decay = group['weight_decay']
                eps = group['eps']
                step = self.num_steps
                # dequantize the optimizer state
                if state['exp_avg'] is not None:
                    exp_avg = self.mm_scheme.dequantize_data(state['exp_avg'], param_name, signed=True)
                    if config.unbiased_sqmm_flag and self.sqmm_scheme.is_quantized(param_name):
                        exp_avg_sq_transform = self.sqmm_scheme.dequantize_data(state['exp_avg_sq'], param_name, signed=False)
                        exp_avg_sq = torch.zeros_like(exp_avg_sq_transform, memory_format=torch.preserve_format)
                        val = 1 / (exp_avg_sq_transform[exp_avg_sq_transform > 0] ** 2)
                        exp_avg_sq[exp_avg_sq_transform > 0] = val
                        exp_avg_sq[exp_avg_sq_transform <= 0] = exp_avg_sq.max()
                        del exp_avg_sq_transform
                    else:
                        exp_avg_sq = self.sqmm_scheme.dequantize_data(state['exp_avg_sq'], param_name, signed=False)
                else:
                    exp_avg = torch.zeros_like(fp_p, memory_format=torch.preserve_format)
                    exp_avg_sq = torch.zeros_like(fp_p, memory_format=torch.preserve_format)

                # compute weight decay contribution
                if weight_decay != 0:
                    if self.num_micro_batches > 0:
                        fp_grad = fp_grad.add(fp_p, alpha=self.num_micro_batches * weight_decay)
                    else:
                        fp_grad = fp_grad.add(fp_p, alpha=weight_decay)

                # update full-precision parameters
                # print(f'{fp_grad.dtype}, {exp_avg.dtype}, {beta1}')
                exp_avg.mul_(beta1).add_(fp_grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(fp_grad, fp_grad.conj(), value=1 - beta2)

                bias_correction1 = 1 - beta1 ** step
                bias_correction2 = 1 - beta2 ** step
                step_size = lr / bias_correction1
                bias_correction2_sqrt = math.sqrt(bias_correction2)
                denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
                fp_p.data.addcdiv_(exp_avg, denom, value=-step_size)

                # reset quantization config for parameters at(before) quantization storing point
                if config.steps_requantization_p > 0 and self.num_steps % config.steps_requantization_p == 0:
                    self.param_scheme.reset_quantization_state(param_name)
                if config.steps_requantization_mm > 0 and self.num_steps % config.steps_requantization_mm == 0:
                    self.mm_scheme.reset_quantization_state(param_name)
                    self.sqmm_scheme.reset_quantization_state(param_name)
                
                # quantize and store
                p.data = self.param_scheme.quantize_data(fp_p, param_name)
                state['exp_avg'] = self.mm_scheme.quantize_data(exp_avg, param_name, signed=True)
                if config.unbiased_sqmm_flag and self.sqmm_scheme.is_quantized(param_name):
                    exp_avg_sq_transform = torch.zeros_like(exp_avg_sq, memory_format=torch.preserve_format)
                    val = 1 / exp_avg_sq[exp_avg_sq > 0].sqrt()
                    exp_avg_sq_transform[exp_avg_sq > 0] = val
                    exp_avg_sq_transform[exp_avg_sq <= 0] = exp_avg_sq_transform.max()
                    state['exp_avg_sq'] = self.sqmm_scheme.quantize_data(exp_avg_sq_transform, param_name, signed=False)
                    del exp_avg_sq_transform
                else:
                    state['exp_avg_sq'] = self.sqmm_scheme.quantize_data(exp_avg_sq, param_name, signed=False)

                # post-process
                del fp_p
                del fp_grad
                del exp_avg
                del exp_avg_sq
                del denom

                self.grad_accumulator[param_id] = None
                if self.num_micro_batches > 1:
                    self.grad_scheme.turn_off_warmup(param_name)

        return loss
