import os
import sys
import torch
import argparse
import numpy as np
import os.path as osp
import torch.nn.functional as F

sys.path.append('.')
from utils.utils import read, write
from flow_generation.liteflownet.run import estimate

parser = argparse.ArgumentParser(
                prog = 'AMT',
                description = 'Flow generation',
                )
parser.add_argument('-r', '--root', default='data/vimeo_triplet') 
args = parser.parse_args()

vimeo90k_dir = args.root
vimeo90k_sequences_dir = osp.join(vimeo90k_dir, 'sequences')
vimeo90k_flow_dir = osp.join(vimeo90k_dir, 'flow')

def pred_flow(img1, img2):
    img1 = torch.from_numpy(img1).float().permute(2, 0, 1) / 255.0
    img2 = torch.from_numpy(img2).float().permute(2, 0, 1) / 255.0

    flow = estimate(img1, img2)

    flow = flow.permute(1, 2, 0).cpu().numpy()
    return flow

print('Built Flow Path')
if not osp.exists(vimeo90k_flow_dir):
    os.makedirs(vimeo90k_flow_dir)

for sequences_path in sorted(os.listdir(vimeo90k_sequences_dir)):
    vimeo90k_sequences_path_dir = osp.join(vimeo90k_sequences_dir, sequences_path)
    vimeo90k_flow_path_dir = osp.join(vimeo90k_flow_dir, sequences_path)
    if not osp.exists(vimeo90k_flow_path_dir):
        os.mkdir(vimeo90k_flow_path_dir)
        
    for sequences_id in sorted(os.listdir(vimeo90k_sequences_path_dir)):
        vimeo90k_flow_id_dir = osp.join(vimeo90k_flow_path_dir, sequences_id)
        if not osp.exists(vimeo90k_flow_id_dir):
            os.mkdir(vimeo90k_flow_id_dir)

for sequences_path in sorted(os.listdir(vimeo90k_sequences_dir)):
    vimeo90k_sequences_path_dir = os.path.join(vimeo90k_sequences_dir, sequences_path)
    vimeo90k_flow_path_dir = os.path.join(vimeo90k_flow_dir, sequences_path)
    
    for sequences_id in sorted(os.listdir(vimeo90k_sequences_path_dir)):
        vimeo90k_sequences_id_dir = os.path.join(vimeo90k_sequences_path_dir, sequences_id)
        vimeo90k_flow_id_dir = os.path.join(vimeo90k_flow_path_dir, sequences_id)
        
        img0_path = vimeo90k_sequences_id_dir + '/im1.png'
        imgt_path = vimeo90k_sequences_id_dir + '/im2.png'
        img1_path = vimeo90k_sequences_id_dir + '/im3.png'
        flow_t0_path = vimeo90k_flow_id_dir + '/flow_t0.flo'
        flow_t1_path = vimeo90k_flow_id_dir + '/flow_t1.flo'
        
        img0 = read(img0_path)
        imgt = read(imgt_path)
        img1 = read(img1_path)
        
        flow_t0 = pred_flow(imgt, img0)
        flow_t1 = pred_flow(imgt, img1)
        
        write(flow_t0_path, flow_t0)
        write(flow_t1_path, flow_t1)
        
    print('Written Sequences {}'.format(sequences_path))