import os
import pickle
import matplotlib
import matplotlib.pyplot as plt
from seaborn.matrix import heatmap
matplotlib.use('Agg')
import numpy as np
import shutil 
import math
import seaborn as sns
import PIL.Image as Image


Gamma2 = ['-0.3','-0.2','-0.1','+0.0','+0.1','+0.2','+0.3']
Gamma3 = ['0.90','1.10','1.30','1.50','1.70','1.90','2.10']

QQ={}
QQ['FolderName'] = r'F:\deep_study\20210824_threelayers_limit_width_zqx\drawing_1to2'
with open(r"F:\deep_study\20210824_threelayers_limit_width_zqx\result_1to2/objk.pkl",'rb') as f:
    Q = pickle.load(f)

# print(Q.keys())

for i in Gamma2:
    for j in Gamma3:
        w_dot_plot_layer1_100 = Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(100)+'_w_dot_plot_layer1']
        w_dot_plot_layer1_1000 = Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(1000)+'_w_dot_plot_layer1']
        w_dot_plot_layer1_2500 = Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(2500)+'_w_dot_plot_layer1']
        w_dot_plot_layer1_5000 = Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(5000)+'_w_dot_plot_layer1']
        w_dot_plot_layer1_10000 = Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(10000)+'_w_dot_plot_layer1']
        b_dot_plot_layer1_100 = Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(100)+'_b_dot_plot_layer1']
        b_dot_plot_layer1_1000 = Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(1000)+'_b_dot_plot_layer1']
        b_dot_plot_layer1_2500 = Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(2500)+'_b_dot_plot_layer1']
        b_dot_plot_layer1_5000 = Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(5000)+'_b_dot_plot_layer1']
        b_dot_plot_layer1_10000 = Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(10000)+'_b_dot_plot_layer1']

        ax = plt.gca()
        fig = plt.gcf()
        plt.rcParams['savefig.dpi'] = 200 #图片像素
        plt.rcParams['figure.dpi'] = 200 #分辨率
        # plt.xlim(-2,2)
        # # plt.ylim(-2,2)
        # plt.xlabel(r'w',fontsize=18)
        # plt.ylabel(r'b',fontsize=18)
        # plt.axis('off')
        if (i,j) in [('-0.3','1.50'),('-0.2','1.50'),('-0.1','1.50'),('+0.0','1.50'),('+0.1','1.70'),('+0.2','1.90'),('+0.3','2.10')]:
            ax.spines['top'].set_color('blue')  # 设置上‘脊梁’为红色
            ax.spines['right'].set_color('blue')  # 设置上‘脊梁’为无色
            ax.spines['bottom'].set_color('blue')
            ax.spines['left'].set_color('blue')
            ax.spines['top'].set_linewidth(8)  # 设置上‘脊梁’为红色
            ax.spines['right'].set_linewidth(8)   # 设置上‘脊梁’为无色
            ax.spines['bottom'].set_linewidth(8) 
            ax.spines['left'].set_linewidth(8) 
        plt.xticks([])
        plt.yticks([])
        plt.scatter(w_dot_plot_layer1_10000[0],b_dot_plot_layer1_10000[0],color = 'r',s = 10, label = 'initial')
        plt.scatter(w_dot_plot_layer1_10000[-1],b_dot_plot_layer1_10000[-1],color = 'g',s = 10,label = 'now')
        # plt.legend()
        fig.set_size_inches(10.0,10.0) #dpi = 300, output = 700*700 pixels
        plt.gca().xaxis.set_major_locator(plt.NullLocator())
        plt.gca().yaxis.set_major_locator(plt.NullLocator())
        plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
        plt.margins(0,0)
        plt.savefig(r'%s/w_b_of_layer_1to2_gam2_%s_gam3_%s.png'%(QQ['FolderName'],i,j))
        plt.close()


IMAGES_PATH = r'F:\deep_study\20210824_threelayers_limit_width_zqx\drawing_1to2\\'  # 图片集地址
IMAGES_FORMAT = ['.png', '.PNG']  # 图片格式
IMAGE_SIZE = 256  # 每张小图片的大小
IMAGE_ROW = 7  # 图片间隔，也就是合并成一张图后，一共有几行
IMAGE_COLUMN = 7  # 图片间隔，也就是合并成一张图后，一共有几列
IMAGE_SAVE_PATH = r'F:\deep_study\20210824_threelayers_limit_width_zqx\drawing_1to2\final.jpg'  # 图片转换后的地址

# 获取图片集地址下的所有图片名称
image_names = [name for name in os.listdir(IMAGES_PATH) for item in IMAGES_FORMAT if
               os.path.splitext(name)[1] == item]

image_names = sorted(image_names,reverse=True)

image_names = sorted(image_names[0:21],reverse=True) + sorted(image_names[21:],reverse=False)

print("image_names", image_names)


# 简单的对于参数的设定和实际图片集的大小进行数量判断
if len(image_names) != IMAGE_ROW * IMAGE_COLUMN:
    raise ValueError("合成图片的参数和要求的数量不能匹配！")


# 定义图像拼接函数
def image_compose():
    to_image = Image.new('RGB', (IMAGE_COLUMN * IMAGE_SIZE, IMAGE_ROW * IMAGE_SIZE))  # 创建一个新图
    # 循环遍历，把每张图片按顺序粘贴到对应位置上
    for y in range(1, IMAGE_ROW + 1):
        for x in range(1, IMAGE_COLUMN + 1):
            from_image = Image.open(IMAGES_PATH + image_names[IMAGE_COLUMN * (y - 1) + x - 1]).resize(
                (IMAGE_SIZE, IMAGE_SIZE), Image.ANTIALIAS)
            if y <= 3:
                to_image.paste(from_image, ((IMAGE_COLUMN - x) * IMAGE_SIZE, (IMAGE_ROW - y) * IMAGE_SIZE))
            else:
                to_image.paste(from_image, ((x - 1) * IMAGE_SIZE, (IMAGE_ROW - y) * IMAGE_SIZE))
    return to_image.save(IMAGE_SAVE_PATH)  # 保存新图


image_compose()  # 调用函数







# fig=plt.figure(figsize=(12,8))
# fig.set_tight_layout(True)
# plt.figure(1)
# ax1 = plt.subplot(221)
# plt.scatter(w_dot_plot_layer1_10000[0],b_dot_plot_layer1_10000[0],color = 'r',s = 10, label = 'initial')
# plt.scatter(w_dot_plot_layer1_10000[-1],b_dot_plot_layer1_10000[-1],color = 'g',s = 10,label = 'now')
# plt.xticks([])
# plt.yticks([])


# ax2 = plt.subplot(222)
# plt.scatter(w_dot_plot_layer1_10000[0],b_dot_plot_layer1_10000[0],color = 'r',s = 10, label = 'initial')
# plt.scatter(w_dot_plot_layer1_10000[-1],b_dot_plot_layer1_10000[-1],color = 'g',s = 10,label = 'now')
# plt.xticks([])
# plt.yticks([])

# ax3 = plt.subplot(212)
# plt.scatter(w_dot_plot_layer1_10000[0],b_dot_plot_layer1_10000[0],color = 'r',s = 10, label = 'initial')
# plt.scatter(w_dot_plot_layer1_10000[-1],b_dot_plot_layer1_10000[-1],color = 'g',s = 10,label = 'now')
# plt.xticks([])
# plt.yticks([])

# plt.show()
# fig.savefig(r'%s/w_b_of_layer_1to2_whole.png'%(QQ['FolderName']))
