import torch
import torch.nn.functional as F

from lbmqt.qscheme import QScheme
from lbmqt.optim.optimizer import LowBitOptimizer
from lbmqt.conf import config
from lbmqt.utils import GOR


class SGD(LowBitOptimizer):
    def __init__(
        self,
        params,
        param_scheme,
        lr,
        momentum=0,
        dampening=0,
        weight_decay=0,
        nesterov=False,
        num_micro_batches=-1,
    ):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay)
        super(SGD, self).__init__(params, defaults, param_scheme)

        q_param_names = self.get_quantifiable_param_names()
        print(q_param_names)
        self.num_micro_batches = num_micro_batches
        self.mm_scheme = QScheme(
            name='momentum',
            param_names=q_param_names,
            bits=config.compression_bits_mm,
            group_size=config.group_size,
            enable=config.enable_quantize_mm,
            num_mode=config.numerical_mode_mm,
            stochastic=config.stochastic,
        )

    @torch.no_grad()
    def step(self, closure=None):
        r"""Performs a single optimization step with quantization.
        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        self.num_steps += 1

        for group in self.param_groups:
            for p in group['params']:
                
                # only update for some micro batches when self.num_micro_batches > 0
                if self.num_micro_batches > 0 and self.num_steps % self.num_micro_batches != 0:
                    continue

                param_name = self.get_param_name(p)
                param_id = self.get_param_id(p)
                if self.grad_accumulator[param_id] is None:
                    continue

                state = self.state[param_id]
                if len(state) == 0:
                    state['momentum_buffer'] = None

                # get the full-precision gradient
                # do not consider dampening, nesterov
                grad = self.grad_accumulator[param_id]
                fp_grad = self.grad_scheme.dequantize_data(grad, param_name)
                fp_p = self.param_scheme.dequantize_data(p.data, param_name)
                # print(f'in optimizer before update, fp_p {param_name}: {fp_p}')
                # print(f'in optimizer before update, fp_grad {param_name}: {fp_grad}')
                lr = group['lr']
                momentum = group['momentum']
                weight_decay = group['weight_decay']

                # compute weight decay contribution
                if weight_decay != 0:
                    if self.num_micro_batches > 0:
                        fp_grad = fp_grad.add(fp_p, alpha=self.num_micro_batches * weight_decay)
                    else:
                        fp_grad = fp_grad.add(fp_p, alpha=weight_decay)

                # update full-precision parameters
                if momentum != 0:
                    buf = state['momentum_buffer']
                    if buf is not None:
                        buf = self.mm_scheme.dequantize_data(buf, param_name)
                        update = buf.mul_(momentum).add_(fp_grad) # this will also update the state['momentum_buffer']
                    else:
                        update = fp_grad

                    # have any degradation when first store then use dequantized update? Maybe No.
                    # state['momentum_buffer'] = quantize_data(update, self.mm_scheme, param_name)
                    # update = dequantize_data(state['momentum_buffer'], self.mm_scheme, param_name)
                else:
                    update = fp_grad

                # debug quantization difference
                if config.debug_quantization_difference and self.param_scheme.is_quantifiable(param_name):
                    norm_ratio = (update.norm() / fp_p.norm()).cpu().item()
                    GOR.log_metric(param_name, "q_diff_norm", norm_ratio)
                    GOR.log_metric(param_name, "q_diff_norm_with_lr", norm_ratio * lr)
                    old_p = p.data.clone()

                fp_p.data.add_(update, alpha=-lr)

                # reset quantization config for parameters and momentum at(before) quantization storing point
                if config.steps_requantization_p > 0 and self.num_steps % config.steps_requantization_p == 0 and GOR.epoch < config.epochs_fix_scale:
                    self.param_scheme.reset_quantization_state(param_name)
                if momentum != 0 and config.steps_requantization_mm > 0 and self.num_steps % config.steps_requantization_mm == 0:
                    self.mm_scheme.reset_quantization_state(param_name)
                
                # store
                if config.debug_quantization_difference and self.param_scheme.is_quantifiable(param_name):
                    # old_p = p.data
                    new_p = self.param_scheme.quantize_data(fp_p, param_name,
                        dilate=config.params_freeze_dilation if GOR.epoch >= config.epochs_fix_scale else 1)
                    bin_change_num = ((new_p.data - old_p.data).abs().sum()).cpu().item()
                    bin_change_ratio = bin_change_num / fp_p.numel()
                    GOR.log_metric(param_name, "q_diff_bin_change_num", bin_change_num)
                    GOR.log_metric(param_name, "q_diff_bin_change_ratio", bin_change_ratio)

                    p_change_num = ((new_p.data - old_p.data).abs() > 0).sum().cpu().item()
                    p_change_ratio = p_change_num / fp_p.numel()
                    GOR.log_metric(param_name, "q_diff_p_change_num", p_change_num)
                    GOR.log_metric(param_name, "q_diff_p_change_ratio", p_change_ratio)
                    
                    b = self.param_scheme.defaults['bits']
                    hit_max_num = (new_p.data.abs() == (2 ** (b - 1)) - 1).sum().cpu().item()
                    hit_max_ratio = hit_max_num / fp_p.numel()
                    GOR.log_metric(param_name, "q_diff_hit_max_num", hit_max_num)
                    GOR.log_metric(param_name, "q_diff_hit_max_ratio", hit_max_ratio)

                    if config.debug_proba_hist and self.param_scheme.is_quantized(param_name):
                        hist = self.param_scheme.calculate_bin_change_num(param_name, update, signed=True)
                        GOR.log_metric(param_name, "q_diff_proba_hist", hist)

                    ideal_update = (update * -lr).view(-1)
                    ideal_update_normalized = ideal_update / ideal_update.norm()
                    old_fp_p = self.param_scheme.dequantize_data(old_p.data, param_name)
                    new_fp_p = self.param_scheme.dequantize_data(new_p.data, param_name)
                    real_update = (new_fp_p - old_fp_p).view(-1)
                    real_update_normalized = real_update / real_update.norm()
                    assert ideal_update_normalized.shape == real_update_normalized.shape
                    cos_sim = ideal_update_normalized.dot(real_update_normalized).cpu().item()
                    GOR.log_metric(param_name, "q_diff_cos_sim_update", cos_sim)

                    p.data = new_p.data

                    del old_p
                    del new_p
                else:
                    p.data = self.param_scheme.quantize_data(fp_p, param_name, 
                        dilate=config.params_freeze_dilation if GOR.epoch >= config.epochs_fix_scale else 1)
                if momentum != 0:
                    state['momentum_buffer'] = self.mm_scheme.quantize_data(update, param_name)
                    del buf

                # post-process
                del fp_p
                del fp_grad
                del update

                self.grad_accumulator[param_id] = None
                if self.num_micro_batches > 1:
                    self.grad_scheme.turn_off_warmup(param_name)

        return loss

    # not used
    def state_dict(self):
        ret = {
            'optimizer': super(SGD, self).state_dict(),
            'grad_scheme': self.grad_scheme,
            'mm_scheme': self.mm_scheme,
        }
        return ret

    def load_state_dict(self, state_dict):
        super(SGD, self).load_state_dict(state_dict['optimizer'])
        self.grad_scheme = state_dict['grad_scheme']
        self.mm_scheme = state_dict['mm_scheme']
