import numpy as np
import torch
import torch.nn as nn
from .hook import DataSaverHook, StopForwardException
from efficientvit.models.qnn import QuantLayer, QuantBlock
from efficientvit.models.qnn import QConv2d, QLinear
from .fake_quant import LSQFakeQuantize, LSQPlusFakeQuantize, MSAFinetuneFakeQuantize, QuantizeBase


def save_inp_oup_data(model, module, cali_data: list, store_inp=False, store_oup=False, bs: int = 32, keep_gpu: bool = True):

    device = next(model.parameters()).device
    data_saver = DataSaverHook(store_input=store_inp, store_output=store_oup, stop_forward=True)
    handle = module.register_forward_hook(data_saver)
    cached = [[], []]
    with torch.no_grad():
        for i in range(int(cali_data.size(0) / bs)):
            try:
                _ = model(cali_data[i * bs: (i + 1) * bs].to(device))
            except StopForwardException:
                pass
            if store_inp:
                if keep_gpu:
                    cached[0].append(data_saver.input_store[0].detach())
                else:
                    cached[0].append(data_saver.input_store[0].detach().cpu())
            if store_oup:
                if keep_gpu:
                    cached[1].append(data_saver.output_store.detach())
                else:
                    cached[1].append(data_saver.output_store.detach().cpu())
    if store_inp:
        cached[0] = torch.cat([x for x in cached[0]])
    if store_oup:
        cached[1] = torch.cat([x for x in cached[1]])
    handle.remove()
    torch.cuda.empty_cache()
    return cached


class LinearTempDecay:
    def __init__(self, t_max=20000, warm_up=0.2, start_b=20, end_b=2):
        self.t_max = t_max
        self.start_decay = warm_up * t_max
        self.start_b = start_b
        self.end_b = end_b

    def __call__(self, t):
        if t < self.start_decay:
            return self.start_b
        elif t > self.t_max:
            return self.end_b
        else:
            rel_t = (t - self.start_decay) / (self.t_max - self.start_decay)
            return self.end_b + (self.start_b - self.end_b) * max(0.0, (1 - rel_t))


class LossFunction:
    r'''loss function to calculate mse reconstruction loss and relaxation loss
    use some tempdecay to balance the two losses.
    '''

    def __init__(self,
                 module: QuantLayer or QuantBlock,
                 p: float = 2.,
                 name = None,
                 logger = None):

        self.module = module
        self.p = p
        self.name = name
        self.logger = logger

        self.count = 0

    def __call__(self, pred, tgt):
        """
        Compute the total loss for dsqfinetune

        :param pred: output from quantized model
        :param tgt: output from FP model
        :return: total loss function
        """
        self.count += 1
        loss = lp_loss(pred, tgt, p=self.p)

        if self.count % 200 == 0:
            self.logger.info('Loss:\t{:.3f}\tcount={}'.format(
                float(loss), self.count))
        return loss


def lp_loss(pred, tgt, p=2.0):
    """
    loss function
    """
    return (pred - tgt).abs().pow(p).sum(1).mean()


class LinearTempBeta:
    def __init__(self, 
                 module: QuantLayer or QuantBlock,
                 iters: int = 20000,
                 beta_range: tuple = (20, 2),
                 warm_up: float = 0.0):
        self.module = module
        self.temp_beta = LinearTempDecay(iters, warm_up=warm_up,
                                          start_b=beta_range[0], end_b=beta_range[1])
    
    def __call__(self, t):
        beta = self.temp_beta(t)
        for name, layer in self.module.named_modules():
            if isinstance(layer, (QLinear, QConv2d)):
                weight_quantizer = layer.weight_fake_quant
                weight_quantizer.init(beta)


def msafinetune_reconstruction(model, fp_model, module, fp_module, cali_data, config, model_name, logger):
    device = next(module.parameters()).device
    # get data first
    quant_inp, _ = save_inp_oup_data(model, module, cali_data, store_inp=True, store_oup=False, bs=config.batch_size, keep_gpu=config.keep_gpu)
    fp_inp, fp_oup = save_inp_oup_data(fp_model, fp_module, cali_data, store_inp=True, store_oup=True, bs=config.batch_size, keep_gpu=config.keep_gpu)
    # prepare for up or down tuning
    qw, w_para = [], []
    a_para = []
    # linear beta
    linear_beta = LinearTempBeta(module, iters=config.iters, beta_range=config.beta_range, warm_up=config.warm_up)
    for name, layer in module.named_modules():
        if isinstance(layer, (QLinear, QConv2d))  and layer.weight_fake_quant.fake_quant_enabled:
            weight_quantizer = layer.weight_fake_quant
            weight_quantizer.init(config.beta_range[0])
            logger.info('learn the weight for {}'.format(name))
            qw += [layer.module.weight]
            logger.info('learn the scale for {}'.format(name))
            w_para += [layer.weight_fake_quant.scale]
        if isinstance(layer, QuantizeBase) and 'act_fake_quantize' in name:
            layer.drop_prob = config.drop_prob
            if isinstance(layer, LSQFakeQuantize) and layer.fake_quant_enabled:
                a_para += [layer.scale]
            if isinstance(layer, LSQPlusFakeQuantize or MSAFinetuneFakeQuantize)  and layer.fake_quant_enabled:
                a_para += [layer.scale]
                a_para += [layer.zero_point]
    if len(a_para) != 0:
        a_opt = torch.optim.Adamax(a_para, lr=config.a_scale_lr)
        a_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(a_opt, T_max=config.iters, eta_min=0.)
    else:
        a_opt, a_scheduler = None, None
    qw_opt = torch.optim.Adamax(qw, lr=config.w_lr)
    qw_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(qw_opt, T_max=config.iters, eta_min=0.)
    w_opt = torch.optim.Adamax(w_para, lr=config.w_scale_lr)
    w_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(w_opt, T_max=config.iters, eta_min=0.)
    
    loss_func = LossFunction(module=module, p=config.p, name=model_name, logger=logger)

    '''start training'''
    logger.info('start tuning by msafintune')
    sz = quant_inp.size(0)
    for i in range(config.iters):
        linear_beta(i)  # update beta
        idx = torch.randint(0, sz, (config.batch_size,))
        if config.drop_prob < 1.0:
            cur_quant_inp = quant_inp[idx].to(device)
            cur_fp_inp = fp_inp[idx].to(device)
            cur_inp = torch.where(torch.rand_like(cur_quant_inp) < config.drop_prob, cur_quant_inp, cur_fp_inp)
        else:
            cur_inp = quant_inp[idx].to(device)
        cur_fp_oup = fp_oup[idx].to(device)
        if a_opt:
            a_opt.zero_grad()
        qw_opt.zero_grad()
        w_opt.zero_grad()
        cur_quant_oup = module(cur_inp)
        err = loss_func(cur_quant_oup, cur_fp_oup)
        err.backward()
        qw_opt.step()
        w_opt.step()
        if a_opt:
            a_opt.step()
        if qw_scheduler:
            qw_scheduler.step()
        if w_scheduler:
            w_scheduler.step()
        if a_scheduler:
            a_scheduler.step()
    torch.cuda.empty_cache()
    for name, layer in module.named_modules():
        if isinstance(layer, (QLinear, QConv2d)):
            weight_quantizer = layer.weight_fake_quant
            layer.module.weight.data = weight_quantizer.get_hard_value(layer.module.weight.data)
            weight_quantizer.finetune = False
        if isinstance(layer, QuantizeBase) and 'act_fake_quantize' in name:
            layer.drop_prob = 1.0