import os
import sys

sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../")))
from pytorch_lightning import seed_everything
from scripts.demo.streamlit_helpers import *
from scripts.demo.sv3d_helpers import *

SAVE_PATH = "outputs/demo/vid/"

VERSION2SPECS = {
    "svd": {
        "T": 14,
        "H": 576,
        "W": 1024,
        "C": 4,
        "f": 8,
        "config": "configs/inference/svd.yaml",
        "ckpt": "checkpoints/svd.safetensors",
        "options": {
            "discretization": 1,
            "cfg": 2.5,
            "sigma_min": 0.002,
            "sigma_max": 700.0,
            "rho": 7.0,
            "guider": 2,
            "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
            "num_steps": 25,
        },
    },
    "svd_image_decoder": {
        "T": 14,
        "H": 576,
        "W": 1024,
        "C": 4,
        "f": 8,
        "config": "configs/inference/svd_image_decoder.yaml",
        "ckpt": "checkpoints/svd_image_decoder.safetensors",
        "options": {
            "discretization": 1,
            "cfg": 2.5,
            "sigma_min": 0.002,
            "sigma_max": 700.0,
            "rho": 7.0,
            "guider": 2,
            "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
            "num_steps": 25,
        },
    },
    "svd_xt": {
        "T": 25,
        "H": 576,
        "W": 1024,
        "C": 4,
        "f": 8,
        "config": "configs/inference/svd.yaml",
        "ckpt": "checkpoints/svd_xt.safetensors",
        "options": {
            "discretization": 1,
            "cfg": 3.0,
            "min_cfg": 1.5,
            "sigma_min": 0.002,
            "sigma_max": 700.0,
            "rho": 7.0,
            "guider": 2,
            "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
            "num_steps": 30,
            "decoding_t": 14,
        },
    },
    "svd_xt_image_decoder": {
        "T": 25,
        "H": 576,
        "W": 1024,
        "C": 4,
        "f": 8,
        "config": "configs/inference/svd_image_decoder.yaml",
        "ckpt": "checkpoints/svd_xt_image_decoder.safetensors",
        "options": {
            "discretization": 1,
            "cfg": 3.0,
            "min_cfg": 1.5,
            "sigma_min": 0.002,
            "sigma_max": 700.0,
            "rho": 7.0,
            "guider": 2,
            "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
            "num_steps": 30,
            "decoding_t": 14,
        },
    },
    "sv3d_u": {
        "T": 21,
        "H": 576,
        "W": 576,
        "C": 4,
        "f": 8,
        "config": "configs/inference/sv3d_u.yaml",
        "ckpt": "checkpoints/sv3d_u.safetensors",
        "options": {
            "discretization": 1,
            "cfg": 2.5,
            "sigma_min": 0.002,
            "sigma_max": 700.0,
            "rho": 7.0,
            "guider": 3,
            "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
            "num_steps": 50,
            "decoding_t": 14,
        },
    },
    "sv3d_p": {
        "T": 21,
        "H": 576,
        "W": 576,
        "C": 4,
        "f": 8,
        "config": "configs/inference/sv3d_p.yaml",
        "ckpt": "checkpoints/sv3d_p.safetensors",
        "options": {
            "discretization": 1,
            "cfg": 2.5,
            "sigma_min": 0.002,
            "sigma_max": 700.0,
            "rho": 7.0,
            "guider": 3,
            "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
            "num_steps": 50,
            "decoding_t": 14,
        },
    },
}


if __name__ == "__main__":
    st.title("Stable Video Diffusion / SV3D")
    version = st.selectbox(
        "Model Version",
        [k for k in VERSION2SPECS.keys()],
        0,
    )
    version_dict = VERSION2SPECS[version]
    if st.checkbox("Load Model"):
        mode = "img2vid"
    else:
        mode = "skip"

    H = st.sidebar.number_input(
        "H", value=version_dict["H"], min_value=64, max_value=2048
    )
    W = st.sidebar.number_input(
        "W", value=version_dict["W"], min_value=64, max_value=2048
    )
    T = st.sidebar.number_input(
        "T", value=version_dict["T"], min_value=0, max_value=128
    )
    C = version_dict["C"]
    F = version_dict["f"]
    options = version_dict["options"]

    if mode != "skip":
        state = init_st(version_dict, load_filter=True)
        if state["msg"]:
            st.info(state["msg"])
        model = state["model"]

        ukeys = set(
            get_unique_embedder_keys_from_conditioner(state["model"].conditioner)
        )

        value_dict = init_embedder_options(
            ukeys,
            {},
        )

        if "fps" not in ukeys:
            value_dict["fps"] = 10

        value_dict["image_only_indicator"] = 0

        if mode == "img2vid":
            img = load_img_for_prediction(W, H)
            if "sv3d" in version:
                cond_aug = 1e-5
            else:
                cond_aug = st.number_input(
                    "Conditioning augmentation:", value=0.02, min_value=0.0
                )
            value_dict["cond_frames_without_noise"] = img
            value_dict["cond_frames"] = img + cond_aug * torch.randn_like(img)
            value_dict["cond_aug"] = cond_aug

        if "sv3d_p" in version:
            elev_deg = st.number_input("elev_deg", value=5, min_value=-90, max_value=90)
            trajectory = st.selectbox(
                "Trajectory",
                ["same elevation", "dynamic"],
                0,
            )
            if trajectory == "same elevation":
                value_dict["polars_rad"] = np.array([np.deg2rad(90 - elev_deg)] * T)
                value_dict["azimuths_rad"] = np.linspace(0, 2 * np.pi, T + 1)[1:]
            elif trajectory == "dynamic":
                azim_rad, elev_rad = gen_dynamic_loop(length=21, elev_deg=elev_deg)
                value_dict["polars_rad"] = np.deg2rad(90) - elev_rad
                value_dict["azimuths_rad"] = azim_rad
        elif "sv3d_u" in version:
            elev_deg = st.number_input("elev_deg", value=5, min_value=-90, max_value=90)
            value_dict["polars_rad"] = np.array([np.deg2rad(90 - elev_deg)] * T)
            value_dict["azimuths_rad"] = np.linspace(0, 2 * np.pi, T + 1)[1:]

        seed = st.sidebar.number_input(
            "seed", value=23, min_value=0, max_value=int(1e9)
        )
        seed_everything(seed)

        save_locally, save_path = init_save_locally(
            os.path.join(SAVE_PATH, version), init_value=True
        )

        if "sv3d" in version:
            plot_save_path = os.path.join(save_path, "plot_3D.png")
            plot_3D(
                azim=value_dict["azimuths_rad"],
                polar=value_dict["polars_rad"],
                save_path=plot_save_path,
                dynamic=("sv3d_p" in version),
            )
            st.image(
                plot_save_path,
                f"3D camera trajectory",
            )

        options["num_frames"] = T

        sampler, num_rows, num_cols = init_sampling(options=options)
        num_samples = num_rows * num_cols

        decoding_t = st.number_input(
            "Decode t frames at a time (set small if you are low on VRAM)",
            value=options.get("decoding_t", T),
            min_value=1,
            max_value=int(1e9),
        )

        if st.checkbox("Overwrite fps in mp4 generator", False):
            saving_fps = st.number_input(
                f"saving video at fps:", value=value_dict["fps"], min_value=1
            )
        else:
            saving_fps = value_dict["fps"]

        if st.button("Sample"):
            out = do_sample(
                model,
                sampler,
                value_dict,
                num_samples,
                H,
                W,
                C,
                F,
                T=T,
                batch2model_input=["num_video_frames", "image_only_indicator"],
                force_uc_zero_embeddings=options.get("force_uc_zero_embeddings", None),
                force_cond_zero_embeddings=options.get(
                    "force_cond_zero_embeddings", None
                ),
                return_latents=False,
                decoding_t=decoding_t,
            )

            if isinstance(out, (tuple, list)):
                samples, samples_z = out
            else:
                samples = out
                samples_z = None

            if save_locally:
                save_video_as_grid_and_mp4(samples, save_path, T, fps=saving_fps)
