import numpy as np
import torch
import random
from quant.quant_layer import QuantModule, lp_loss
from quant.quant_model import Quant_Model
from quant.quant_block import BaseQuantBlock, QuantAttnBlock, QuantQKMatMul, QuantSMVMatMul, QuantAttentionBlock, QuantBasicTransformerBlock
from quant.adaptive_rounding import AdaRoundQuantizer
from quant_control.data_utils import save_inp_oup_data
import torch.nn.functional as F

def block_w_reconstruction(model: Quant_Model, block: BaseQuantBlock, cali_data: torch.Tensor, cali_t: torch.Tensor,
                         batch_size: int = 32, batch_size1: int = 1024, iters: int = 20000, weight: float = 0.01, opt_mode: str = 'mse',
                         asym: bool = False, b_range: tuple = (20, 2),
                         warmup: float = 0.0, act_quant: bool = False, weight_quant: bool = False, lr_za: float = 1e-1, lr_a: float = 4e-5, lr_w=1e-2, lr_rw=1e-2, p: float = 2.0,
                         input_prob: float = 1.0, keep_gpu: bool = True, 
                         recon_rw: bool = False, recon_w: bool = False, recon_a: bool = False, recon_smooth: bool = False, lr_smooth=1e-5, 
                         ):
    round_mode = 'learned_hard_sigmoid'
    '''set quantizer'''
    # Replace weight quantizer to AdaRoundQuantizer
    w_para, scale_w_para, a_para, a_zero_para, smooth_para = [], [], [], [], []
    for module in block.modules():
        '''weight'''
        if isinstance(module, QuantModule):
            if module.stop_train:
                continue
            if module.split == 0:
                if recon_w:                                        
                    module.weight_quantizer.delta = torch.nn.Parameter(torch.tensor(module.weight_quantizer.delta))
                    scale_w_para += [module.weight_quantizer.delta]
            else :
                if recon_w: 
                    module.weight_quantizer.delta = torch.nn.Parameter(torch.tensor(module.weight_quantizer.delta))
                    module.weight_quantizer_0.delta = torch.nn.Parameter(torch.tensor(module.weight_quantizer_0.delta))
                    scale_w_para += [module.weight_quantizer.delta]
                    scale_w_para += [module.weight_quantizer_0.delta]

            if recon_rw:
                module.weight = torch.nn.Parameter(torch.tensor(module.weight))
                module.weight.requires_grad = True
                w_para += [module.weight]
                if module.bias is not None:
                    module.bias = torch.nn.Parameter(torch.tensor(module.bias))
                    module.bias.requires_grad = True
                    w_para += [module.bias]
                module.weight_quantizer.delta.requires_grad = True

            if recon_smooth:
                module.smooth_quantizer.scales = torch.nn.Parameter(torch.tensor(module.smooth_quantizer.scales))
                smooth_para += [module.smooth_quantizer.scales]
        '''activation'''
        if isinstance(module, (QuantModule, BaseQuantBlock)):
            if act_quant and isinstance(module, QuantBasicTransformerBlock):
                module.attn1.act_quantizer_q.delta = torch.nn.Parameter(torch.tensor(module.attn1.act_quantizer_q.delta))
                module.attn1.act_quantizer_k.delta = torch.nn.Parameter(torch.tensor(module.attn1.act_quantizer_k.delta))
                module.attn1.act_quantizer_v.delta = torch.nn.Parameter(torch.tensor(module.attn1.act_quantizer_v.delta))
                module.attn1.act_quantizer_w.delta = torch.nn.Parameter(torch.tensor(module.attn1.act_quantizer_w.delta))
                module.attn2.act_quantizer_q.delta = torch.nn.Parameter(torch.tensor(module.attn2.act_quantizer_q.delta))
                module.attn2.act_quantizer_k.delta = torch.nn.Parameter(torch.tensor(module.attn2.act_quantizer_k.delta))
                module.attn2.act_quantizer_v.delta = torch.nn.Parameter(torch.tensor(module.attn2.act_quantizer_v.delta))
                module.attn2.act_quantizer_w.delta = torch.nn.Parameter(torch.tensor(module.attn2.act_quantizer_w.delta))         
                if recon_a:
                    a_para += [module.attn1.act_quantizer_q.delta]
                    a_para += [module.attn1.act_quantizer_k.delta]
                    a_para += [module.attn1.act_quantizer_v.delta]
                    a_para += [module.attn1.act_quantizer_w.delta]
                    a_para += [module.attn2.act_quantizer_q.delta]
                    a_para += [module.attn2.act_quantizer_k.delta]
                    a_para += [module.attn2.act_quantizer_v.delta]
                    a_para += [module.attn2.act_quantizer_w.delta]
                    a_zero_para += [module.attn1.act_quantizer_q.zero_point]
                    a_zero_para += [module.attn1.act_quantizer_k.zero_point]
                    a_zero_para += [module.attn1.act_quantizer_v.zero_point]
                    a_zero_para += [module.attn1.act_quantizer_w.zero_point]
                    a_zero_para += [module.attn2.act_quantizer_q.zero_point]
                    a_zero_para += [module.attn2.act_quantizer_k.zero_point]
                    a_zero_para += [module.attn2.act_quantizer_v.zero_point]
                    a_zero_para += [module.attn2.act_quantizer_w.zero_point]
                    module.attn1.act_quantizer_q.is_training = True
                    module.attn1.act_quantizer_k.is_training = True
                    module.attn1.act_quantizer_v.is_training = True
                    module.attn1.act_quantizer_w.is_training = True
                    module.attn2.act_quantizer_q.is_training = True
                    module.attn2.act_quantizer_k.is_training = True
                    module.attn2.act_quantizer_v.is_training = True
                    module.attn2.act_quantizer_w.is_training = True 
            if act_quant and module.act_quantizer.delta is not None and module.act_quantizer.stop_train is False:
                if module.split == 0:
                    module.act_quantizer.delta = torch.nn.Parameter(torch.tensor(module.act_quantizer.delta))
                    module.act_quantizer.zero_point = torch.nn.Parameter(torch.tensor(module.act_quantizer.zero_point))
                    if recon_a:
                        a_para += [module.act_quantizer.delta]
                        a_zero_para += [module.act_quantizer.zero_point]
                        module.act_quantizer.is_training = True
                else:
                    module.act_quantizer.delta = torch.nn.Parameter(torch.tensor(module.act_quantizer.delta))
                    module.act_quantizer_0.delta = torch.nn.Parameter(torch.tensor(module.act_quantizer_0.delta))
                    module.act_quantizer.zero_point = torch.nn.Parameter(torch.tensor(module.act_quantizer.zero_point))
                    module.act_quantizer_0.zero_point = torch.nn.Parameter(torch.tensor(module.act_quantizer_0.zero_point))
                    if recon_a:
                        a_para += [module.act_quantizer.delta]
                        a_para += [module.act_quantizer_0.delta]
                        a_zero_para += [module.act_quantizer.zero_point]
                        a_zero_para += [module.act_quantizer_0.zero_point]
                        module.act_quantizer.is_training = True
                        module.act_quantizer_0.is_training = True

    list_a_scales = []
    list_rw = []
    w_opt, scale_w_opt, a_opt, a_zero_opt, smooth_opt = None, None, None, None, None
    w_scheduler, scale_w_scheduler, a_scheduler, a_zero_scheduler, smooth_scheduler = None, None, None, None, None
    if len(w_para) != 0:
        avg_w_para = torch.mean(torch.tensor([torch.mean(w.abs()) for w in w_para])).item()
        list_rw.append(avg_w_para)
        w_opt = torch.optim.Adam(w_para, lr=lr_rw * avg_w_para)
        w_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(w_opt, T_max=iters, eta_min=0.)
    if len(scale_w_para) != 0:
        scale_w_opt = torch.optim.Adam(scale_w_para, lr=lr_w)
        scale_w_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(scale_w_opt, T_max=iters, eta_min=0.)
    if len(a_para) != 0:
        avg_a_scales = torch.mean(torch.tensor([torch.mean(a_scale) for a_scale in a_para])).item()
        list_a_scales.append(avg_a_scales)
        a_opt = torch.optim.Adam(a_para, lr=lr_a * avg_a_scales)
        a_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(a_opt, T_max=iters, eta_min=0.)
    if len(a_zero_para) != 0:
        avg_a_scales1 = torch.mean(torch.tensor([torch.mean(a_scale1) for a_scale1 in a_zero_para])).item()
        a_zero_opt = torch.optim.Adam(a_zero_para, lr=lr_za)
        a_zero_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(a_zero_opt, T_max=iters, eta_min=0.)
    if len(smooth_para) != 0:
        smooth_opt = torch.optim.Adam(smooth_para, lr=lr_a)
        smooth_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(smooth_opt, T_max=iters, eta_min=0.)

    loss_func = torch.nn.MSELoss()
    '''get input and set scale'''
    # cut cali groups
    batch_cali = 256
    batch_cali_data = []
    batch_cali_t = []
    for i in range(int(cali_t.size(0)/batch_cali)):
        batch_cali_data.append([_[i * batch_cali : (i + 1) * batch_cali] for _ in cali_data])
        batch_cali_t.append(cali_t[i * batch_cali : (i + 1) * batch_cali])
    # get hooks
    all_cached_inps_x = []
    all_cached_inps_time = []
    all_cached_syms_x = []
    all_cached_syms_time = []
    all_cached_outs = []
    all_cali_t = []
    for i in  range(len(batch_cali_data)):
        cali_data = batch_cali_data[i]
        cali_t_1 = batch_cali_t[i]
        Resblock, cached_inps, cached_outs = save_inp_oup_data(model, block, cali_data, cali_t_1, asym, act_quant=act_quant, weight_quant=weight_quant, 
                                                                batch_size=batch_size1, input_prob=True, keep_gpu=keep_gpu)
        batch_cali_data[i] = None
        batch_cali_t[i] = None
        del cali_data
        all_cali_t_1 = []
        for j in range(int(batch_cali/batch_size1)):
            all_cali_t_1.append(torch.cat([cali_t_1[j * batch_size1 : (j + 1) * batch_size1]] * 2))
        all_cali_t.append(torch.cat(all_cali_t_1))
        if Resblock:
            all_cached_inps_x.append(cached_inps[0][0])
            all_cached_inps_time.append(cached_inps[0][1])
            all_cached_syms_x.append(cached_inps[1][0])
            all_cached_syms_time.append(cached_inps[1][1])
        else:
            all_cached_inps_x.append(cached_inps[0])
            all_cached_syms_x.append(cached_inps[1])
        all_cached_outs.append(cached_outs)
    cached_outs = torch.cat(all_cached_outs)
    cali_t = torch.cat(all_cali_t)
    del all_cached_outs, cached_inps
    if Resblock:
        cached_inps = [[torch.cat(all_cached_inps_x), torch.cat(all_cached_inps_time)]]
        del all_cached_inps_x, all_cached_inps_time
        cached_inps.append([torch.cat(all_cached_syms_x), torch.cat(all_cached_syms_time)])
        del all_cached_syms_x, all_cached_syms_time
    else:
        cached_inps = [torch.cat(all_cached_inps_x), torch.cat(all_cached_syms_x)]
        del all_cached_inps_x, all_cached_syms_x

    device = 'cuda'
    model.block_count = model.block_count + 1
    sz = cached_outs.size(0)
    out_loss_list = []
    for num_iter in range(iters):
        idx = torch.randperm(sz)[:batch_size]
        t = cali_t[idx].to(device)
        block.set_t(t=t)
        cur_out = cached_outs[idx].to(device)
        if Resblock:
            cur_inp, cur_sym = cached_inps[0][0][idx].to(device), cached_inps[1][0][idx].to(device)
            temb_cur_inp, temb_cur_sym = cached_inps[0][1][idx].to(device), cached_inps[1][1][idx].to(device)
            if input_prob <= 1.0:
                num_batch = cur_inp.size(0)
                sym_index = int(num_batch * (1-input_prob))
                if sym_index > 0:
                    indices = torch.randperm(num_batch)[:sym_index]
                    cur_inp[indices] = cur_sym[indices]
                    temb_cur_inp[indices] = temb_cur_sym[indices]
        else:
            cur_inp, cur_sym = cached_inps[0][idx].to(device), cached_inps[1][idx].to(device)
            if input_prob <= 1.0:
                num_batch = cur_inp.size(0)
                sym_index = int(num_batch * (1-input_prob))
                if sym_index > 0:
                    indices = torch.randperm(num_batch)[:sym_index]
                    cur_inp[indices] = cur_sym[indices]

        if w_opt:
            w_opt.zero_grad()
        if scale_w_opt:
            scale_w_opt.zero_grad()
        if a_opt:
            a_opt.zero_grad()
        if a_zero_opt:
            a_zero_opt.zero_grad()
        if smooth_opt:
            smooth_opt.zero_grad()

        if Resblock:
            out_quant = block(cur_inp, temb_cur_inp)
        else:
            out_quant = block(cur_inp)

        loss = loss_func(out_quant, cur_out)
        loss.backward()#retain_graph=True
        out_loss_list.append(loss.cpu().detach().numpy())

        if w_opt:
            w_opt.step()
        if scale_w_opt:
            scale_w_opt.step()
        if a_opt:
            a_opt.step()
        if a_zero_opt:
            a_zero_opt.step()
        if smooth_opt:
            smooth_opt.step()

        if w_scheduler:
            w_scheduler.step()
        if scale_w_scheduler:
            scale_w_scheduler.step()
        if a_scheduler:
            a_scheduler.step()
        if a_zero_scheduler:
            a_zero_scheduler.step()
        if smooth_scheduler:
            smooth_scheduler.step()

        if len(a_para) != 0:
            new_a_lr = a_scheduler.optimizer.param_groups[0]['lr']
            a_opt = torch.optim.Adam(a_para, lr=new_a_lr)
        if len(a_zero_para) != 0:
            new_a_lr = a_zero_scheduler.optimizer.param_groups[0]['lr']
            a_zero_opt = torch.optim.Adam(a_zero_para, lr=new_a_lr)
    torch.cuda.empty_cache()

    for module in block.modules():
        if isinstance(module, QuantModule):
            '''weight '''
            if module.split == 0:
                module.weight_quantizer.soft_targets = False
                module.act_quantizer.is_training = False
            else:
                module.weight_quantizer.soft_targets = False
                module.weight_quantizer_0.soft_targets = False
                module.act_quantizer.is_training = False
                module.act_quantizer_0.is_training = False
        if isinstance(module, QuantBasicTransformerBlock):
            module.attn1.act_quantizer_q.is_training = False
            module.attn1.act_quantizer_k.is_training = False
            module.attn1.act_quantizer_v.is_training = False
            module.attn1.act_quantizer_w.is_training = False
            module.attn2.act_quantizer_q.is_training = False
            module.attn2.act_quantizer_k.is_training = False
            module.attn2.act_quantizer_v.is_training = False
            module.attn2.act_quantizer_w.is_training = False 
