import argparse
import os
import torch
import torch.nn.functional as F

from utils.utils import (
    image_to_tensor,
    disparity_to_tensor,
    render_3dphoto,
)
from model.AdaMPI import MPIPredictor


parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--img_path', type=str, default="images/0810.png")
parser.add_argument('--disp_path', type=str, default="images/depth/0810.png")
parser.add_argument('--width', type=int, default=384)
parser.add_argument('--height', type=int, default=256)
parser.add_argument('--save_path', type=str, default="debug/0810.mp4")
parser.add_argument('--ckpt_path', type=str, default="adampiweight/adampi_64p.pth")
parser.add_argument('--num_frames', type=int, default=90)
parser.add_argument('--save_frames', type=bool, default=True)
parser.add_argument('--r_x', type=float, default=0.14)
parser.add_argument('--r_y', type=float, default=0.0)
parser.add_argument('--r_z', type=float, default=0.10)
opt, _ = parser.parse_known_args()

ckpt = torch.load(opt.ckpt_path)
model = MPIPredictor(
    width=opt.width,
    height=opt.height,
    num_planes=ckpt["num_planes"],
)
model.load_state_dict(ckpt["weight"])
model = model.cuda()
model = model.eval()

K = torch.tensor([
    [0.58, 0, 0.5],
    [0, 0.58, 0.5],
    [0, 0, 1]
]).cuda()
K[0, :] *= opt.width
K[1, :] *= opt.height
K = K.unsqueeze(0)


opt.save_path += "_"+str(opt.num_frames)+"_"+str(opt.r_x)+"_"+str(opt.r_y)+"_"+str(opt.r_z)
# load input
for folder in reversed(sorted(os.listdir(opt.img_path))):
    if not os.path.isdir(os.path.join(opt.img_path, folder)):
        continue
    print(folder)

    for image_path in sorted(os.listdir(os.path.join(opt.img_path, folder))):

        file_name = os.path.splitext(image_path)[0]
        cur_image_path = os.path.join(opt.img_path, folder, image_path)
        if os.path.isdir(os.path.join(opt.save_path, file_name)):
            print(f"Skipping {cur_image_path}")
            continue
        depth_map_path = os.path.join(opt.disp_path, folder, os.path.splitext(image_path)[0]+".png")
        print(cur_image_path, depth_map_path)

        image = image_to_tensor(cur_image_path).cuda()  # [1,3,h,w]
        disp = disparity_to_tensor(depth_map_path).cuda()  # [1,1,h,w]
        image = F.interpolate(image, size=(opt.height, opt.width), mode='bilinear', align_corners=True)
        disp = F.interpolate(disp, size=(opt.height, opt.width), mode='bilinear', align_corners=True)

# load pretrained model
# predict MPI planes
        try:
            with torch.no_grad():
                pred_mpi_planes, pred_mpi_disp = model(image, disp)  # [b,s,4,h,w]
        except Exception as e:
            print(f"Skipping {cur_image_path} : {e}")
            continue

# render 
        if opt.save_frames:
            save_path = os.path.join(opt.save_path, file_name)
            try:
                os.makedirs(os.path.join(opt.save_path, file_name), exist_ok=True)
            except:
                if not os.path.isdir(os.path.join(opt.save_path, file_name)):
                    os.mkdir(os.path.join(opt.save_path, file_name))
                #pass
        else:
            save_path = os.path.join(opt.save_path, folder)
            os.makedirs(save_path, exist_ok=True)
            save_path = os.path.join(save_path, file_name +".mp4")
        render_3dphoto(
            image,
            pred_mpi_planes,
            pred_mpi_disp,
            K,
            K,
            save_path,
            opt
           )
