# Description: This script is used to search the dimensions of hidden states that are more likely to be positional hidden states.

import numpy as np
import scipy
from tqdm import tqdm

def fit_and_monotic(data,degree=3,skip=50):
    data=data[skip:]
    data=(data-np.min(data))/(np.max(data)-np.min(data))
    #fit
    x = np.arange(len(data),dtype=np.float32)
    z1 = np.polyfit(x, data.astype(np.float32), degree)
    p1 = np.poly1d(z1)
    y_fit = p1(x)
    #determine the monotonicity, i.e., whether the difference is of the same sign
    y_fit_diff=np.diff(y_fit)
    #calculate the goodness of fit
    r2 = 1 - np.sum((data - y_fit) ** 2) / np.sum((data - np.mean(data)) ** 2)

    if np.all(y_fit_diff>=0):
        return 1,r2
    elif np.all(y_fit_diff<=0):
        return 2,r2
    else:
        return 0,r2

def get_mono_layers(hidden_states_dim:np.ndarray,only_inc_or_dec=0,r2_chose_type="mean",degree=3,skip=100)->tuple:
    inc_layer_num=[]
    dec_layer_num=[]
    inc_r2_list=[]
    dec_r2_list=[]
    for layer in range(2,hidden_states_dim.shape[0]-1):
        hidden_states_dim_layer=hidden_states_dim[layer]
        m,r2=fit_and_monotic(hidden_states_dim_layer,degree=degree,skip=skip)
        if m==1:
            inc_layer_num.append(layer)
            inc_r2_list.append(r2)
        elif m==2:
            dec_layer_num.append(layer)
            dec_r2_list.append(r2)

    # determine whether there are more increasing layers or decreasing layers
    inc_layer_num_max=len(inc_layer_num)
    dec_layer_num_max=len(dec_layer_num)

    if only_inc_or_dec==1:
        mono_layer_num=inc_layer_num
        if len(inc_r2_list)==0:
            r2_mean=0
        else:
            r2_mean=np.mean(inc_r2_list) if r2_chose_type=="mean" else np.min(inc_r2_list)
        return mono_layer_num,r2_mean
    elif only_inc_or_dec==2:
        mono_layer_num=dec_layer_num
        r2_mean=np.mean(dec_r2_list)
        return mono_layer_num,r2_mean
    else:
        if inc_layer_num_max>=dec_layer_num_max:
            mono_layer_num=inc_layer_num
            r2_mean=np.mean(inc_r2_list)
        else:
            mono_layer_num=dec_layer_num
            r2_mean=np.mean(dec_r2_list)

        return mono_layer_num,r2_mean

#the smoothness of the curve is defined as the integral of the square of the second derivative
def get_smoothness(x):
    # normalize to 0-1
    x = (x - np.min(x)) / (np.max(x) - np.min(x))
    # smooth with a sliding window
    x = scipy.signal.savgol_filter(x, window_length=10, polyorder=3, axis=0)
    # calculate the first derivative, i.e., the first-order difference
    dx = np.diff(x, axis=0)
    # calculate the second derivative, i.e., the second-order difference
    ddx = np.diff(dx, axis=0)
    # calculate the square
    ddx2 = np.square(ddx)
    # calculate the integral
    smoothness = np.sum(ddx2, axis=0)
    return smoothness


# 在单调的层中找出最小的平滑度
def get_smoothness_layers(hidden_states_all_mean, layers, skip=300):
    smoothness_layers = [get_smoothness(hidden_states_all_mean[i][skip:]) for i in layers]
    return np.min(smoothness_layers)


if __name__ == '__main__':

    save_path_hidden = "PATH to Hidden States"
    hidden_states_path = save_path_hidden
    hidden_states_all_mean = np.load(hidden_states_path).astype(np.float32)  # shape=(num_layers,num_samples,hidden_size)
    print("已读取", hidden_states_path)


    if len(hidden_states_all_mean.shape) == 4:
        # if shape is (num_layers,num_heads,seq_len,head_dim)，reshape to (num_layers,seq_len,hidden_size)
        hidden_states_all_mean = hidden_states_all_mean.transpose(0, 2, 1, 3).reshape(hidden_states_all_mean.shape[0],
                                                                                      hidden_states_all_mean.shape[2],
                                                                                      -1)

    num_layers = hidden_states_all_mean.shape[0]
    hidden_size = hidden_states_all_mean.shape[2]
    seq_len = hidden_states_all_mean.shape[1]

    # traverse each dimension, calculate the number of increasing or decreasing layers (take the maximum value)
    each_dim_mono_layers_and_r2 = {
        dim: get_mono_layers(hidden_states_all_mean[:, :, dim], only_inc_or_dec=1, r2_chose_type="mean", skip=300) for
        dim in tqdm(range(hidden_size))}

    # keep the dimensions with monotonic layers greater than num_layers // 4
    dim_cand1 = [dim for dim in range(hidden_size) if len(each_dim_mono_layers_and_r2[dim][0]) > num_layers // 4]
    dim_cand1 = np.array(dim_cand1).astype(np.int32)

    # calculate the smoothness of each dimension, take the top-k layers with the smallest smoothness
    topk = 10
    dim_cand_smoothness = [
        get_smoothness_layers(hidden_states_all_mean[:, :, dim], each_dim_mono_layers_and_r2[dim][0], skip=200) for dim
        in dim_cand1]
    dim_cand = dim_cand1[np.argsort(dim_cand_smoothness)[:int(topk)]]
    dim_cand = dim_cand.tolist()
    dim_cand = [int(i) for i in dim_cand]

    print("dim_candidate", list(dim_cand))

