from safetensors.torch import load_file,save_file
import torch
import os

def B_t(a_t_plus1, a_t):
    return a_t - a_t_plus1


def add_adjust_term(error_term,alpha_t_list): 
    m = len(alpha_t_list)-1 
    print(f'alpha_t_list长度={len(alpha_t_list)}')
    alpha_t_minus_m = alpha_t_list[-1] 
    sum_ = 0
    for k in range(m):
        alpha_t_minus_k = alpha_t_list[k+1]
        alpha_t_minus_k_plus_1 = alpha_t_list[k]
        sum_ += B_t(alpha_t_minus_k_plus_1, alpha_t_minus_k).to(error_term[k].device) * error_term[k]
    return -sum_


##### cache adjust K
def find_K_hw(path):
    error_tensor =  load_file(path)
    error_tensor = error_tensor['output']
    error_term = []
    for key in range(error_tensor.shape[0]):
        error_term.append(error_tensor[key])
        
    return error_term