import cv2
import numpy as np

'''
A helper function that cut single image from a big grid images
'''

im1 = cv2.imread('./ours.png')
im2 = cv2.imread('./stargan-v2.png')
size = 256

def cut(im1, im2, row, col, fname):
    im_res1 = [im1[:size]]

    for r in row:
        im_res1.append( im2[(r-1)*size:r*size] )
        im_res1.append( im1[(r-1)*size:r*size] )

    im_res1 = np.concatenate(im_res1, axis=0)
    im_res = [im_res1[:,:size]]
    for c in col:
        im_res.append( im_res1[:,(c-1)*size:c*size] )
    im_res = np.concatenate(im_res, axis=1)
    cv2.imwrite(fname, im_res)

cat_row = [2,5,7,8]
cat_col = [3,4,9,10]
cut(im1, im2, cat_row, cat_col, 'compare_cat.png')

dog_row = [11,17,18,19]
dog_col = [12,13,14,16]
cut(im1, im2, dog_row, dog_col, 'compare_dog.png')

wild_row = [22,26,27,28]
wild_col = [20,23,24,25]
cut(im1, im2, wild_row, wild_col, 'compare_wild.png')


cut(im1, im2, cat_col[:2]+dog_col[:2], 
               wild_row, 'compare_mix_1.png')

cut(im1, im2, [21,23,26,27], 
              cat_row[:2]+dog_row[:2], 'compare_mix_2.png')