import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Union
import math
from utils.attention_processor_quant import batch_mse,batch_max,lp_loss,round_ste
import matplotlib.pyplot as plt
import numpy as np
from utils.quant_layer import QuantLayerNormal
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable

def round(x, rounding='deterministic'):
    assert(rounding in ['deterministic', 'stochastic'])
    if rounding == 'stochastic':
        x_floor = x.floor()
        return x_floor + torch.bernoulli(x - x_floor)
    else:
        return x.round()



def get_shift_and_sign(x, rounding='deterministic'):
    sign = torch.sign(x)

    x_abs = torch.abs(x)
    if rounding == "floor":
        shift = torch.floor(torch.log(x_abs) / np.log(2))
    else:
        shift = round_ste(torch.log(x_abs) / np.log(2))

    return shift, sign



def round_power_of_2(x, rounding='deterministic', q_bias=None, scale=None):
    if q_bias is not None:
        q_bias = q_bias.unsqueeze(1).expand_as(x)
        x = x - q_bias
    if scale is not None:
        scale = scale.unsqueeze(1).expand_as(x)
        x = x / scale
    shift, sign = get_shift_and_sign(x, rounding)
    x_rounded = (2.0 ** shift) * sign
    if scale is not None:
        x_rounded = x_rounded * scale
    if q_bias is not None:
        x_rounded = x_rounded + q_bias
    return x_rounded



def additive_power_of_2(x, log_s):
    sign = torch.sign(x)
    x_abs = torch.abs(x)

    shift = round_ste(torch.log(x_abs) / np.log(2) + log_s)

    x_rounded = (2.0 ** shift) * sign

    return x_rounded





class StraightThrough(nn.Module):
    def __init__(self, channel_num: int = 1):
        super().__init__()

    def forward(self, input):
        return input

def check_in_special(ss, special_list):
    for itm in special_list:
        if itm in ss:
            return True
    return False


class QuantUnetWarp(nn.Module):
    def __init__(self, model: nn.Module, args):
        super().__init__()
        self.model = model
        lora_quant_params = {'n_bits': args.nbits, 'lora_bits':args.nbits, 'symmetric':False, 'channel_wise':True, 'rank':4}
        other_quant_params = {'n_bits': args.nbits, 'symmetric': False, 'channel_wise': True, 'scale_method': 'mse'}
        special_list = ['to_q','to_k','to_v','to_out'] 
        self.quant_module_refactor(self.model, lora_quant_params, other_quant_params, special_list)


    def quant_module_refactor(self, module: nn.Module, lora_quant_params, other_quant_params, sepcial_list, prev_name=''):
        for name, child_module in module.named_children():
            tmp_name=prev_name+'_'+name
            if isinstance(child_module, (nn.Conv2d, nn.Conv1d, nn.Linear))  \
                and not ('downsample' in name and 'conv' in name):
                if check_in_special(tmp_name,sepcial_list):
                    setattr(module, name, QuantLayerProcesser(child_module, **lora_quant_params))
                    #print(f'{tmp_name} are quantinized with additional intlora')
                else:
                    setattr(module, name, QuantLayerNormal(child_module, other_quant_params))
                    print(f'{tmp_name} are quantinized')
            elif isinstance(child_module, StraightThrough):
                continue
            else:
                self.quant_module_refactor(child_module, lora_quant_params, other_quant_params, sepcial_list, prev_name=tmp_name)


    def set_lora_weights_qat(self):
        for name, module in self.named_modules():
            if isinstance(module, QuantLayerProcesser):
                module.quant_lora_weights = True

    def set_inference_ptq_on(self):
        for name, module in self.named_modules():
            if isinstance(module, QLoRAProcessor):
                module.inference_ptq = True

    def set_inference_ptq_off(self):
        for name, module in self.named_modules():
            if isinstance(module, QLoRAProcessor):
                module.inference_ptq = False
                module.double_quant_init = False
                module.adapted_dequant_weight = None

    def forward(self, image, t, context=None):
        return self.model(image, t, context)




class QuantLayerProcesser(nn.Module):
    # The implementation of our IntLoRA
    def __init__(self, org_module, n_bits=8, lora_bits=8, symmetric=False,channel_wise=True, rank=4):
        super(QuantLayerProcesser, self).__init__()

        self.n_bits = n_bits
        self.lora_bits = lora_bits
        self.sym = symmetric
        self.scale_method = 'mse'
        self.always_zero = False
        self.n_levels = 2 ** self.n_bits if not self.sym else 2 ** (self.n_bits - 1) - 1
        self.channel_wise = channel_wise
        self.inited=False
        self.lora_levels = self.n_levels
        if isinstance(org_module, nn.Conv2d):
            self.fwd_kwargs = dict(stride=org_module.stride, padding=org_module.padding,
                                   dilation=org_module.dilation, groups=org_module.groups)
            self.fwd_func = F.conv2d
            self.in_features = org_module.in_channels
            self.out_features = org_module.out_channels
        else:
            self.fwd_kwargs = dict()
            self.fwd_func = F.linear
            self.in_features = org_module.in_features
            self.out_features = org_module.out_features


        # save original weights and bias and keep them intact
        self.ori_weight_shape = org_module.weight.shape

        self.ori_weight = org_module.weight.view(self.out_features,-1).data.clone() # reshape here
        self.ori_bias = None if org_module.bias is None else org_module.bias.data.clone()


        # quant lora quant here ===========================
        self.quant_lora_weights = True
        self.double_inited = False
        self.double_quant_delta = torch.nn.Parameter(torch.zeros(self.out_features,1),requires_grad=False)
        self.double_quant_zero_point = torch.nn.Parameter(torch.zeros(self.out_features,1),requires_grad=False)

        rank = rank
        r = rank
        lora_dropout = 0.0
        if lora_dropout > 0.0:
            self.lora_dropout_layer = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout_layer = nn.Identity()

        if isinstance(org_module, nn.Linear):
            self.loraA = nn.Linear(org_module.in_features, r, bias=False)
            self.loraB = nn.Linear(r, org_module.out_features, bias=False)
            nn.init.kaiming_uniform_(self.loraA.weight, a=math.sqrt(5))
            nn.init.zeros_(self.loraB.weight)
            m = torch.distributions.laplace.Laplace(loc=torch.tensor([0.]),scale=torch.tensor([0.5]))
            self.lora_A0B0 = m.sample((org_module.out_features,org_module.in_features))[:,:,0]
            self.lora_A0B0 = self.lora_A0B0.to(self.ori_weight.device).detach()


    def forward(self, input: torch.Tensor):
        # init the original weights to int =========================
        if self.inited is False:
            # for conv layer, we have already reshape them to [Cout, Cin]
            lora_A0B0_abs_max  = torch.minimum(self.lora_A0B0.max(dim=-1,keepdim=True)[0].abs(),self.lora_A0B0.min(dim=-1,keepdim=True)[0].abs()).detach()
            ori_weight_abs_max = torch.maximum(self.ori_weight.max(dim=-1,keepdim=True)[0].abs(),self.ori_weight.min(dim=-1,keepdim=True)[0].abs()).detach()
            self.lora_A0B0 = ((ori_weight_abs_max)**1.5/(lora_A0B0_abs_max+1e-8)**1.5)*self.lora_A0B0

            ori_weight = self.ori_weight - self.lora_A0B0
            delta, zero_point = self.init_quantization_scale(ori_weight, self.channel_wise,self.n_bits,self.sym)
            self.register_buffer('weight_quant_delta', delta) # TODO 1e-3
            self.register_buffer('weight_quant_zero_point', zero_point)
            ori_weight_round = round_ste(ori_weight / self.weight_quant_delta) + self.weight_quant_zero_point
            if self.sym:
                ori_weight_round = torch.clamp(ori_weight_round, -self.n_levels - 1, self.n_levels)
            else:
                ori_weight_round = torch.clamp(ori_weight_round, 0, self.n_levels - 1)

            # delete the FP weights and save the int weights
            self.ori_weight_round = ori_weight_round # int weight and keep it intact
            self.ori_weight = None
            torch.cuda.empty_cache()
            self.inited = True

        ori_weight_int = self.ori_weight_round - self.weight_quant_zero_point 

        # PETL for quant scale here ==================================
        if self.fwd_func is F.linear:
            lora_weight = (self.lora_A0B0 + (self.loraB.weight @ self.loraA.weight)) / \
                          torch.where(ori_weight_int == 0, torch.tensor(1).to(ori_weight_int.device), ori_weight_int)
            weight_updates = self.weight_quant_delta + lora_weight # broad-cast
        elif self.fwd_func is F.conv2d:
            lora_weight = self.loraB.weight.squeeze(-1).squeeze(-1) @ self.loraA.weight.permute(2, 3, 0,1)  ## (cout, r) @　(3, 3, r, cin)
            lora_weight = lora_weight.permute(2, 3, 0, 1) # TODO
            weight_updates = self.weight_quant_delta + lora_weight # broad-cast
        else:
            weight_updates = self.weight_quant_delta

        # run QAT for lora weigts here ============================
        if self.quant_lora_weights:
            # # ================ log2 quantization ================================
            weight_updates_sign = weight_updates.sign()
            weight_updates_abs = torch.abs(weight_updates)
            lora_shift = round_ste(torch.log2(weight_updates_abs+1e-16))
            lora_rounded = 2.0**lora_shift
            weight = weight_updates_sign * lora_rounded * ori_weight_int
            if torch.any(torch.isnan(weight_updates)):
                print('There is nan in the weight-updates for log2 quantization')

            # ================== INT x INT quantization ============================
            # if self.double_inited is False:
            #     delta, zero_point = self.init_quantization_scale(weight_updates, True, self.lora_bits)
            #     with torch.no_grad():
            #         self.double_quant_delta.copy_(delta) 
            #         self.double_quant_zero_point.copy_(zero_point)
            #     self.double_inited = True
            #
            # weight_updates_round = round_ste(weight_updates / self.double_quant_delta) + self.double_quant_zero_point
            # if self.sym:
            #     weight_updates_round = torch.clamp(weight_updates_round, -self.lora_levels - 1, self.lora_levels)
            # else:
            #     weight_updates_round = torch.clamp(weight_updates_round, 0, self.lora_levels - 1)
            
            # weight_updates_int = (weight_updates_round - self.double_quant_zero_point)
            # weight_int_mul = weight_updates_int * ori_weight_int # INT multiply 
            # weight = self.double_quant_delta * weight_int_mul 

        else:
            weight = weight_updates*ori_weight_int

        bias = self.ori_bias
        out = self.fwd_func(input, weight, bias, **self.fwd_kwargs)

        return out



    def init_quantization_scale(self, x: torch.Tensor, channel_wise: bool = False, n_bits: int = 8, sym: bool= False):
        n_levels = 2 ** n_bits if not sym else 2 ** (n_bits - 1) - 1
        delta, zero_point = None, None
        if channel_wise:
            x_clone = x.clone().detach()
            n_channels = x_clone.shape[0]
            if len(x.shape) == 4:
                x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0].max(dim=-1)[0]
            elif len(x.shape) == 3:
                x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0]
            else:
                x_max = x_clone.abs().max(dim=-1)[0]
            delta = x_max.clone()
            zero_point = x_max.clone()
            # determine the scale and zero point channel-by-channel
            if 'max' in self.scale_method:
                delta, zero_point = batch_max(x_clone.view(n_channels, -1), sym, 2 ** n_bits,
                                              self.always_zero)

            elif 'mse' in self.scale_method:
                delta, zero_point = batch_mse(x_clone.view(n_channels, -1), sym, 2 ** n_bits,
                                              self.always_zero)

            if len(x.shape) == 4:
                delta = delta.view(-1, 1, 1, 1)
                zero_point = zero_point.view(-1, 1, 1, 1)
            elif len(x.shape) == 3:
                delta = delta.view(-1, 1, 1)
                zero_point = zero_point.view(-1, 1, 1)
            else:
                delta = delta.view(-1, 1)
                zero_point = zero_point.view(-1, 1)
        else:
            # if self.leaf_param:
            #     self.x_min = x.data.min()
            #     self.x_max = x.data.max()

            if 'max' in self.scale_method:
                x_min = min(x.min().item(), 0)
                x_max = max(x.max().item(), 0)
                if 'scale' in self.scale_method:
                    x_min = x_min * (n_bits + 2) / 8
                    x_max = x_max * (n_bits + 2) / 8

                x_absmax = max(abs(x_min), x_max)
                if sym:
                    # x_min, x_max = -x_absmax if x_min < 0 else 0, x_absmax
                    delta = x_absmax / n_levels
                else:
                    delta = float(x.max().item() - x.min().item()) / (n_levels - 1)
                if delta < 1e-8:
                    delta = 1e-8

                zero_point = round(-x_min / delta) if not (sym or self.always_zero) else 0
                delta = torch.tensor(delta).type_as(x)

            elif self.scale_method == 'mse':
                x_max = x.max()
                x_min = x.min()
                best_score = 1e+10
                for i in range(80):
                    new_max = x_max * (1.0 - (i * 0.01))
                    new_min = x_min * (1.0 - (i * 0.01))
                    x_q = self.quantize(x, new_max, new_min,n_bits,sym)
                    score = lp_loss(x, x_q, p=2.4, reduction='all')
                    if score < best_score:
                        best_score = score
                        delta = (new_max - new_min) / (2 ** n_bits - 1) \
                            if not self.always_zero else new_max / (2 ** n_bits - 1)
                        zero_point = (- new_min / delta).round() if not self.always_zero else 0
            else:
                raise NotImplementedError

        return delta, zero_point

    def quantize(self, x, max, min,n_bits,sym):
        n_levels = 2 ** n_bits if not sym else 2 ** (n_bits - 1) - 1
        delta = (max - min) / (2 ** n_bits - 1) if not self.always_zero else max / (2 ** n_bits - 1)
        zero_point = (- min / delta).round() if not self.always_zero else 0
        # we assume weight quantization is always signed
        x_int = torch.round(x / delta)
        x_quant = torch.clamp(x_int + zero_point, 0,n_levels - 1)
        x_float_q = (x_quant - zero_point) * delta
        return x_float_q
