import os, sys

sys.path.append("../SLD")
sys.path.append("")
sys.path.append("../Orient-Anything")

from peft import PeftModelForCausalLM
from PIL import Image
from transformers import BitsAndBytesConfig
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor, AutoProcessor
from qwen_vl_utils import process_vision_info
from sld.llm_template import orientation_detection_prompt
import torch
import cv2
import numpy as np
import cv2
import matplotlib.pyplot as plt
from transformers import DPTImageProcessor, DPTForDepthEstimation
import random

from paths import DINO_LARGE
from vision_tower import DINOv2_MLP
from transformers import AutoImageProcessor
import torch
from PIL import Image

import torch.nn.functional as F
from utils_orient import background_preprocess
from inference import get_3angle_infer_aug, get_3angle

from huggingface_hub import hf_hub_download


class OrientationModule():
    def __init__(self, peft_model_id=None, device="cpu"):
        ckpt_path = hf_hub_download(repo_id="Viglong/Orient-Anything", filename="croplargeEX2/dino_weight.pt", repo_type="model", resume_download=True)

        self.dino = DINOv2_MLP(
                    dino_mode   = 'large',
                    in_dim      = 1024,
                    out_dim     = 360+180+180+2,
                    evaluate    = True,
                    mask_dino   = False,
                    frozen_back = False
                )

        self.dino.eval()
        self.dino.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
        self.dino = self.dino.to(device)
        self.val_preprocess = AutoImageProcessor.from_pretrained(DINO_LARGE)
        self.device = device

        self.direction_objects = {'chicken', 'deer', 'car', 'bicycle', 'cow', 'horse', 'cat', 'bus', 'sheep', 'chair', 'bench', 'dog', "laptop", "woman", "clown"}

    def __call__(self, image, random_crop=True) -> str:
        rm_bkg_img = background_preprocess(image, True)
        # print(image.size)
        if random_crop:
            angles = get_3angle_infer_aug(image, rm_bkg_img, self.dino, self.val_preprocess, self.device)
        else:
            angles = get_3angle(image, self.dino, self.val_preprocess, self.device)
        orientation = angles[0]
        if 22.5 <= orientation < 67.5:
            return "forward-left"
        if 67.5 <= orientation < 112.5:
            return "left"
        if 112.5 <= orientation < 157.5:
            return "backward-left"
        if 157.5 <= orientation < 202.5:
            return "back"
        if 202.5 <= orientation < 247.5:
            return "backward-right"
        if 247.5 <= orientation < 292.5:
            return "right"
        if 292.5 <= orientation < 337.5:
            return "forward-right"
        return "front"
    
    def predict_orientation(self, object, bbox, image_source):

        if object not in self.direction_objects:
            return None
        image = Image.open(image_source).convert('RGB')
        W, H = image.size
        x, y, w, h = [float(data) for data in bbox]
        x_max = min(x + w, 1.0)
        y_max = min(y + h, 1.0)
        image = image.crop([x * W, y * H, x_max * W, y_max * H]).resize((W, H), Image.LANCZOS)
        # 
        
        pred = self(image)
        return pred


class DepthModule():
    def __init__(self, cur_device):
        self.processor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
        self.model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large", device_map=cur_device)
        self.cur_device = cur_device

    def extract_depth(self, image):
        H, W, _ = image.shape
        inputs = self.processor(images=image, return_tensors="pt")
        inputs = inputs.to(self.cur_device)
        with torch.no_grad():
            outputs = self.model(**inputs)
            predicted_depth = outputs.predicted_depth
        # print(image.size)
        prediction = torch.nn.functional.interpolate(
            predicted_depth.unsqueeze(1),
            size=(H, W),
            mode="bicubic",
            align_corners=False,
        )
        output = prediction.squeeze().cpu().numpy()
        formatted = (output * 255 / np.max(output)).astype("uint8")
        depth = Image.fromarray(formatted)
        return depth, formatted

    def get_object_depth(self, object_mask, depth_map):
        obj_mask = np.mean(object_mask, axis=2)
        obj_mask[obj_mask > 0.05] = 1.0

        object_depth = obj_mask * (np.array(depth_map) / 255)

        avg_depth = np.sum(object_depth) / np.sum(obj_mask > 0)

        return object_depth, obj_mask, avg_depth
