import argparse
import os
from glob import glob

import numpy as np
from PIL import Image
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument(
    '--baseline_dir',
    type=str,
    default=None
)
parser.add_argument(
    '--ttt_dir',
    type=str,
    default=None
)
parser.add_argument(
    '--video',
    type=str,
    default=None
)
parser.add_argument(
    '--win_size',
    type=int,
    default=None
)
parser.add_argument(
    '--save_dir',
    type=str,
    default=None
)
args = parser.parse_args()



if __name__ == '__main__':
    baseline_dir = args.baseline_dir
    ttt_dir = args.ttt_dir
    video = args.video
    win_size = args.win_size
    save_dir = args.save_dir

    save_dir = os.path.join(save_dir, video)
    os.makedirs(os.path.join(save_dir, "seg"), exist_ok=True)
    os.makedirs(os.path.join(save_dir, "recon"), exist_ok=True)

    # import ipdb; ipdb.set_trace()

    # Set ttt_dir and baseline dir
    ttt_dir = [p for p in glob(ttt_dir + "/*/", recursive=False) if "/" + video in p][0]
    baseline_dir = os.path.join(baseline_dir, video)

    ttt_seg_frames = glob(os.path.join(ttt_dir, str(win_size) + '_win', '*_final.png'))
    ttt_recon_frames = glob(os.path.join(ttt_dir, str(win_size) + '_win', '*_recon.png'))

    baseline_seg_frames = glob(os.path.join(baseline_dir, '*_final.png'))
    baseline_recon_frames = glob(os.path.join(baseline_dir, '*_recon.png'))

    # Sort and remove beginning frames from baseline
    ttt_seg_frames = sorted(ttt_seg_frames)
    ttt_recon_frames = sorted(ttt_recon_frames)

    baseline_seg_frames = sorted(baseline_seg_frames)[(win_size - 1):]
    baseline_recon_frames = sorted(baseline_recon_frames)[(win_size - 1):]

    # Segmentation: Concatenate (baseline top, ttt bottom)
    for i, (ttt, baseline) in enumerate(tqdm(zip(ttt_seg_frames, baseline_seg_frames), total=len(ttt_seg_frames))):
        ttt_img = np.asarray(Image.open(ttt).convert('RGB'))
        baseline_img = np.asarray(Image.open(baseline).convert('RGB'))
        whitespace = np.zeros_like(ttt_img)[:5, :, :]

        concat_img = np.concatenate((baseline_img, whitespace, ttt_img), axis=0)

        # Add to video stream
        save_path = os.path.join(save_dir, "seg", format(i, "06d") + '_frame.png')
        Image.fromarray(concat_img).save(save_path)

    # Reconstruction: Concatenate (baseline top, ttt bottom)
    for i, (ttt, baseline) in enumerate(tqdm(zip(ttt_recon_frames, baseline_recon_frames), total=len(ttt_recon_frames))):
        ttt_img = np.asarray(Image.open(ttt).convert('RGB'))
        baseline_img = np.asarray(Image.open(baseline).convert('RGB'))
        whitespace = np.zeros_like(ttt_img)[:5, :, :]

        concat_img = np.concatenate((baseline_img, whitespace, ttt_img), axis=0)

        # Add to video stream
        save_path = os.path.join(save_dir, "recon", format(i, "06d") + '_frame.png')
        Image.fromarray(concat_img).save(save_path)
