from glob import glob
from tqdm import tqdm
import os
from argparse import ArgumentParser
from PIL import Image
import torch
import numpy as np
import cv2
import pickle as pkl
import math

def image_mask_resize(args):
     # multi view
    os.makedirs(f'{args.data_path}/image_masks_2/hair', exist_ok=True)
    os.makedirs(f'{args.data_path}/image_masks_2/body', exist_ok=True)
    os.makedirs(f'{args.data_path}/image_masks_2/face', exist_ok=True)
    os.makedirs(f'{args.data_path}/image_masks_4/hair', exist_ok=True)
    os.makedirs(f'{args.data_path}/image_masks_4/body', exist_ok=True)
    os.makedirs(f'{args.data_path}/image_masks_4/face', exist_ok=True)
    multi_cam_names = os.listdir(f'{args.data_path}/images')
    for cam in tqdm(multi_cam_names):
        cam_path = f'{args.data_path}/images/{cam}'
        mask_hair_cam_path = f'{args.data_path}/masks/hair/{cam}'
        mask_body_cam_path = f'{args.data_path}/masks/body/{cam}'
        mask_face_cam_path = f'{args.data_path}/masks/face/{cam}'
        os.makedirs(f'{args.data_path}/image_masks_2/hair/{cam}', exist_ok=True)
        os.makedirs(f'{args.data_path}/image_masks_2/body/{cam}', exist_ok=True)
        os.makedirs(f'{args.data_path}/image_masks_2/face/{cam}', exist_ok=True)
        os.makedirs(f'{args.data_path}/image_masks_4/hair/{cam}', exist_ok=True)
        os.makedirs(f'{args.data_path}/image_masks_4/body/{cam}', exist_ok=True)
        os.makedirs(f'{args.data_path}/image_masks_4/face/{cam}', exist_ok=True)
        img_names = os.listdir(cam_path)
        for img_name in img_names:
            img_path = f'{cam_path}/{img_name}'
            mask_hair_path = f'{mask_hair_cam_path}/{img_name}'
            mask_body_path = f'{mask_body_cam_path}/{img_name}'
            mask_face_path = f'{mask_face_cam_path}/{img_name}'
            img = np.asarray(Image.open(img_path))
            mask_hair = np.asarray(Image.open(mask_hair_path))
            mask_body = np.asarray(Image.open(mask_body_path))
            mask_face = np.asarray(Image.open(mask_face_path))
            h_old, w_old = img.shape[:2]
            mask_hair = (mask_hair > 0.01).astype(np.uint8)
            mask_body = (mask_body > 0.01).astype(np.uint8)
            mask_face = (mask_face > 0.01).astype(np.uint8)
            mask_hair = np.stack([mask_hair, mask_hair, mask_hair], axis=-1)
            mask_body = np.stack([mask_body, mask_body, mask_body], axis=-1)
            mask_face = np.stack([mask_face, mask_face, mask_face], axis=-1)
            mask_hair = mask_hair * img
            mask_body = mask_body * img
            mask_face = mask_face * img
            
            Image.fromarray(mask_hair).resize((w_old // 2, h_old // 2), Image.BICUBIC).save(f'{args.data_path}/image_masks_2/hair/{cam}/{img_name}')
            Image.fromarray(mask_hair).resize((w_old // 4, h_old // 4), Image.BICUBIC).save(f'{args.data_path}/image_masks_4/hair/{cam}/{img_name}')
            Image.fromarray(mask_face).resize((w_old // 2, h_old // 2), Image.BICUBIC).save(f'{args.data_path}/image_masks_2/face/{cam}/{img_name}')
            Image.fromarray(mask_face).resize((w_old // 4, h_old // 4), Image.BICUBIC).save(f'{args.data_path}/image_masks_4/face/{cam}/{img_name}')
            Image.fromarray(mask_body).resize((w_old // 2, h_old // 2), Image.BICUBIC).save(f'{args.data_path}/image_masks_2/body/{cam}/{img_name}')
            Image.fromarray(mask_body).resize((w_old // 4, h_old // 4), Image.BICUBIC).save(f'{args.data_path}/image_masks_4/body/{cam}/{img_name}')
# from torchvision.transforms import Resize, InterpolationMode
def main(args):
    # multi view
    os.makedirs(f'{args.data_path}/images_2', exist_ok=True)
    os.makedirs(f'{args.data_path}/images_4', exist_ok=True)
    os.makedirs(f'{args.data_path}/masks_2/hair', exist_ok=True)
    os.makedirs(f'{args.data_path}/masks_2/body', exist_ok=True)
    os.makedirs(f'{args.data_path}/masks_2/face', exist_ok=True)
    os.makedirs(f'{args.data_path}/masks_4/hair', exist_ok=True)
    os.makedirs(f'{args.data_path}/masks_4/body', exist_ok=True)
    os.makedirs(f'{args.data_path}/masks_4/face', exist_ok=True)
    multi_cam_names = os.listdir(f'{args.data_path}/images')
    for cam in tqdm(multi_cam_names):
        cam_path = f'{args.data_path}/images/{cam}'
        mask_hair_cam_path = f'{args.data_path}/masks/hair/{cam}'
        mask_face_cam_path = f'{args.data_path}/masks/face/{cam}'
        mask_body_cam_path = f'{args.data_path}/masks/body/{cam}'
        os.makedirs(f'{args.data_path}/images_2/{cam}', exist_ok=True)
        os.makedirs(f'{args.data_path}/images_4/{cam}', exist_ok=True)
        os.makedirs(f'{args.data_path}/masks_2/hair/{cam}', exist_ok=True)
        os.makedirs(f'{args.data_path}/masks_2/face/{cam}', exist_ok=True)
        os.makedirs(f'{args.data_path}/masks_2/body/{cam}', exist_ok=True)
        os.makedirs(f'{args.data_path}/masks_4/hair/{cam}', exist_ok=True)
        os.makedirs(f'{args.data_path}/masks_4/face/{cam}', exist_ok=True)
        os.makedirs(f'{args.data_path}/masks_4/body/{cam}', exist_ok=True)
        img_names = os.listdir(cam_path)
        for img_name in img_names:
            img_path = f'{cam_path}/{img_name}'
            mask_hair_path = f'{mask_hair_cam_path}/{img_name}'
            mask_face_path = f'{mask_face_cam_path}/{img_name}'
            mask_body_path = f'{mask_body_cam_path}/{img_name}'
            img = np.asarray(Image.open(img_path))
            mask_hair = np.asarray(Image.open(mask_hair_path))
            mask_face = np.asarray(Image.open(mask_face_path))
            mask_body = np.asarray(Image.open(mask_body_path))
            h_old, w_old = img.shape[:2]
            Image.fromarray(img).resize((w_old // 2, h_old // 2), Image.BICUBIC).save(f'{args.data_path}/images_2/{cam}/{img_name}')
            Image.fromarray(img).resize((w_old // 4, h_old // 4), Image.BICUBIC).save(f'{args.data_path}/images_4/{cam}/{img_name}')
            Image.fromarray(mask_hair).resize((w_old // 2, h_old // 2), Image.BICUBIC).save(f'{args.data_path}/masks_2/hair/{cam}/{img_name}')
            Image.fromarray(mask_hair).resize((w_old // 4, h_old // 4), Image.BICUBIC).save(f'{args.data_path}/masks_4/hair/{cam}/{img_name}')
            Image.fromarray(mask_body).resize((w_old // 2, h_old // 2), Image.BICUBIC).save(f'{args.data_path}/masks_2/body/{cam}/{img_name}')
            Image.fromarray(mask_body).resize((w_old // 4, h_old // 4), Image.BICUBIC).save(f'{args.data_path}/masks_4/body/{cam}/{img_name}')
            Image.fromarray(mask_face).resize((w_old // 2, h_old // 2), Image.BICUBIC).save(f'{args.data_path}/masks_2/face/{cam}/{img_name}')
            Image.fromarray(mask_face).resize((w_old // 4, h_old // 4), Image.BICUBIC).save(f'{args.data_path}/masks_4/face/{cam}/{img_name}')
    # colmap
    # data_path = args.data_path+'/colmap'
    # os.makedirs(f'{data_path}/images_2', exist_ok=True)
    # os.makedirs(f'{data_path}/images_4', exist_ok=True)
    # os.makedirs(f'{data_path}/masks_2/hair', exist_ok=True)
    # os.makedirs(f'{data_path}/masks_2/body', exist_ok=True)
    # os.makedirs(f'{data_path}/masks_4/hair', exist_ok=True)
    # os.makedirs(f'{data_path}/masks_4/body', exist_ok=True)
    # img_names = os.listdir(f'{args.data_path}/colmap/images')
    # for img_name in tqdm(img_names):
    #     img = np.asarray(Image.open(f'{data_path}/images/{img_name}'))
    #     h_old, w_old = img.shape[:2]
    #     mask_hair = np.asarray(Image.open(f'{data_path}/masks/hair/{img_name}'))
    #     mask_body = np.asarray(Image.open(f'{data_path}/masks/body/{img_name}'))
    #     Image.fromarray(img).resize((w_old // 2, h_old // 2), Image.BICUBIC).save(f'{data_path}/images_2/{img_name}')
    #     Image.fromarray(img).resize((w_old // 4, h_old // 4), Image.BICUBIC).save(f'{data_path}/images_4/{img_name}')
    #     Image.fromarray(mask_hair).resize((w_old // 2, h_old // 2), Image.BICUBIC).save(f'{data_path}/masks_2/hair/{img_name}')
    #     Image.fromarray(mask_body).resize((w_old // 2, h_old // 2), Image.BICUBIC).save(f'{data_path}/masks_2/body/{img_name}')
    #     Image.fromarray(mask_hair).resize((w_old // 4, h_old // 4), Image.BICUBIC).save(f'{data_path}/masks_4/hair/{img_name}')
    #     Image.fromarray(mask_body).resize((w_old // 4, h_old // 4), Image.BICUBIC).save(f'{data_path}/masks_4/body/{img_name}')
    
    
if __name__ == "__main__":
    parser = ArgumentParser(conflict_handler='resolve')

    parser.add_argument('--data_path', default='', type=str)
    # parser.add_argument('--img_size', default=1024, type=int)

    args, _ = parser.parse_known_args()
    args = parser.parse_args()

    main(args)
    image_mask_resize(args)