import os, sys, torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

mono_path  = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
                          'third_party', 'monodepth2')
sys.path.append(mono_path)

from networks import ResnetEncoder, DepthDecoder 

def load_monodepth2(model_dir=None, device='cuda'):

    if model_dir is None:
        model_dir = os.path.join(mono_path, 'models', 'mono+stereo_640x192')

    encoder_path = os.path.join(model_dir, 'encoder.pth')
    decoder_path = os.path.join(model_dir, 'depth.pth')

    loaded_dict_enc = torch.load(encoder_path, map_location=device)
    encoder = ResnetEncoder(num_layers=18, pretrained=False)
    encoder.load_state_dict({k: v for k, v in loaded_dict_enc.items() if k in encoder.state_dict()})
    encoder.to(device).eval()

    depth_decoder = DepthDecoder(num_ch_enc=encoder.num_ch_enc, scales=range(4))
    depth_decoder.load_state_dict(torch.load(decoder_path, map_location=device))
    depth_decoder.to(device).eval()

    for m in (encoder, depth_decoder):
        for p in m.parameters(): p.requires_grad = False

    print(f"✓ Monodepth2 loaded from: {model_dir}")
    print(f"✓ Model moved to: {device}")
    return encoder, depth_decoder

def preprocess_img(img_pil):
    transform = transforms.Compose([
        transforms.Resize((192, 640)),
        transforms.ToTensor()
    ])
    return transform(img_pil).unsqueeze(0)   


@torch.no_grad()
def compute_depth(encoder, decoder, img_tensor, device='cuda', min_depth=0.1, max_depth=100.0):

    features = encoder(img_tensor.to(device))
    outputs  = decoder(features)

    disp = outputs[("disp", 0)]              

    depth = 1 / disp
    depth = torch.clamp(depth, min=min_depth, max=max_depth)

    return depth


def test_monodepth():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    enc, dec = load_monodepth2(device=device)

    img_path = '/home/ymd5170/datasets/kitti/data_odometry_color/dataset/sequences/00/image_2/000000.png'
    img_pil  = Image.open(img_path).convert('RGB')
    img_in   = preprocess_img(img_pil)       

    depth = compute_depth(enc, dec, img_in, device=device)  
    depth_map = depth.squeeze().cpu().numpy()

    print(f"Depth map shape : {depth.shape}")
    print(f"Depth range     : [{depth.min().item():.2f} m , {depth.max().item():.2f} m]")

    plt.figure(figsize=(10,4))
    plt.imshow(depth_map, cmap='magma')
    plt.axis('off')
    plt.show()


if __name__ == "__main__":
    test_monodepth()
