import os
import warnings
import logging
import contextlib
from io import StringIO
import torch
import sys
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
import time
from harmonic_blending import harmonic_blend_of_depths
import my_utils
from my_utils import printc
import ldi_inpaiting as ldi
from sky_segmentation import SkyMaskDetector

# Disabling some warnings
os.environ["GLOG_minloglevel"] = "2"
os.environ["GLOG_logtostderr"] = "0"
os.environ["CERES_MINIMIZER_PROGRESS_TO_STDOUT"] = "0"
logging.disable(logging.CRITICAL + 1)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.simplefilter("ignore", FutureWarning)

# local imports
_360monodepth_install_dir = "../LayerPano3D/submodules/360monodepth/code/python/src/"
sys.path.append(_360monodepth_install_dir) 
with contextlib.redirect_stdout(StringIO()):
    from sphericaldreamer import SphericalDreamer

_phase_1a = "1a"
_phase_1b = "1b"
_phase_2a = "2a"
_phase_2b = "2b"
_phase_2c = "2c"

_phase_current = _phase_2b

def get_save_viz_path(dream):
    return save_dir_ / f"align_{dream:02d}" / _phase_current / "ldi_insights"

if __name__ == "__main__":
    config = my_utils.fetch_config_via_parser(
        debug=False
    )
    seeds, width, height, save_dir_, pose_init, pose_end, translation_direction = my_utils.setup(config)
    plot_results = config.ldi.save_plots

    spherical_dreamer = SphericalDreamer(
        pano_width=width,
        pano_height=height,
        pano_depth_temp_dir='/tmp/pano_depth_temp',
        depth_model=config.depth_model,
    )

    # ----------------------------------------------------- #
    # ------------ PHASE 2-B. LDI INPAINTING -------------- #
    # ----------------------------------------------------- #
    printc(f"=== [PHASE {_phase_current}] EXPERIMENT: {config.expname} ===", color='cyan')

    if config.phase2.apply_ldi:
        if not config.load_phase2b_from:
            printc(f"=== PHASE {_phase_current}: LDI INPAINTING ===", color='green')

            # -----------------------------
            # 0. LOAD INPUT IMAGES + DEPTH
            # -----------------------------
            list_img = []
            list_depth_origin = []
            for i in range(1, config.num_dreams):
                printc(f"--- {_phase_current}: load image  {i:02d} / {config.num_dreams} ---", color='yellow')

                save_dir__ = save_dir_ / f"align_{i:02d}"

                data =  np.load(save_dir__ / _phase_2a / ".cache" / "other_data.npy", allow_pickle=True).item()
                
                depth_estimated       = data['depth_estimated']
                pano_rgb_inpainted    = data['pano_rgb_inpainted']

                list_img.append(my_utils.opencv_resize(my_utils.PIL_to_numpy(pano_rgb_inpainted), config.height, config.width, mode="bilinear"))
                list_depth_origin.append(my_utils.opencv_resize(depth_estimated, config.height, config.width, mode="bilinear"))

            # ---------------------------------------
            # 1. COMPUTE MASK FOR FOREGROUND OBJECTS
            # ---------------------------------------
            t0 = time.time()
            list_mask = []
            sam, mask_generator = ldi.instanciate_sam(config)

            for i in range(1, config.num_dreams):
                printc(f"--- {_phase_current}: Compute mask for foreground object  {i:02d} / {config.num_dreams} ---", color='yellow')
                
                save_viz_path = get_save_viz_path(i)
                os.makedirs(save_viz_path, exist_ok=True)

                mask = ldi.get_foreground_segmask(
                    config,
                    mask_generator, 
                    list_img[i-1],
                    list_depth_origin[i-1],
                    plot_results=plot_results,
                    save_path=save_viz_path,
                )
                list_mask.append(mask)

            del sam
            del mask_generator
            torch.cuda.empty_cache()
            print(f"Foreground mask computed in {time.time() - t0:.1f} seconds for {config.num_dreams} images.")

            # ------------------------
            # 2. INPAINTING WITH LAMA
            # ------------------------
            t0 = time.time()
            list_prompt=[]
            list_mask_smooth_pil = []
            list_inpaint_pano_lama_pil = []
            list_viz_kwargs = []
            llm_model, processor = ldi.instanciate_llm_and_processor()

            for i in range(1, config.num_dreams):
                printc(f"--- {_phase_current}: Lama Inpainting  {i:02d} / {config.num_dreams} ---", color='yellow')
                prompt, mask_smooth_pil, inpaint_pano_lama_pil, viz_kwargs = ldi.lama_flux_double_inpainting_p1(
                    config,
                    spherical_dreamer,
                    llm_model,
                    processor,
                    image=list_img[i-1],
                    mask=list_mask[i-1],
                )
                list_prompt.append(prompt)
                list_mask_smooth_pil.append(mask_smooth_pil)
                list_inpaint_pano_lama_pil.append(inpaint_pano_lama_pil)
                list_viz_kwargs.append(viz_kwargs)

            spherical_dreamer._release_lama_memory()
            del llm_model
            del processor
            torch.cuda.empty_cache()
            print(f"Lama inpainting done in {time.time() - t0:.1f} seconds for {config.num_dreams} images.")

            # --------------------------
            # 3. INPAINTING WITH FLUX
            # --------------------------

            t0 = time.time()
            list_inpaint_mask_pil = []
            list_inpaint_pano_pil = []

            for i in range(1, config.num_dreams):
                printc(f"--- {_phase_current}: Flux Inpainting  {i:02d} / {config.num_dreams} ---", color='yellow')
                inpaint_pano_pil, inpaint_mask_pil = ldi.lama_flux_double_inpainting_p2(
                    config,
                    spherical_dreamer,
                    list_prompt[i-1],
                    list_mask_smooth_pil[i-1],
                    list_inpaint_pano_lama_pil[i-1],
                    list_viz_kwargs[i-1],
                    plot_results=plot_results,
                    save_path=get_save_viz_path(i),
                )
                list_inpaint_mask_pil.append(inpaint_mask_pil)
                list_inpaint_pano_pil.append(inpaint_pano_pil)

            spherical_dreamer._release_flux_inpainting_memory()
            torch.cuda.empty_cache()
            print(f"FLUX inpainting done in {time.time() - t0:.1f} seconds for {config.num_dreams} images.")


            # -----------------------
            #   3.5 GET SKY MASKS
            # -----------------------
            list_sky_mask = []
            for i in range(config.num_dreams):
                printc(f"--- {_phase_current}: Get sky mask  {i:02d} / {config.num_dreams} ---", color='yellow')
                sky_mask = SkyMaskDetector.get_mask(
                    image=list_inpaint_pano_pil[i-1],
                )
                list_sky_mask.append(sky_mask)
                

            # -------------------------------------------------
            # 4. DEPTH INPAINTING (at resolution 1024 * 2048)
            # -------------------------------------------------
            t0 = time.time()
            list_depth_inpainted = []
            list_mask_inpaint_resized = []
            list_depth_origin_resized = []
            list_img_pil_resized = []
            # preparation
            if config.ldi.depth_inpainting.method == "infusion":
                pipe_dp = ldi.instanciate_pipe_dp()

            for i in range(1, config.num_dreams):
                printc(f"--- {_phase_current}: Depth Inpainting  {i:02d} / {config.num_dreams} ---", color='yellow')

                img_pil, depth_origin, _, inpaint_mask_bool_ = ldi.prepare_inpainting(
                    config,
                    list_img[i-1],
                    list_depth_origin[i-1],
                    list_inpaint_mask_pil[i-1],
                )

                if config.ldi.depth_inpainting.method == "horizontal_min_prior":
                    _, depth_prior = ldi.remove_low_freq(depth_origin, config=config.ldi.masking.depth_mean_based.remove_depth_low_freq)
                    inpaint_mask_bool_ = np.ones_like(inpaint_mask_bool_, dtype=bool)  # in this method, we inpaint everywhere
                    depth_inpainted = depth_prior

                elif config.ldi.depth_inpainting.method == "harmonic_blending":
                        depth_360_mono = spherical_dreamer.estimate_pano_depth(inpaint_pano_pil)
                        inpaint_pano = np.array(inpaint_pano_pil) / 255.0
                        _, _, _, depth_inpainted_hblending = harmonic_blend_of_depths(
                            colors=inpaint_pano, 
                            warped_depth_interp=depth_origin, #gt depth
                            depth_estimated=depth_360_mono, # new depth
                            missing_info_mask=inpaint_mask_bool_,
                            pose= np.eye(4).astype(np.float32),
                            sphere_radius=1.0,
                            height=inpaint_pano.shape[0],
                            width=inpaint_pano.shape[1],
                            phase=_phase_current,
                            logging=plot_results,
                            where_save=save_viz_path, 
                        )

                elif config.ldi.depth_inpainting.method == "infusion":
                    depth_inpainted = ldi.inpaint_bg_depth_infusion(
                        image=img_pil,
                        depth=depth_origin,
                        image_bg=list_inpaint_pano_pil[i-1],
                        bg_mask=inpaint_mask_pil,
                        pipe_dp=pipe_dp,
                        rescale_to_min_depth=True,
                        pad_width=config.ldi.depth_inpainting.pad_width,
                        plot_results=plot_results,
                        save_path=save_viz_path,
                    )

                elif config.ldi.depth_inpainting.method == "nearest":
                    depth_inpainted = ldi.interpolate_depth_nearest(
                        depth=depth_origin,
                        bg_mask=inpaint_mask_bool_,
                        pad_width=config.ldi.depth_inpainting.pad_width,
                    )
                
                elif config.ldi.depth_inpainting.method == "bilinear_plus_nn":
                    depth_inpainted = ldi.interpolate_depth_bilinear_plus_nn(
                        depth=depth_origin,
                        bg_mask=inpaint_mask_bool_,
                        pad_width=config.ldi.depth_inpainting.pad_width,
                    )
                
                else:
                    raise ValueError(f"Unknown depth inpainting method: {config.ldi.depth_inpainting.method}")

                depth_inpainted[~inpaint_mask_bool_] = np.nan
                if config.ldi.depth_inpainting.apply_post_processing:
                    depth_inpainted = ldi.post_process_inpainted_depth(
                        depth_bg=depth_inpainted,
                        depth_fg=depth_origin,
                        bg_mask=inpaint_mask_bool_,
                        plot=plot_results,
                        save_path=get_save_viz_path(i),
                    )
                list_depth_inpainted.append(depth_inpainted)
                list_mask_inpaint_resized.append(inpaint_mask_bool_)
                list_depth_origin_resized.append(depth_origin)
                list_img_pil_resized.append(img_pil)

            if config.ldi.depth_inpainting.method == "infusion":
                del pipe_dp
            torch.cuda.empty_cache()
            print(f"Depth inpainting done in {time.time() - t0:.1f} seconds for {config.num_dreams} images with method {config.ldi.depth_inpainting.method}.")

            # SAVE RESULTS
            for i in range(1, config.num_dreams):
                if config.ldi.depth_inpainting.method == "horizontal_min_prior":
                    suffix="hminprior"
                elif config.ldi.depth_inpainting.method == "harmonic_blending":
                    suffix="hblending"
                elif config.ldi.depth_inpainting.method == "infusion":
                    suffix="infusion"
                elif config.ldi.depth_inpainting.method == "nearest":
                    suffix="nn"
                elif config.ldi.depth_inpainting.method == "bilinear_plus_nn":
                    suffix="bilinear_nn"

                kwargs = {
                    "img_pil": list_img_pil_resized[i-1],
                    "inpaint_pano_pil": list_inpaint_pano_pil[i-1],
                    "inpaint_mask_pil": list_mask_inpaint_resized[i-1],
                    "depth_origin": list_depth_origin_resized[i-1],
                    f"depth_inpainted_{suffix}": list_depth_inpainted[i-1]
                }
                ldi.visualize_depth_inpainting(
                    **kwargs,
                    save_path= get_save_viz_path(i) / "10_depth_inpainting_visualization.png",
                )

                my_utils.save_rgbd_ldi_pano(
                    pano_rgb_bg=list_inpaint_pano_pil[i-1],
                    depth_bg=list_depth_inpainted[i-1],
                    mask_bg=list_mask_inpaint_resized[i-1],
                    sky_mask_bg=list_sky_mask[i-1],
                    dream=i,
                    save_dir_=save_dir_,
                    phase=_phase_current,
                ) 
            printc(f"PHASE {_phase_current} SUCCESSFULLY COMPLETED!", color='green')
        else:
            printc(f"SKIPPING PHASE {_phase_current}: BACKGROUND RGBD INPAINTING", color='magenta')
            printc(f"Loading instead from {config.load_phase2b_from}", color='magenta')
            
            source_phase2b_path = Path(config.save_dir) / config.load_phase2b_from
            dest_phase2b_path = Path(save_dir_)

            my_utils.copy_phase_folders(
                source_dir=source_phase2b_path,
                dest_dir=dest_phase2b_path,
                phase=_phase_current,
            )
    else:
        printc(f"PHASE {_phase_current}: LDI INPAINTING NOT APPLIED AS PER CONFIGURATION", color='magenta')