import torch
import sys
import numpy as np
from pyquaternion import Quaternion
import pickle

import warnings

warnings.filterwarnings("ignore")

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.build_sam import build_sam2_hq_video_predictor
from sam_hq2.grounding_dino.groundingdino.util.inference import load_model

from Video_Depth_Anything.video_depth_anything.video_depth import VideoDepthAnything

from MoGe.moge.model import MoGeModel

from PowerPaint.app import PowerPaintController

from InvSR.sampler_invsr import InvSamplerSR
from InvSR.inference_invsr import get_configs, get_parser

from MOFA.MOFA_Video_Traj.inference import Drag, get_pipeline

from Depth_Anything_V2.metric_depth.depth_anything_v2.dpt import DepthAnythingV2
from openai import OpenAI

from transformers import AutoImageProcessor

CAM = [
    "CAM_FRONT",
    "CAM_FRONT_RIGHT",
    "CAM_FRONT_LEFT",
    "CAM_BACK",
    "CAM_BACK_LEFT",
    "CAM_BACK_RIGHT",
]


def build_model(base_path, device, api_key, resize_h=512, resize_w=512, n_frames=10):
    val_preprocess = AutoImageProcessor.from_pretrained(
        str(base_path)
        + "/checkpoints/models--facebook--dinov2-large/snapshots/47b73eefe95e8d44ec3623f8890bd894b6ea2d6c"
    )

    sam2_checkpoint = str(base_path) + "/checkpoints/sam2.1_hq_hiera_large.pt"
    model_cfg = "configs/sam2.1/sam2.1_hq_hiera_l.yaml"
    imagesam = SAM2ImagePredictor(build_sam2(model_cfg, sam2_checkpoint, device=device))
    videosam = build_sam2_hq_video_predictor(model_cfg, sam2_checkpoint, device=device)

    grounding_model = load_model(
        model_config_path=str(base_path)
        + "/sam_hq2/grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py",
        model_checkpoint_path=str(base_path)
        + "/checkpoints/groundingdino_swint_ogc.pth",
        device=device,
    )
    grounding_model = grounding_model.to(device)

    video_depth_anything = VideoDepthAnything("vitl")
    video_depth_anything.load_state_dict(
        torch.load(
            f"{base_path}/checkpoints/video_depth_anything_vitl.pth", map_location="cpu"
        ),
        strict=True,
    )
    video_depth_anything = video_depth_anything.to(device).eval()

    depth_anything = DepthAnythingV2(
        **{
            **{
                "encoder": "vitl",
                "features": 256,
                "out_channels": [256, 512, 1024, 1024],
                "max_depth": 80.0,
            }
        }
    )
    depth_anything.load_state_dict(
        torch.load(
            f"{base_path}/checkpoints/depth_anything_v2_metric_vkitti_vitl.pth",
            map_location="cpu",
        )
    )
    depth_anything = depth_anything.to(device).eval()

    moge = MoGeModel.from_pretrained(f"{base_path}/checkpoints/moge-vitl/model.pt").to(
        device
    )

    weight_dtype = torch.float16
    inpainter = PowerPaintController(
        weight_dtype, str(base_path) + "/checkpoints/ppt-v2", False, "ppt-v2", device
    )
    sys.argv = ["script_name"]
    args = get_parser()
    configs = get_configs(args)
    invsr = InvSamplerSR(configs, device)

    pipeline, cmp = get_pipeline(base_path, device, weight_dtype)
    i2v = Drag(pipeline, cmp, device, resize_h, resize_w, n_frames)

    gpt = OpenAI(api_key=api_key)

    return (
        imagesam,
        videosam,
        grounding_model,
        video_depth_anything,
        depth_anything,
        moge,
        inpainter,
        invsr,
        i2v,
        gpt,
        val_preprocess,
    )


def from_openpcdet_to_fillthegap(infos, save_path):
    input_dict = {
        "sweep_tr": [],
        "gt_boxes": [],
        "sweep_gt_boxes": [],
        "sweep_instance_tokens": [],
        "sweep_gt_names": [],
        "gt_names": [],
        "lidarpath": [],
        "segpath": [],
        "num_lidar_pts": [],
        "token": [],
        "imgpath": [],
        "lidar2image": [],
        "lidar2camera": [],
        "camera_intrinsic": [],
        "camera2lidar": [],
        "camera2global": [],
        "camera2ego": [],
        "ego2global": [],
    }
    for i in range(len(infos)):
        sweep_tr = [np.eye(4)]
        sweep_gt_boxes, sweep_instance_tokens, sweep_gt_names = [], [], []
        for s in range(len(infos[i]["sweeps"])):
            tr = infos[i]["sweeps"][s]["transform_matrix"]
            if tr is None:
                tr = np.eye(4)
            sweep_tr.append(tr)

            sweep_gt_boxes.append(infos[i]["sweeps"][s]["gt_boxes"])
            sweep_instance_tokens.append(infos[i]["sweeps"][s]["instance_tokens"])
            sweep_gt_names.append(infos[i]["sweeps"][s]["gt_names"])

        input_dict["sweep_gt_boxes"].append(sweep_gt_boxes)
        input_dict["sweep_instance_tokens"].append(sweep_instance_tokens)
        input_dict["sweep_gt_names"].append(sweep_gt_names)

        input_dict["sweep_tr"].append(np.array(sweep_tr))
        input_dict["gt_boxes"].append(infos[i]["gt_boxes"])

        input_dict["lidarpath"].append(infos[i]["lidar_path"])
        input_dict["segpath"].append(infos[i]["lidarseg_path"])

        input_dict["num_lidar_pts"].append(infos[i]["num_lidar_pts"])
        input_dict["token"].append(infos[i]["token"])
        input_dict["gt_names"].append(infos[i]["gt_names"])

        cam_dict = {}
        cam_dict["lidar2image"] = []
        cam_dict["lidar2camera"] = []
        cam_dict["path"] = []
        cam_dict["camera2lidar"] = []
        cam_dict["camera_intrinsic"] = []
        cam_dict["camera2global"] = []
        cam_dict["camera2ego"] = []
        cam_dict["ego2global"] = []

        for j, (_, camera_info) in enumerate(infos[i]["cams"].items()):
            lidar2camera_r = np.linalg.inv(camera_info["sensor2lidar_rotation"])
            lidar2camera_t = camera_info["sensor2lidar_translation"] @ lidar2camera_r.T
            lidar2camera_rt = np.eye(4).astype(np.float32)
            lidar2camera_rt[:3, :3] = lidar2camera_r.T
            lidar2camera_rt[3, :3] = -lidar2camera_t
            cam_dict["lidar2camera"].append(lidar2camera_rt)

            camera_intrinsics = np.eye(4).astype(np.float32)
            camera_intrinsics[:3, :3] = camera_info["camera_intrinsics"]
            cam_dict["camera_intrinsic"].append(camera_intrinsics)

            lidar2image = camera_intrinsics @ lidar2camera_rt.T
            cam_dict["lidar2image"].append(lidar2image)

            camera2lidar = np.eye(4).astype(np.float32)
            camera2lidar[:3, :3] = camera_info["sensor2lidar_rotation"]
            camera2lidar[:3, 3] = camera_info["sensor2lidar_translation"]
            cam_dict["camera2lidar"].append(camera2lidar)

            camera2ego = np.eye(4).astype(np.float32)
            camera2ego[:3, :3] = Quaternion(
                camera_info["sensor2ego_rotation"]
            ).rotation_matrix
            camera2ego[:3, 3] = camera_info["sensor2ego_translation"]
            cam_dict["camera2ego"].append(camera2ego)

            ego2global = np.eye(4).astype(np.float32)
            ego2global[:3, :3] = Quaternion(
                camera_info["ego2global_rotation"]
            ).rotation_matrix
            ego2global[:3, 3] = camera_info["ego2global_translation"]
            camera2global = ego2global @ camera2ego
            cam_dict["camera2global"].append(camera2global)
            cam_dict["ego2global"].append(ego2global)

            cam_dict["path"].append(camera_info["data_path"])

        input_dict["lidar2image"].append(np.stack(cam_dict["lidar2image"], axis=0))
        input_dict["lidar2camera"].append(np.stack(cam_dict["lidar2camera"], axis=0))
        input_dict["camera_intrinsic"].append(
            np.stack(cam_dict["camera_intrinsic"], axis=0)
        )
        input_dict["imgpath"].append(cam_dict["path"])
        input_dict["camera2lidar"].append(np.stack(cam_dict["camera2lidar"], axis=0))
        input_dict["camera2global"].append(np.stack(cam_dict["camera2global"], axis=0))
        input_dict["camera2ego"].append(np.stack(cam_dict["camera2ego"], axis=0))
        input_dict["ego2global"].append(np.stack(cam_dict["ego2global"], axis=0))

    input_dict["lidar2image"] = np.stack(input_dict["lidar2image"], axis=0)
    input_dict["lidar2camera"] = np.stack(input_dict["lidar2camera"], axis=0)
    input_dict["camera_intrinsic"] = np.stack(input_dict["camera_intrinsic"], axis=0)
    input_dict["imgpath"] = np.stack(input_dict["imgpath"], axis=0)
    input_dict["lidarpath"] = np.stack(input_dict["lidarpath"], axis=0)
    input_dict["sweep_tr"] = np.stack(input_dict["sweep_tr"], axis=0)

    with open(save_path, "wb") as f:
        pickle.dump(input_dict, f)
