import os
import collections
import copy
import argparse
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import save_image
from matplotlib.colors import Normalize
import matplotlib as mlt
import math
from mpl_toolkits.mplot3d import Axes3D
import torch.nn.functional as F

def visual_shap(shap_value, w, h, img_path):
    max_value = torch.topk(shap_value, 1)[0].item()
    min_value = torch.topk(shap_value, 1, largest=False)[0].item()
    maximum = max([abs(max_value), abs(min_value)])

    shap_value = shap_value.view(w, h)
    plt.figure()
    norm = Normalize(vmin=-maximum, vmax=maximum)
    plt.imshow(shap_value, norm=norm, cmap=mlt.cm.bwr)
    plt.gca().get_yaxis().set_visible(False)  # 不显示y轴
    plt.gca().get_xaxis().set_visible(False)  # 不显示x轴
    #plt.colorbar()
    plt.tight_layout()
    plt.savefig(img_path, format='png')
    plt.close()

def visual_shap_w_tick(shap_value, w, h, img_path):
    max_value = torch.topk(shap_value, 1)[0].item()
    min_value = torch.topk(shap_value, 1, largest=False)[0].item()
    maximum = max([abs(max_value), abs(min_value)])

    x_spec = [i for i in range(w)]
    x_shift_spec = np.fft.fftshift(x_spec)
    y_spec = [i for i in range(h)]
    y_shift_spec = np.fft.fftshift(y_spec)


    shap_value = shap_value.view(w, h)
    plt.figure()
    norm = Normalize(vmin=-maximum, vmax=maximum)
    plt.imshow(shap_value, norm=norm, cmap=mlt.cm.bwr)
    #plt.gca().get_yaxis().set_visible(False)  # 不显示y轴
    #plt.gca().get_xaxis().set_visible(False)  # 不显示x轴
    plt.xticks(x_spec, x_shift_spec)
    plt.yticks(y_spec, y_shift_spec)
    plt.colorbar()
    plt.tight_layout()
    plt.savefig(img_path, format='png')
    plt.close()
    
def visual_shap_pillar(shap_value, w, h, img_path):
    max_value = torch.topk(shap_value, 1)[0].item()
    min_value = torch.topk(shap_value, 1, largest=False)[0].item()
    maximum = max([abs(max_value), abs(min_value)])

    x_spec = [i for i in range(w)]
    x_shift_spec = np.fft.fftshift(x_spec)
    y_spec = [i for i in range(h)]
    y_shift_spec = np.fft.fftshift(y_spec)


    shap_value = shap_value.view(w, h)
    plt.figure()
    # norm = Normalize(vmin=-maximum, vmax=maximum)
    # plt.imshow(shap_value, norm=norm, cmap=mlt.cm.bwr)
    data= []
    for i in range(w):
      data.append(sum(shap_value[:,i]))
      
    print(data)
      
    new_data = data[8:]+data[0:8]
    
    print(new_data)
    
    plt.bar(range(w),new_data,width=0.5)

    plt.xticks(range(w), x_spec)
    #plt.yticks(y_spec, data)
    #plt.colorbar()
    plt.tight_layout()
    plt.savefig(img_path, format='png')
    plt.close()

if __name__ == "__main__":
    x_spec = [i for i in range(16)]
    x_shift_spec = np.fft.fftshift(x_spec)
    y_spec = [i for i in range(16)]
    y_shift_spec = np.fft.fftshift(y_spec)
    print(x_spec)
    print(x_shift_spec)
    
    freq_shap = torch.load("D:/2.phd/project/0426/实验结果/0501erm/right/0/4_freq.pt")
    #visual_shap_pillar(freq_shap, 16, 16, "D:/2.phd/project/0426/实验结果/0501erm/right/0/4_freq_show.png")
    
    
    visual_shap(freq_shap, 16, 16, "D:/2.phd/project/0426/实验结果/0501erm/right/0/4_freq_show.png")