import os
import json
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from tqdm import tqdm
from abc import ABC, abstractmethod
from sklearn.cluster import KMeans
from torchvision import transforms


def get_refined_prompt_coords(bbox, original_size, dinov2_feat, patch_size=14, n_clusters=3, exclude_out_of_box=True):
    H_orig, W_orig = original_size
    H_feat, W_feat = 37, 37  # Feature map size for vitl14
    C = dinov2_feat.shape[-1]

    # Convert bbox center to feature map coordinates
    scale_x = 518 / W_orig
    scale_y = 518 / H_orig
    x_center = (bbox[0] + bbox[2]) / 2 * scale_x
    y_center = (bbox[1] + bbox[3]) / 2 * scale_y
    x_patch = int(x_center // patch_size)
    y_patch = int(y_center // patch_size)

    # Compute feature similarity
    feat_map = dinov2_feat[0].reshape(H_feat, W_feat, C)
    center_feat = feat_map[y_patch, x_patch]
    flat_feat = feat_map.reshape(-1, C)
    sim = F.cosine_similarity(flat_feat, center_feat.unsqueeze(0), dim=-1)

    # Select top-k similar patches and map back to original image coordinates
    top_k = 3 * n_clusters
    topk_indices = torch.topk(sim, top_k).indices.cpu().numpy()
    coords = [(i // W_feat, i % W_feat) for i in topk_indices]
    coords_518 = [(x * patch_size + patch_size // 2, y * patch_size + patch_size // 2) for y, x in coords]
    coords_orig = [(x / scale_x, y / scale_y) for x, y in coords_518]

    # Apply KMeans to refine point distribution
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    cluster_centers = kmeans.fit(coords_orig).cluster_centers_

    # Optionally exclude points outside the bbox
    if exclude_out_of_box:
        x1, y1, x2, y2 = bbox
        cluster_centers = [(x, y) for x, y in cluster_centers if x1 <= x <= x2 and y1 <= y <= y2]

    return cluster_centers


class BaseBooster(ABC):
    """
    Abstract base class for segmentation models.
    """
    def __init__(self):
        self._load_model()

    @abstractmethod
    def _load_model(self):
        """Load the segmentation model."""
        pass

    @abstractmethod
    def predict(self, args: dict, image_path: str, hp: dict, bbox: tuple) -> list:
        """Predict the mask within a given bounding box."""
        pass

    
class BstModelDINOKMeans(BaseBooster):
    def _load_model(self):
        dino_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg')
        self.model_bst = dino_model.cuda()
        self.transform = transforms.Compose([
            transforms.Resize((518, 518)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3)])

    def predict(self, args: dict, image_path: str, hp: dict, bbox: tuple) -> np.ndarray:
        n_clusters = hp['bst_n_clusters'] 
        if n_clusters == 0:
            return []
        elif n_clusters == 1:
            x_center = (bbox[0] + bbox[2]) / 2
            y_center = (bbox[1] + bbox[3]) / 2
            return [(x_center, y_center)]

        image = Image.open(image_path).convert("RGB")
        width, height = image.size

        input_tensor = self.transform(image).unsqueeze(0).cuda()
        with torch.no_grad():
            feat = self.model_bst.get_intermediate_layers(input_tensor, n=1, reshape=False, return_class_token=False)[0]        
        
        refined_points = get_refined_prompt_coords(bbox, (height, width), feat, n_clusters = n_clusters)
        
        return refined_points