import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
from pathlib import Path
from plot.datasets.kitti import read_kitti_calib

import sys
sys.path.insert(0, "UniDepth")
from unidepth.models import UniDepthV1, UniDepthV2
from unidepth.utils.camera import Pinhole


class DepthModel:
    def __init__(self, model='unidepth', device="cpu") -> None:
        self.model_name = model
        if model == 'unidepthv1':
            self.model = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-vitl14")
        else:
            self.model = UniDepthV2.from_pretrained("lpiccinelli/unidepth-v2-vitl14")
            # set resolution level and interpolation mode (only V2)
            self.model.resolution_level = 9
            self.model.interpolation_mode = "bilinear"

        self.model = self.model.to(device).eval()

    def preprocess_unidepth(self, image, intrinsics=None):
        image = torch.from_numpy(image).permute(2, 0, 1)
        if self.model_name == "unidepthv1":
            camera = torch.from_numpy(intrinsics)
        else:
            camera = Pinhole(K=torch.from_numpy(intrinsics).unsqueeze(0))
        return image, camera
    
    def __call__(self, image_path, intrinsics=None):
        image = np.array(Image.open(image_path))
        if 'unidepth' in self.model_name:
            image, intrinsics = self.preprocess_unidepth(image, intrinsics)
            results = self.model.infer(image, intrinsics)
            depth_pred = results['depth'].squeeze().cpu().numpy()
            intrinsics_pred = results['intrinsics'].squeeze().cpu().numpy()
        
        if self.model_name == 'unidepthv2':
            conf = results['confidence'].squeeze().cpu().numpy()
        else:
            conf = None
        depth_pred[depth_pred == np.inf] = 0.
        return depth_pred, conf, intrinsics_pred



if __name__ == '__main__':
    root = Path("PATH_TO_DATA")

    device = torch.device("cuda:2")
    # scenes = sorted(list((root / "frames").glob("*")))
    with open(root / "ImageSets" / "val.txt") as f:
        train_scene_names = f.read().splitlines()
    
    save_dir = root / f"unidepthv1_val"
    save_dir.mkdir(parents=True, exist_ok=True)

    depth_model = DepthModel("unidepthv1", device)

    for scene_name in tqdm(train_scene_names):
        image_paths = sorted(list((root / "frames" / scene_name).glob("*.png")))
        K = read_kitti_calib(root / "training" / "calib" / f"{scene_name}.txt") 
        scene_save_folder = save_dir / scene_name
        scene_save_folder.mkdir(parents=True, exist_ok=True)

        for image_path in image_paths:
            depth_pred, depth_conf, K_pred = depth_model(image_path, K)
            if depth_conf is not None:
                np.savez_compressed(scene_save_folder / f"{image_path.stem}",
                                    depth=depth_pred.astype(np.float32),
                                    conf=depth_conf.astype(np.float32),
                                    intrinsics=K_pred.astype(np.float32))
            else:
                np.savez_compressed(scene_save_folder / f"{image_path.stem}",
                                    depth=depth_pred.astype(np.float32),
                                    intrinsics=K_pred.astype(np.float32))
