import matplotlib.pyplot as plt
import os
import numpy as np
from PIL import Image

# item_list = ['carpet', 'grid', 'leather', 'tile', 'wood', 'bottle', 'cable', 'capsule',
#              'hazelnut', 'metal_nut', 'pill', 'screw', 'toothbrush', 'transistor', 'zipper']
#
# img_path = './visualize/vitill_mvtec_uni_dinov2br_c392_en29_bn4dp2_de8_elaelu_md2_i1_it10k_sams2e3_wd1e4_w1hcosa2e4_ghmp09f01w01_b16_ev_s1'
# plot_idx = 1
# plt.figure(figsize=(10, 8))
# for item in item_list:
#     imgs = os.listdir(os.path.join(img_path, item))
#     imgs = [img for img in imgs if 'img' in img and 'good' not in img]
#     seeds = np.random.choice(range(len(imgs)), size=2, replace=False)
#
#     for seed in seeds:
#         img = np.array(Image.open(os.path.join(img_path, item, imgs[seed])).convert('RGB'))
#         gt = np.array(Image.open(os.path.join(img_path, item, imgs[seed].replace('img', 'gt'))))
#         gt = np.expand_dims(gt, axis=2)
#         gt = np.repeat(gt, 3, axis=2)
#         cam = np.array(Image.open(os.path.join(img_path, item, imgs[seed].replace('img', 'cam'))))
#         plt.subplot(3, 10, plot_idx)
#
#         white = np.ones(shape=(5, img.shape[0], 3), dtype='uint8') * 255
#         cat = np.concatenate([img, white, gt, white, cam], axis=0)
#
#         plt.axis('off')
#         plt.imshow(cat)
#         plt.tick_params(left=False, right=False, labelleft=False,
#                         labelbottom=False, bottom=False)
#
#         plot_idx += 1
#
# plt.subplots_adjust(wspace=0., hspace=0.05)
# plt.show()


# item_list = ['candle', 'capsules', 'cashew', 'chewinggum', 'fryum', 'macaroni1', 'macaroni2',
#              'pcb1', 'pcb2', 'pcb3', 'pcb4', 'pipe_fryum']
#
# img_path = './visualize/vitill_visa_uni_dinov2br_c392r_en29_bn4dp2_de8_elaelu_md2_i1_it10k_sams2e3_wd1e4_w1hcosa_ghmp09f01w01_b16_ev_s1'
# plot_idx = 1
# plt.figure(figsize=(10, 8), dpi=300)
# for item in item_list:
#     imgs = os.listdir(os.path.join(img_path, item))
#     imgs = [img for img in imgs if 'img' in img and 'good' not in img]
#     seeds = np.random.choice(range(len(imgs)), size=2, replace=False)
#
#     for seed in seeds:
#         img = np.array(Image.open(os.path.join(img_path, item, imgs[seed])).convert('RGB'))
#         gt = np.array(Image.open(os.path.join(img_path, item, imgs[seed].replace('img', 'gt'))))
#         gt = np.expand_dims(gt, axis=2)
#         gt = np.repeat(gt, 3, axis=2)
#         cam = np.array(Image.open(os.path.join(img_path, item, imgs[seed].replace('img', 'cam'))))
#         plt.subplot(3, 8, plot_idx)
#
#         white = np.ones(shape=(5, img.shape[0], 3), dtype='uint8') * 255
#         cat = np.concatenate([img, white, gt, white, cam], axis=0)
#
#         plt.axis('off')
#         plt.imshow(cat)
#         plt.tick_params(left=False, right=False, labelleft=False,
#                         labelbottom=False, bottom=False)
#
#         plot_idx += 1
#
# plt.subplots_adjust(wspace=0., hspace=0.05)
# plt.show()


item_list = ['audiojack', 'bottle_cap', 'button_battery', 'end_cap', 'eraser', 'fire_hood',
             'mint', 'mounts', 'pcb', 'phone_battery', 'plastic_nut', 'plastic_plug',
             'porcelain_doll', 'regulator', 'rolled_strip_base', 'sim_card_set', 'switch', 'tape',
             'terminalblock', 'toothbrush', 'toy', 'toy_brick', 'transistor1', 'usb',
             'usb_adaptor', 'u_block', 'vcpill', 'wooden_beads', 'woodstick', 'zipper']

img_path = './visualize/vitill_realiad_uni_dinov2br_c392r_en29_bn4dp4_de8_elaelu_md2_i1_it50k_sams2e3_wd1e4_w1hcosa2e4_ghmp09f01w01_b16_s1'

plot_idx = 1
plt.figure(figsize=(12, 6), dpi=300)
for item in item_list:
    imgs = os.listdir(os.path.join(img_path, item))
    imgs = [img for img in imgs if 'img' in img and 'OK' not in img]
    seeds = np.random.choice(range(len(imgs)), size=2, replace=False)

    for seed in seeds:
        gt = np.array(Image.open(os.path.join(img_path, item, imgs[seed].replace('img', 'gt'))))
        for i in range(10):
            if gt.max() == 0:
                seed = np.random.choice(range(len(imgs)), size=1, replace=False)[0]
                gt = np.array(Image.open(os.path.join(img_path, item, imgs[seed].replace('img', 'gt'))))
            else:
                break

        img = np.array(Image.open(os.path.join(img_path, item, imgs[seed])).convert('RGB'))
        gt = np.expand_dims(gt, axis=2)
        gt = np.repeat(gt, 3, axis=2)
        cam = np.array(Image.open(os.path.join(img_path, item, imgs[seed].replace('img', 'cam'))))

        plt.subplot(2, 10, plot_idx)

        white = np.ones(shape=(5, img.shape[0], 3), dtype='uint8') * 255
        cat = np.concatenate([img, white, gt, white, cam], axis=0)

        plt.axis('off')
        plt.imshow(cat)
        plt.tick_params(left=False, right=False, labelleft=False,
                        labelbottom=False, bottom=False)

        if plot_idx % 20 == 0:
            plt.subplots_adjust(wspace=0.1, hspace=0.05)
            plt.show()
            plot_idx = 0
            plt.figure(figsize=(12, 6), dpi=300)

        plot_idx += 1

# plt.subplots_adjust(wspace=0., hspace=0.05)
# plt.show()
