import math
import logging
from functools import partial
from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union, List

from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit import Final

from timm.layers import to_2tuple

import matplotlib.pyplot as plt

import numpy as np
from .probe import probe
from .token_probe import norm_probing_not_sorted
from .token_select import token_select
import pandas as pd
import os

def round_pass(x):
    y = x.round()
    y_grad = x
    return y.detach() - y_grad.detach() + y_grad

class Quantizer():
    def __init__(self, N_bits: int, type: str = "per_tensor", signed: bool = True, symmetric: bool = True):
        super().__init__()
        self.N_bits = N_bits
        self.signed = signed
        self.symmetric = symmetric
        self.q_type = type
        self.minimum_range = 1e-6
        
        if self.N_bits is None:
            return 

        if self.signed:
            self.Qn = - 2 ** (self.N_bits - 1)
            self.Qp = 2 ** (self.N_bits - 1) - 1
            
        else:
            self.Qn = 0
            self.Qp = 2 ** self.N_bits - 1

    def __call__(self, x):  
        return self.forward(x)

    def forward(self, x): 
        if self.N_bits is None:
            return x, 1
            # raise ValueError("Quantization module is called in full precision")

        if self.symmetric:
            if self.q_type == 'per_tensor': 
                max_x = x.abs().max()
            elif self.q_type == 'per_token': #토큰 별 가장 큰 값
                max_x = x.abs().amax(dim=-1, keepdim=True)              
            elif self.q_type == 'per_channel': #채널별 가장 큰 값 
                max_x = x.abs().amax(dim=0, keepdim=True)
            max_x = max_x.clamp_(self.minimum_range)
            scale = max_x / self.Qp
            x = x / scale 
            x = round_pass(x)
            
        else: #Asymmetric
            if self.q_type == 'per_tensor': 
                min_x = x.min().detach()
                max_x = x.max().detach()
            elif self.q_type == 'per_token': 
                min_x = x.min(dim=-1, keepdim=True).detach()
                max_x = x.max(dim=-1, keepdim=True).detach()
            elif self.q_type == 'per_channel': 
                min_x = x.min(dim=0, keepdim=True).detach()
                max_x = x.max(dim=0, keepdim=True).detach()

            range_x = (max_x - min_x).detach().clamp_(min=self.minimum_range)
            scale = range_x / (self.Qp - self.Qn)
            zero_point = torch.round((min_x / scale) - self.Qn)
            x = (x / scale) + zero_point
            x = round_pass(x.clamp_(self.Qn, self.Qp))

        return x, scale


class QuantAct(nn.Module):
    def __init__(self, 
                 N_bits: int, 
                 type: str , 
                 signed: bool = True, 
                 symmetric: bool = True):
        super(QuantAct, self).__init__()
        self.quantizer = Quantizer(N_bits=N_bits, type = type, signed=signed, symmetric=symmetric)

    def forward(self, x):
        q_x, s_qx = self.quantizer(x)
        return q_x, s_qx

class Quantized_Linear(nn.Linear):
    def __init__(self, weight_quantize_module: Quantizer, act_quantize_module: Quantizer, weight_grad_quantize_module: Quantizer, act_grad_quantize_module: Quantizer,
                 in_features, out_features, abits, bias=True):
        super(Quantized_Linear, self).__init__(in_features, out_features, bias=bias)
        self.weight_quantize_module = weight_quantize_module
        self.act_quantize_module = act_quantize_module
        self.weight_grad_quantize_module = weight_grad_quantize_module
        self.act_grad_quantize_module = act_grad_quantize_module
        self.prefix_qmodule = Quantizer(abits, 'per_token')

    def forward(self,input, prefix_token_num=0):
        return _quantize_global.apply(prefix_token_num, input, self.weight, self.bias, self.weight_quantize_module,
                                      self.act_quantize_module, self.weight_grad_quantize_module, self.act_grad_quantize_module, self.prefix_qmodule)
    
class _quantize_global(torch.autograd.Function):
    @staticmethod
    def forward(ctx, prefix_token_num, x, w, bias=None,
                w_qmodule=None, a_qmodule=None, w_g_qmodule=None, a_g_qmodule=None, prefix_qmodule=None):
        #save for backward

        ctx.a_g_qmodule = a_g_qmodule
        ctx.w_g_qmodule = w_g_qmodule 
        ctx.has_bias = bias is not None

        B, S, C = x.shape[0], x.shape[1], x.shape[2]
        ctx.x_size = B,S,C
        x_2d = x.view(-1, C)
        
        #full precision
        if all(x is None for x in (w_qmodule.N_bits, a_qmodule.N_bits, w_g_qmodule.N_bits, a_g_qmodule.N_bits)):
            ctx.fullprecision = True
            output = torch.matmul(x_2d, w.t())
            ctx.weight = w
            ctx.activation = x_2d
            if bias is not None:
                output += bias.unsqueeze(0).expand_as(output)
            return output.view(B, S, -1)
        else: 
            ctx.fullprecision = False

        #Quantization 
        if prefix_token_num == 0:
            input_quant, s_input_quant = a_qmodule(x_2d)
            weight_quant, s_weight_quant = w_qmodule(w)
            ctx.weight = (weight_quant, s_weight_quant)
            ctx.activation = (input_quant, s_input_quant) if w_g_qmodule.N_bits is not None else x_2d
            output = torch.matmul(input_quant, weight_quant.t())
            s_o = s_input_quant * s_weight_quant
            output = output * s_o
            if bias is not None:
                output += bias.unsqueeze(0).expand_as(output)
            return output.view(B, S, -1)
        else: 
            prefix_token = x[:, :(prefix_token_num + 1)] #[32, 9, 768]
            patch_x = x[:, (prefix_token_num + 1):] #[32, 196, 768]

            prefix_token = prefix_token.reshape(-1, C) #[288, 768]
            patch_x = patch_x.reshape(-1, C) #[6272, 768]
            q_prefix_token, s_prefix_token = prefix_qmodule(prefix_token) #per-token 
            q_patch_x, s_patch_x = a_qmodule(patch_x)
            q_prefix_token = q_prefix_token.reshape(B,-1, C)
            q_patch_x = q_patch_x.reshape(B,-1,C)
            input_quant = torch.cat((q_prefix_token, q_patch_x), dim=1)

            if a_qmodule.q_type == 'per_token':
                s_prefix_token = s_prefix_token.reshape(B, -1) #[32, 9]
                s_patch_x = s_patch_x.reshape(B, -1) #[32, 196]
                s_input_quant = torch.cat((s_prefix_token, s_patch_x),dim=1)
            elif a_qmodule.q_type == 'per_tensor':
                s_prefix_token = s_prefix_token.reshape(B, -1) #[32, 9]
                s_patch_x = s_patch_x.expand(B, S-prefix_token_num-1) #[32, 196]
                s_input_quant = torch.cat((s_prefix_token, s_patch_x), dim=1) #[32, 205]

            ctx.activation = (input_quant, s_input_quant) if w_g_qmodule.N_bits is not None else x_2d
            weight_quant, s_weight_quant = w_qmodule(w) 
            ctx.weight = (weight_quant, s_weight_quant)

            s_o = s_weight_quant * s_input_quant
            output = torch.matmul(input_quant, weight_quant.t())
            output = output.view(B, S, -1) #[32, 205, 768])
            s_o = s_o.unsqueeze(-1).expand(-1, -1, output.shape[2]) 
            output = output * s_o
            if bias is not None:
                output += bias.unsqueeze(0).expand_as(output)
            return output

    @staticmethod
    def backward(ctx, g_3D):
        g_2D = g_3D.reshape(-1, g_3D.size(-1)) #reshape to 2D
        grad_X = grad_W = grad_bias = None
        B,S,C = ctx.x_size
        
        if ctx.fullprecision:
            w = ctx.weight
            x = ctx.activation
            grad_W = torch.matmul(g_2D.t(), x)
            grad_X = torch.matmul(g_2D, w)
            grad_X = grad_X.view(B,S,-1)
            if ctx.has_bias:
                grad_bias = g_2D.sum(dim=0)
            else:
                grad_bias = None
        else:
            q_w, s_w = ctx.weight
            a_g_qmodule = ctx.a_g_qmodule
            w_g_qmodule = ctx.w_g_qmodule
            a_g_2D_quant, a_s_g_2D_quant = a_g_qmodule(g_2D)
            
            if w_g_qmodule.N_bits is not None:
                w_g_2D_quant, w_s_g_2D_quant = w_g_qmodule(g_2D)
                (q_x, s_x) = ctx.activation
            else:
                x = ctx.activation
            grad_X = torch.matmul(a_g_2D_quant, q_w)
            s_grad_X = a_s_g_2D_quant * s_w
            grad_X = grad_X * s_grad_X

            #Weigth Gradient
            if w_g_qmodule.N_bits is None: 
                grad_W = torch.matmul(g_2D.t(), x)
            else: 
                q_x = q_x.reshape(-1,q_x.size(-1))
                s_x = s_x.reshape(-1,s_x.size(-1))
                grad_W = torch.matmul(w_g_2D_quant.t(), q_x) #([768, 3072])
                s_grad_W = w_s_g_2D_quant * s_x
                grad_W = grad_W * s_grad_W

            if ctx.has_bias:
                grad_bias = g_2D.sum(dim=0)
            else:
                grad_bias = None
            grad_X = grad_X.view(B,S,-1)
        return None, None, None, None, None, None, grad_X, grad_W, grad_bias, None, None, None, None, None
        
        