import torch
# from monitor_logger import monitor

# Function to compress input vector
# Returns vector of the same shape, zero-out lowest(absolute value) (1-k)*100% values
def compress_topk(input, ctx = None, topk=0.1):
    original_shape = input.shape
    original_stride = input.stride()
    
    batch_size = input.size(0)
    total_elements = input.numel() // batch_size
    original_input = input.clone()
    
    try:
        input_2d = torch.as_strided(input, 
                                   size=(batch_size, total_elements),
                                   stride=(original_stride[0], original_stride[-1]))
        
        input_2d_abs = torch.abs(input_2d)
        n_lowest = int(round((input_2d_abs.numel() * (1 - topk))))
        if n_lowest == 0:
            pivot = 0
        else:
            pivot = torch.kthvalue(input_2d_abs.flatten(), k=n_lowest).values
        mask = input_2d_abs <= pivot
        out = input_2d.masked_fill_(mask, 0)
        
        result = torch.as_strided(out, size=original_shape, stride=original_stride)
        
        assert result.stride() == original_stride, "Stride changed unexpectedly"
        assert result.shape == original_shape, "Shape changed unexpectedly"
        compression_ratio = torch.norm(result)/torch.norm(original_input)
        # monitor.debug(f"L2 norm ratio (compressed/original): {compression_ratio:.4f}")
        # monitor.log_metric('compression_ratio', compression_ratio.item())
        return result
        
    except Exception as e:
        print(f"Error during compression: {e}")
        print(f"Original shape: {original_shape}, stride: {original_stride}")
        raise

def compress_topk_sync(input, ctx = None, topk=0.1):
    original_shape = input.shape
    original_stride = input.stride()
    
    total_elements = input.numel() // input.size(0)
    input_2d = torch.as_strided(input, 
                               size=(input.size(0), total_elements),
                               stride=(original_stride[0], original_stride[-1]))
    
    if hasattr(ctx, 'mask'):      
        mask = ctx.mask
    else:
        n_lowest = int(total_elements * (1 - topk))
        input_2d_abs = torch.abs(input_2d)
        pivot = torch.kthvalue(input_2d_abs, k=n_lowest, keepdim=True).values
        mask = input_2d_abs <= pivot
        ctx.mask = mask
    
    out = input_2d.masked_fill_(mask, 0)
    return torch.as_strided(out, size=original_shape, stride=original_stride)

def quantize_custom(input, ctx = None, k=8,*args,**kwargs):
    original_shape = input.shape
    original_stride = input.stride()
    
    total_elements = input.numel() // input.size(0)
    x = torch.as_strided(input, 
                        size=(input.size(0), total_elements),
                        stride=(original_stride[0], original_stride[-1]))
    
    x_min = x.min(dim=1, keepdim=True).values
    x_max = x.max(dim=1, keepdim=True).values
    x_norm = (x - x_min) / (x_max - x_min + 1e-8)  
    
    val = torch.round((2**k - 1) * x_norm)  
    val = torch.clamp(val, 0, 2**k - 1)     
    val = val / (2**k - 1)                  
    out = val * (x_max - x_min) + x_min     
    
    return torch.as_strided(out, size=original_shape, stride=original_stride)

def natural_quantize(input, ctx=None, k=8,*args,**kwargs):
    original_shape = input.shape
    original_stride = input.stride()
    
    output = quantize_custom(input, ctx, k)
    
    total_elements = output.numel() // output.size(0)
    output_2d = torch.as_strided(output, 
                                size=(output.size(0), total_elements),
                                stride=(original_stride[0], original_stride[-1]))
    
    mantissa, exponent = torch.frexp(output_2d)
    sign = torch.sign(output_2d)
    
    p = (mantissa * sign * 2 - 1).clip(min=0)
    shift = torch.bernoulli(p)
    
    result = sign * 2.0**(exponent + shift - 1)
    
    return torch.as_strided(result, size=original_shape, stride=original_stride)

def natural_compress(input, ctx = None, *args,**kwargs):
    original_shape = input.shape
    original_stride = input.stride()
    
    total_elements = input.numel() // input.size(0)
    input_2d = torch.as_strided(input, 
                               size=(input.size(0), total_elements),
                               stride=(original_stride[0], original_stride[-1]))
    
    mantissa, exponent = torch.frexp(input_2d)
    sign = torch.sign(input_2d)
    
    p = (mantissa * sign * 2 - 1).clip(min=0)
    shift = torch.bernoulli(p)
    
    result = sign * 2.0**(exponent + shift - 1)
    
    return torch.as_strided(result, size=original_shape, stride=original_stride)

def quantize_ACSGD(x, ctx = None, nbits=8, scale_method='max', scale_dims=(0,)):
    
    fbits = nbits - 1
    
    if scale_method == 'max':
        # issue: sensitive to outlier points
        scale = x.abs().amax(scale_dims, keepdims=True)
    elif scale_method == 'l2':
        # ~95% confidence interval for normal distribution
        scale = x.pow(2).mean(scale_dims, keepdims=True).sqrt() * 2 
    else:
        raise Exception('unkonwn scale method.')
    # fp16 should be enough
    scale = scale.half()
    x = x / (scale + 1e-6)
    
    x = x.ldexp(torch.tensor(fbits))
    clip_min = -(1<<fbits)
    clip_max = (1<<fbits)-1

    x = x.round()
    x = x.clip(clip_min, clip_max)
    
    x = x - clip_min
    x = x.type(torch.uint8)
    
    return x, scale

def dequantize_ACSGD(x, nbits, scale):
    
    fbits = nbits - 1
    
    clip_min = -(1<<fbits)
    clip_max = (1<<fbits)-1
    
    x = x.float() + clip_min
    
    x = x / (clip_max+1) * scale
    
    return x

# XXX: works only for vectors on GPU
def quantize_dequantize_ACSGD(x, nbits, scale_method='max', scale_dims=(0,)):
    xx = x.flatten(start_dim=1)
    xx, scale = quantize_ACSGD(xx, nbits, scale_method, scale_dims)
    c_x = dequantize_ACSGD(xx, nbits, scale)
    return c_x

if __name__ == "__main__":
    x = torch.tensor([-0.5,0.3])
    compressed = compress_topk(x,topk=0.5)
    print(compressed)



