import torch
import random
from quant.quant_layer import QuantModule, StraightThrough, lp_loss
from quant.quant_model import Quant_Model
from quant.adaptive_rounding import AdaRoundQuantizer
from quant.data_utils import save_inp_oup_data
import torch.nn.functional as F

def layer_w_reconstruction(model: Quant_Model, layer: QuantModule, cali_data: torch.Tensor, cali_t: torch.Tensor,
                         batch_size: int = 32, batch_size1: int = 1024, iters: int = 20000, weight: float = 0.001, opt_mode: str = 'mse',
                         asym: bool = False, b_range: tuple = (20, 2),
                         warmup: float = 0.0, act_quant: bool = False, weight_quant: bool = False, lr_za: float = 1e-1, lr_a: float = 4e-5, lr_w=1e-2, lr_rw=1e-2, p: float = 2.0,
                         input_prob: float = 1.0, keep_gpu: bool = True, 
                         recon_rw: bool = False, recon_w: bool = False, recon_a: bool = False, recon_smooth: bool = False, lr_smooth=1e-5, 
                         ):
    if layer.stop_train:
        model.block_count = model.block_count + 1
        return 0
        
    round_mode = 'learned_hard_sigmoid'
    '''set quantizer'''
    # Replace weight quantizer to AdaRoundQuantizer
    w_para, a_para, a_zero_para, smooth_para, scale_w_para = [], [], [], [], []
    '''weight'''
    if layer.split == 0:
        if recon_w:                                        
            layer.weight_quantizer.delta = torch.nn.Parameter(torch.tensor(layer.weight_quantizer.delta))
            scale_w_para += [layer.weight_quantizer.delta]
    else:
        if recon_w: 
            layer.weight_quantizer.delta = torch.nn.Parameter(torch.tensor(layer.weight_quantizer.delta))
            layer.weight_quantizer_0.delta = torch.nn.Parameter(torch.tensor(layer.weight_quantizer_0.delta))
            scale_w_para += [layer.weight_quantizer.delta]
            scale_w_para += [layer.weight_quantizer_0.delta]

    if recon_rw:
        layer.weight = torch.nn.Parameter(torch.tensor(layer.weight))
        layer.weight.requires_grad = True
        layer.bias = torch.nn.Parameter(torch.tensor(layer.bias))
        layer.bias.requires_grad = True
        w_para += [layer.weight]
        w_para += [layer.bias]
        layer.weight_quantizer.delta.requires_grad = True

    if recon_smooth:
        layer.smooth_quantizer.scales = torch.nn.Parameter(torch.tensor(layer.smooth_quantizer.scales))
        smooth_para += [layer.smooth_quantizer.scales]

    '''activation'''
    if act_quant and layer.act_quantizer.delta is not None and layer.act_quantizer.stop_train is False:
        if layer.split == 0:
            layer.act_quantizer.delta = torch.nn.Parameter(torch.tensor(layer.act_quantizer.delta))
            layer.act_quantizer.zero_point = torch.nn.Parameter(torch.tensor(layer.act_quantizer.zero_point))
            if recon_a:
                a_para += [layer.act_quantizer.delta]
                a_zero_para += [layer.act_quantizer.zero_point]
                layer.act_quantizer.is_training = True
        else:
            layer.act_quantizer.delta = torch.nn.Parameter(torch.tensor(layer.act_quantizer.delta))
            layer.act_quantizer_0.delta = torch.nn.Parameter(torch.tensor(layer.act_quantizer_0.delta))
            layer.act_quantizer.zero_point = torch.nn.Parameter(torch.tensor(layer.act_quantizer.zero_point))
            layer.act_quantizer_0.zero_point = torch.nn.Parameter(torch.tensor(layer.act_quantizer_0.zero_point))
            if recon_a:
                a_para += [layer.act_quantizer.delta]
                a_para += [layer.act_quantizer_0.delta]
                a_zero_para += [layer.act_quantizer.zero_point]
                a_zero_para += [layer.act_quantizer_0.zero_point]
                layer.act_quantizer.is_training = True
                layer.act_quantizer_0.is_training = True

    list_a_scales = []
    list_a_scales1 = []    
    list_rw = []
    w_opt, scale_w_opt, a_opt, a_zero_opt, smooth_opt = None, None, None, None, None
    w_scheduler, scale_w_scheduler, a_scheduler, a_zero_scheduler, smooth_scheduler = None, None, None, None, None
    if len(w_para) != 0:
        avg_w_para = torch.mean(torch.tensor([torch.mean(w.abs()) for w in w_para])).item()
        list_rw.append(avg_w_para)
        w_opt = torch.optim.Adam(w_para, lr=lr_rw * avg_w_para)
        w_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(w_opt, T_max=iters, eta_min=0.)
    if len(scale_w_para) != 0:
        scale_w_opt = torch.optim.Adam(scale_w_para, lr=lr_w)
        scale_w_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(scale_w_opt, T_max=iters, eta_min=0.)
    if len(a_para) != 0:
        avg_a_scales = torch.mean(torch.tensor([torch.mean(a_scale) for a_scale in a_para])).item()
        list_a_scales.append(avg_a_scales)
        a_opt = torch.optim.Adam(a_para, lr=lr_a * avg_a_scales)
        a_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(a_opt, T_max=iters, eta_min=0.)
    if len(a_zero_para) != 0:
        avg_a_scales1 = torch.mean(torch.tensor([torch.mean(a_scale1) for a_scale1 in a_zero_para])).item()
        a_zero_opt = torch.optim.Adam(a_zero_para, lr=lr_za)
        a_zero_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(a_zero_opt, T_max=iters, eta_min=0.)
    if len(smooth_para) != 0:
        smooth_opt = torch.optim.Adam(smooth_para, lr=lr_smooth)
        smooth_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(smooth_opt, T_max=iters, eta_min=0.)

    loss_func = torch.nn.MSELoss()
    '''get input and set scale'''
    # cut cali groups
    batch_cali = 512
    batch_cali_data = []
    batch_cali_t = []
    for i in range(int(cali_t.size(0)/batch_cali)):
        batch_cali_data.append([_[i * batch_cali : (i + 1) * batch_cali] for _ in cali_data])
        batch_cali_t.append(cali_t[i * batch_cali : (i + 1) * batch_cali])
    # get hooks
    all_cached_inps_x = []
    all_cached_syms_x = []
    all_cached_outs = []
    for i in  range(len(batch_cali_data)):
        cali_data = batch_cali_data[i]
        cali_t_1 = batch_cali_t[i]
        Resblock, cached_inps, cached_outs = save_inp_oup_data(model, layer, cali_data, cali_t_1, asym, act_quant=act_quant, weight_quant=weight_quant, 
                                                                batch_size=batch_size1, input_prob=True, keep_gpu=keep_gpu)
        batch_cali_data[i] = None
        batch_cali_t[i] = None
        del cali_data, cali_t_1
        torch.cuda.empty_cache()
        all_cached_inps_x.append(cached_inps[0])
        all_cached_syms_x.append(cached_inps[1])
        all_cached_outs.append(cached_outs)
    cached_outs = torch.cat(all_cached_outs)
    del all_cached_outs, cached_inps
    cached_inps = [torch.cat(all_cached_inps_x), torch.cat(all_cached_syms_x)]
    del all_cached_inps_x, all_cached_syms_x

    device = 'cuda'
    sz = cached_outs.size(0)
    model.block_count = model.block_count + 1
    out_loss_list = []
    for num_iter in range(iters):
        idx = torch.randperm(sz)[:batch_size]
        t = cali_t[idx].to(device)
        layer.set_t(t=t)
        cur_out = cached_outs[idx].to(device)
        cur_inp, cur_sym = cached_inps[0][idx].to(device), cached_inps[1][idx].to(device)

        if input_prob <= 1.0:
            num_batch = cur_inp.size(0)
            sym_index = int(num_batch * (1-input_prob))
            if sym_index > 0:
                indices = torch.randperm(num_batch)[:sym_index]
                cur_inp[indices] = cur_sym[indices]

        if w_opt:
            w_opt.zero_grad()
        if scale_w_opt:
            scale_w_opt.zero_grad()
        if a_opt:
            a_opt.zero_grad()
        if a_zero_opt:
            a_zero_opt.zero_grad()
        if smooth_opt:
            smooth_opt.zero_grad()

        out_quant = layer(cur_inp)

        loss = loss_func(out_quant, cur_out)
        if len(w_para) != 0 or len(a_para) != 0:
            loss.backward()#retain_graph=True
        out_loss_list.append(loss.cpu().detach().numpy())

        if w_opt:
            w_opt.step()
        if scale_w_opt:
            scale_w_opt.step()
        if a_opt:
            a_opt.step()
        if a_zero_opt:
            a_zero_opt.step()
        if smooth_opt:
            smooth_opt.step()

        if w_scheduler:
            w_scheduler.step()
        if scale_w_scheduler:
            scale_w_scheduler.step()
        if a_scheduler:
            a_scheduler.step()
        if a_zero_scheduler:
            a_zero_scheduler.step()
        if smooth_scheduler:
            smooth_scheduler.step()

        if len(a_para) != 0:
            new_a_lr = a_scheduler.optimizer.param_groups[0]['lr']
            a_opt = torch.optim.Adam(a_para, lr=new_a_lr)
        if len(a_zero_para) != 0:
            new_a_lr = a_zero_scheduler.optimizer.param_groups[0]['lr']
            a_zero_opt = torch.optim.Adam(a_zero_para, lr=new_a_lr)
    torch.cuda.empty_cache()

    if layer.split == 0:
        layer.weight_quantizer.soft_targets = False
        layer.act_quantizer.is_training = False
    else:
        layer.weight_quantizer.soft_targets = False
        layer.weight_quantizer_0.soft_targets = False
        layer.act_quantizer.is_training = False
        layer.act_quantizer_0.is_training = False
