import numpy as np
import torch

def decay_linear(total_steps, step_T, beta=1.0):
    #linear decrease alpha_T = 1 - 1e-2 - step_T * interval, 
    #return bar[alpha_T] = [alpha_0 * alpha_1 * ... * alpha_T]
    #step_T: the current step number
    #total_steps: the total decay steps
    interval = beta*(1 - 1e-2) / total_steps
    alpha_list = np.ones(step_T)
    for i in range(step_T):
        alpha_list[i] = 1 - 1e-2 - (i+1) * interval
    alpha_list = np.log(alpha_list+1e-12)
    log_sum = np.sum(alpha_list)
    bar_alpha_T = np.exp(log_sum)
    return bar_alpha_T

def calculate_Q_bar(step_T, decay_f, feature_dim_list, total_steps):
    #return the Q_bar matrix at step t with the decay function f
    #step_T: the current step number: [0, total steps)
    #decay_f: the type of decay function, in {linear}
    #feature_dim_list: the input features list
    #return: the Q matrix for each dimension of feature and the corresponding feature multiplied with Q
    if decay_f == 'linear':
        bar_alpha_T = decay_linear(total_steps, step_T)
    else:
        print('sorry we do not support the decay function now')
    
    Q_list = []
    for i in range(len(feature_dim_list)):
        dim = feature_dim_list[i]
        Q_1 = torch.eye(dim)
        Q_1 = Q_1 * bar_alpha_T
        Q_2 = torch.ones((dim, dim))
        Q_2 = Q_2 * (1 - bar_alpha_T) / dim
        Q_T = Q_1 + Q_2
        Q_list.append(Q_T)
    return Q_list


def calculate_Q(step_T, decay_f, feature_dim_list, total_steps):
    #return the Q matrix at step t with the decay function f
    #step_T: the current step number: [0, total steps)
    #decay_f: the type of decay function, in {linear}
    #feature_dim_list: the input features list
    #return: the Q matrix for each dimension of feature and the corresponding feature multiplied with Q
    if decay_f == 'linear':
        interval = (1 - 1e-2) / total_steps
        alpha_T = 1 - 1e-2 - (step_T+1) * interval
    else:
        print('sorry we do not support the decay function now')
    
    Q_list = []
    for i in range(len(feature_dim_list)):
        dim = feature_dim_list[i]
        Q_1 = torch.eye(dim)
        Q_1 = Q_1 * alpha_T
        Q_2 = torch.ones((dim, dim))
        Q_2 = Q_2 * (1 - alpha_T) / dim
        Q_T = Q_1 + Q_2
        Q_list.append(Q_T)
    return Q_list