import hydra
from omegaconf import DictConfig, OmegaConf
import numpy as np
from data.load import Camera, InstanceMasks3D, Images, PointCloud, get_number_of_images
from utils import get_free_gpu, create_out_folder
from mask_features_computation.features_extractor import FeaturesExtractor
import torch
import os

# TIP: add version_base=None to the arguments if you encounter some error
@hydra.main(config_path="configs", config_name="openmask3d_inference")
def main(ctx: DictConfig):

    device = "cpu" 
    device = get_free_gpu(min_mem=7000) if torch.cuda.is_available() else device
    print(f"Using device: {device}")
    
    out_folder = ctx.output.output_directory
    
    # convert all paths to absolute paths
    os.chdir(hydra.utils.get_original_cwd())
    ctx.data.masks.masks_path = os.path.abspath(ctx.data.masks.masks_path)
    # if ctx.data.camera.poses_path:
    #     ctx.data.camera.poses_path = os.path.abspath(ctx.data.camera.poses_path)
    #     ctx.data.camera.intrinsic_path = os.path.abspath(ctx.data.camera.intrinsic_path)
    # else:
    ctx.data.camera.poses_intrinsic_path = os.path.abspath(ctx.data.camera.poses_intrinsic_path)
    ctx.data.depths.depths_path = os.path.abspath(ctx.data.depths.depths_path)
    ctx.data.images.images_path = os.path.abspath(ctx.data.images.images_path)
    ctx.data.point_cloud_path = os.path.abspath(ctx.data.point_cloud_path)
    ctx.external.sam_checkpoint = os.path.abspath(ctx.external.sam_checkpoint)
    ctx.output.output_directory = os.path.abspath(ctx.output.output_directory)

    # 1. Load the masks
    assert os.path.exists(ctx.data.masks.masks_path), f"Path to masks does not exist: {ctx.data.masks.masks_path} - first run compute_masks_single_scene.sh!"
    masks = InstanceMasks3D(ctx.data.masks.masks_path)
    print(f"[INFO] Masks loaded. {masks.num_masks} masks found.")    
    
    # 2. Load the images
    indices = np.arange(0, get_number_of_images(ctx.data.camera.poses_intrinsic_path), step = ctx.openmask3d.frequency)
    images = Images(images_path=ctx.data.images.images_path, 
                    extension=ctx.data.images.images_ext, 
                    indices=indices)
    print(f"[INFO] Images loaded. {len(images.images)} images found at {ctx.data.images.images_path}.")
    assert len(images.images) > 0
    
    # 3. Load the pointcloud
    pointcloud = PointCloud(ctx.data.point_cloud_path)
    print(f"[INFO] Pointcloud loaded. {pointcloud.num_points} points found.")
    
    # 4. Load the camera configurations
    camera = Camera(intrinsic_resolution=ctx.data.camera.intrinsic_resolution, 
                    poses_intrinsic_path=ctx.data.camera.poses_intrinsic_path, 
                    depths_path=ctx.data.depths.depths_path, 
                    extension_depth=ctx.data.depths.depths_ext, 
                    depth_scale=ctx.data.depths.depth_scale)
    print("[INFO] Camera configurations loaded.")

    # 5. Run extractor
    if ctx.external.use_opendas:
        cfg_opendas = OmegaConf.create({
            "MODEL": {
                "CLIP_ADAPTER": {
                    "CLIP_MODEL_NAME": ctx.external.clip_model
                },
                "OPENDAS": {
                    "DIR": "./multimodal-prompt-learning/output/scannetpp_similar_negative_v2/OpenDAS/vit_l14_c2_ep10_batch16_2+2ctx_d24_use_both_losses_0shots/seed429",
                    "LOAD_EPOCH": 8,
                    "PROMPT_DEPTH_VISION": 24,
                    "PROMPT_DEPTH_TEXT": 24,
                    "N_CTX_TEXT": 4,
                    "N_CTX_VISION": 8,
                    "CTX_INIT": "a photo of a",
                    "INPUT_SIZE": (224, 224),
                }
            }
        })
    elif ctx.external.use_vpt:
        model_dir = "./multimodal-prompt-learning/output/scannetpp/VPT/vit_l14_c2_ep5_batch32_4_depth_24_0shots/seed428"
        cfg_opendas = OmegaConf.create({
            "MODEL": {
                "CLIP_ADAPTER": {
                    "CLIP_MODEL_NAME": ctx.external.clip_model
                },
                "VPT": {
                    "DIR": model_dir,
                    "LOAD_EPOCH": 5,
                    "PROMPT_DEPTH_VISION": 24,
                    "N_CTX_VISION": 8,
                    "CTX_INIT": "a photo of a",
                    "INPUT_SIZE": (224, 224),
                }
            }
        }) 
    else:
        cfg_opendas = None
    features_extractor = FeaturesExtractor(camera=camera, 
                                           clip_model=ctx.external.clip_model, 
                                           images=images, 
                                           masks=masks,
                                           pointcloud=pointcloud, 
                                           sam_model_type=ctx.external.sam_model_type,
                                           sam_checkpoint=ctx.external.sam_checkpoint,
                                           vis_threshold=ctx.openmask3d.vis_threshold,
                                           rotation_deg_apply=ctx.data.rotation_degrees,
                                           device=device,
                                           cfg=cfg_opendas,
                                           use_opendas=ctx.external.use_opendas,
                                           use_vpt=ctx.external.use_vpt)
    print("[INFO] Computing per-mask CLIP features.")
    features = features_extractor.extract_features(topk=ctx.openmask3d.top_k, 
                                                   multi_level_expansion_ratio = ctx.openmask3d.multi_level_expansion_ratio,
                                                   num_levels=ctx.openmask3d.num_of_levels, 
                                                   num_random_rounds=ctx.openmask3d.num_random_rounds,
                                                   num_selected_points=ctx.openmask3d.num_selected_points,
                                                   save_crops=ctx.output.save_crops,
                                                   out_folder=out_folder,
                                                   optimize_gpu_usage=ctx.gpu.optimize_gpu_usage)
    print("[INFO] Features computed.")
    # 6. Save features
    scene_name = os.path.join(ctx.data.masks.masks_path).split("/")[-2]
    filename = f"{scene_name}_openmask3d_features.npy"
    output_path = os.path.join(out_folder, filename)
    np.save(output_path, features)
    print(f"[INFO] Masks features saved to {output_path}.")
    
if __name__ == "__main__":
    main()