import os
import re

import numpy as np
import torch
from scipy.fft import dct, idct
import matplotlib.pyplot as plt

import torch.nn.functional as F

from utils.frequency import transform_ifft, transform_fft


if __name__ == "__main__":
    idxs = set()
    img_w=28;img_h=28;colors=3 # 根据数据集手动设置尺寸    
        

    shap_path = "./CMnist/ERM_50/"
    save_path = "../pos/cm/erm/"
    
    classes = ['0.0','1.0']
    
    for c in classes:
        result_path = os.path.join(shap_path, c)
        
        result = torch.zeros(img_w, img_h,3) # result就是欲求的结果向量
        
        if os.path.isdir(result_path):
            files = os.listdir(result_path)
            for file in files:
                if ".png" in file and not "shap" in file:    
                    
                    freq = transform_fft(torch.tensor(plt.imread(os.path.join(result_path, file)))) # 从图片中获得频域向量
                    shap = torch.load(os.path.join(result_path, file.split('.')[0] + "_freq.pt")) # 读取对应的shap矩阵
                    price=0
                                        
                    
                    for i in shap: # 统计正shap值数量
                        price += int(i>0) 
                    shap = torch.tensor([1 if i>0 else 0 for i in shap])
                    
                    
                    shap = shap.view(28, 28) 
                    shap = shap.unsqueeze(0).expand(3 ,28, 28).unsqueeze(0)
                                       
                    shap = F.interpolate(shap, size=[img_w, img_h], mode="nearest").float()
        
                    shap = shap.squeeze().permute(1,2,0)
                                                          
                    #.transpose((1,2,0))
                    result = result + freq * shap * price / (img_w*img_h)
                                        
            torch.save(result, save_path + c +"_result.pt")            
           