from safetensors.torch import load_file,save_file
import torch
import os
    
def B_t(a_t_plus1, a_t):
    left_ = (1-a_t) ** 0.5
    right_ = ((a_t * (1-a_t_plus1))**0.5) / ((a_t_plus1)**0.5)
    return left_-right_


def B_t_vpred(a_t_plus1, a_t):
    return (a_t_plus1 * (1-a_t)) **0.5 - (a_t * (1-a_t_plus1)) **0.5

def A_t_vpred(a_t_plus1, a_t):
    return (a_t_plus1 * a_t) **0.5 + ((1-a_t) * (1-a_t_plus1)) **0.5

def add_adjust_term_ddim_eplison(error_term,alpha_t_list): 
    m = len(alpha_t_list)-1 
    print(f'alpha_t_list长度={len(alpha_t_list)}')
    alpha_t_minus_m = alpha_t_list[-1] 
    sum_ = 0
    for k in range(m):
        alpha_t_minus_k = alpha_t_list[k+1]
        alpha_t_minus_k_plus_1 = alpha_t_list[k]
        B_alpha_t_minus_k = B_t(alpha_t_minus_k_plus_1, alpha_t_minus_k)
        sum_ +=  B_alpha_t_minus_k * error_term[k]
    return -sum_


def add_adjust_term_ddim_vpred(error_term,alpha_t_list): ### 这里的alpha_t_list是包含了[t+1,t-m],m+1个时刻的coefficient,有t+1是因为bt要用
    m = len(alpha_t_list)-1 
    alpha_t_minus_m = alpha_t_list[-1] 
    sum_ = 0
    for k in range(m):
        alpha_t_minus_k = alpha_t_list[k+1]
        alpha_t_minus_k_plus_1 = alpha_t_list[k]
        B_alpha_t_minus_k = B_t_vpred(alpha_t_minus_k_plus_1, alpha_t_minus_k)
        sum_ +=  B_alpha_t_minus_k * error_term[k]
    return -sum_




def add_adjust_term_ddim_vpred_final(error_term,alpha_t_list): ### 这里的alpha_t_list是包含了[t+1,t-m],m+1个时刻的coefficient,有t+1是因为bt要用
    m = len(alpha_t_list)-1 
    alpha_t_minus_m = alpha_t_list[-1] 
    sum_ = 0
    for k in range(m):
        alpha_t_minus_k = alpha_t_list[k+1]
        alpha_t_minus_k_plus_1 = alpha_t_list[k]
        B_alpha_t_minus_k = B_t_vpred(alpha_t_minus_k_plus_1, alpha_t_minus_k)
        sum_ +=  (alpha_t_minus_k / alpha_t_minus_m )**0.5 * (1/ (1+A_t_vpred(alpha_t_minus_k_plus_1, alpha_t_minus_k))) * (1/(1 + abs(B_alpha_t_minus_k))) * error_term[k]
    return -sum_
        