import torch
import torch.nn as nn
import numpy as np
import random
import torch.nn.functional as F

str2kernel = lambda s: torch.tensor(list(map(int, s)), dtype=torch.float32) * 2 - 1
bnum_base = torch.tensor([2 ** i for i in range(9)]).to(torch.int32).unsqueeze(0)

patterns_str = [('%9s' % bin(i)[2:]).replace(' ', '0') for i in range(512)]
patterns_all = str2kernel(patterns_str[0]).unsqueeze(0)
for i in range(1, len(patterns_str)):
    patterns_all = torch.cat([patterns_all, str2kernel(patterns_str[i]).unsqueeze(0)], dim=0)  # [512, 9]

# 定义邻域关系矩阵 [9, 9]，每行表示一个位置的邻域关系
neighbor_matrix = torch.zeros(9, 9)
neighbors = {
        0: [1, 3],  # 左上角
        1: [0, 2, 4],  # 上边
        2: [1, 5],  # 右上角
        3: [0, 4, 6],  # 左边
        4: [1, 3, 5, 7],  # 中间
        5: [2, 4, 8],  # 右边
        6: [3, 7],  # 左下角
        7: [4, 6, 8],  # 下边
        8: [5, 7]  # 右下角
}
for pos, neighs in neighbors.items():
    neighbor_matrix[pos, neighs] = 1

neighbor_counts = neighbor_matrix.sum(dim=1).unsqueeze(0)

def patterns2bnum(k):
    k = ((torch.sign(k) + 1) / 2).view(-1, 9).to(torch.int32).cpu()
    bnum = torch.sum(k * bnum_base, dim=1).cpu().numpy()
    return bnum


def one_hot(labels, n_class):  # labels: [n_num] in [0, n_class)
    return torch.zeros(labels.shape[0], n_class).to(labels.device).\
        scatter(1, labels.unsqueeze(1), 1)


def get_sorted_patterns(weights, bit_num):
    _, idxs = weights2patterns_l2_iqr(weights.view(-1, 9), patterns_all.to(weights.device))
    idxs = list(idxs.data.cpu().numpy())
    counts = [(i, idxs.count(i)) for i in range(512)]
    counts.sort(key=lambda x: (x[1], x[0]), reverse=True)
    counts = list(map(lambda x: x[0], counts[:2 ** bit_num]))
    patterns = patterns_all[counts]  # [2 ** bit_num, 9]
    return patterns


def get_random_patterns(bit_num):   # num_bit=5 随机选择32个可用的卷积核
    idxs = random.sample(list(range(512)), 2 ** bit_num) # 在0~512的index中随机采样2**5=32个index
    patterns = patterns_all[idxs]  # [2 ** bit_num, 9] 可用的卷积核
    return patterns


# def weights2patterns_l2(weights, patterns):
#     # weights: [cin * cout, 9], patterns: [2 ** bit_num, 9]; cin*cout个都必须在可用的32个中进行选择
#     norm = torch.norm(weights.unsqueeze(1) - patterns.unsqueeze(0), dim=2)  # [cin * cout, 2 ** bit_num]
#     idxs = norm.argmin(dim=1)  # [cin * cout]; 挑选出可用卷积核中距离最近的一个，返回的是index
#     return patterns[idxs], idxs  # [cin * cout, 9], [cin * cout]

def detect_singular_values(weights, k=1.5):
    """
    weights: [cin * cout, 9]
    返回奇异值的mask: [cin * cout, 9]
    """
    q1 = torch.quantile(weights, 0.25, dim=1, keepdim=True)
    q3 = torch.quantile(weights, 0.75, dim=1, keepdim=True)
    iqr = q3 - q1
    # 使用IQR方法定义上下界
    upper_bound = q3 + k * iqr
    lower_bound = q1 - k * iqr
    # 检测奇异值：超出IQR范围的值
    singular_mask = (weights > upper_bound) | (weights < lower_bound)
    return singular_mask

def local_scaling(weights, singular_mask):
    weights_expanded = weights.unsqueeze(2)                     # c_in*c_out, 9, 1
    diff_matrix = weights_expanded - weights_expanded.transpose(1, 2)
    device = diff_matrix.device
    masked_diff = diff_matrix * neighbor_matrix.to(device)                 # c_in*c_out,9,9
    avg_diff = (masked_diff.sum(dim=2) / neighbor_counts.to(device))       # c_in*c_out,9

    scalingf = 1/torch.abs(avg_diff)
    scaled_weights = torch.where(singular_mask, weights*scalingf, weights)
    return scaled_weights

def weights2patterns_l2_iqr(weights, patterns):
    # weights: [cin * cout, 9], patterns: [2 ** bit_num, 9]; cin*cout个都必须在可用的32个中进行选择
    singular_mask = detect_singular_values(weights)         # 找奇异值
    scaled_weights = local_scaling(weights, singular_mask)  # 缩放奇异值
    norm = torch.norm(scaled_weights.unsqueeze(1) - patterns.unsqueeze(0), dim=2)  # [cin * cout, 2 ** bit_num]
    idxs = norm.argmin(dim=1)       # [cin * cout]; 挑选出可用卷积核中距离最近的一个，返回的是index
    return patterns[idxs], idxs     # [cin * cout, 9], [cin * cout]

# def weights2patterns_hamming(weights, patterns):
#
#     weights_binary = torch.sign(weights)  # [cin * cout, 9]
#
#     hamming_dist = torch.sum((weights_binary.unsqueeze(1) != patterns.unsqueeze(0)).float(), dim=2)  # [cin * cout, 2 ** bit_num]
#
#     # 选择汉明距离最小的pattern
#     idxs = hamming_dist.argmin(dim=1)  # [cin * cout]
#     return patterns[idxs], idxs  # [cin * cout, 9], [cin * cout]

def remove_repetitive_patterns(patterns, bit_num):
    tmp_set, new_patterns = set(), None
    for i in range(2 ** bit_num):
        pattern = patterns[i].unsqueeze(0)
        bnum = patterns2bnum(pattern)[0]
        if bnum not in tmp_set:
            tmp_set.add(bnum)
            new_patterns = pattern if i == 0 else torch.cat([new_patterns, pattern], dim=0)
    return new_patterns


def conpensate_patterns(weights, patterns, bnum_set, bit_num):
    add_num = 0
    #patterns_resample = get_sorted_patterns(weights, bit_num).to(weights.device).detach()
    patterns_resample = get_random_patterns(bit_num).to(weights.device).detach()
    for i in range(patterns_resample.shape[0]):
        pattern = patterns_resample[i].unsqueeze(0)
        bnum = patterns2bnum(pattern)[0]
        if bnum not in bnum_set:
            patterns = torch.cat([patterns, pattern], dim=0)
            add_num += 1
        if len(bnum_set) + add_num == 2 ** bit_num:
            break
    return patterns
