import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
def func_dic(dic,x):
    x1 = x/np.sum(x)
    x1 = x
    dic.append(np.log(x1))
    dic.append((np.log(x1)-np.min(np.log(x1)))**2)
    dic.append((np.log(x1)-np.min(np.log(x1)))**0.5)
    x2 = np.divide(x,np.ones_like(x)*np.max(x))
    dic.append(x2**0.15)
    dic.append(x2**0.3)
    return dic
def dic_PCA(dic,n_dic):
    dic_svd = []
    x = StandardScaler().fit_transform(np.array(dic).T)
    if n_dic is not None:
        #print("wrong")
        pca = PCA(n_components=n_dic)
        x = pca.fit_transform(x)
        svd_dic_values_list = x.T
        for i in range(n_dic):
            dic_svd.append(svd_dic_values_list[i, :])
    else:
        #print('correct')
        dic_svd = x.T
    dic_svd = np.array(dic_svd)
    return dic_svd


def make_dic(args,class_freq,class_wise_acc_dic,class_weight_norm,class_noise_dic,num_classes):
    params_dic_origin = []
    #print(args["dic_component"]["dic_freq"])
    #print(args["dic_component"]["dic_freq"] == True)
    if args["dic_component"]["dic_freq"] == True:
        #print("correct")
        class_freq = np.array(class_freq)
        params_dic_origin=func_dic(params_dic_origin,class_freq)
    if args["dic_component"]["dic_error"] == True:
        class_diff = np.divide(np.ones_like(class_wise_acc_dic)*np.min(np.ones_like(class_wise_acc_dic)-class_wise_acc_dic),np.ones_like(class_wise_acc_dic)-class_wise_acc_dic)
        params_dic_origin=func_dic(params_dic_origin,class_diff)
    if args["dic_component"]["dic_norm"] == True:
        class_norm = class_weight_norm/np.max(class_weight_norm)
        params_dic_origin=func_dic(params_dic_origin,class_norm)
    if args["dic_component"]["dic_noise"] == True:
        class_noise = np.divide(np.ones_like(class_noise_dic)-class_noise_dic,np.ones_like(class_noise_dic)-class_noise_dic)
        params_dic_origin=func_dic(params_dic_origin,class_noise)     
    print(params_dic_origin)
    params_dic_origin = dic_PCA(params_dic_origin,args["dic_component"]["n_dic"])
    #for i in range(np.shape(params_dic_origin)[0]):
    #    params_dic_origin[i] = norm_dic(params_dic_origin[i])
    params_dic_origin = np.vstack((params_dic_origin,np.ones([num_classes])))
    return params_dic_origin