import os
import numpy as np
from abc import ABC, abstractmethod
import torch
import torch.nn.functional as F
from skimage import io, transform
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamPredictor
from augmentation import image_augmentation, select_hp

class BaseSegmentor(ABC):
    """
    Abstract base class for segmentation models.
    """
    def __init__(self):
        self._load_model()

    @abstractmethod
    def _load_model(self):
        """Load the segmentation model."""
        pass

    @abstractmethod
    def predict(self, args: dict, image_path: str, hp: dict, bbox: tuple, points: list) -> np.ndarray:
        """Predict the mask within a given bounding box."""
        pass

class SegModelSAM(BaseSegmentor):
    def _load_model(self):
        checkpoint_path = os.path.join("checkpoints", "sam_vit_h_4b8939.pth")
        sam = sam_model_registry['vit_h'](checkpoint=checkpoint_path).to('cuda')
        self.model_seg = SamPredictor(sam)

    def predict(self, args: dict, image_path: str, hp: dict, bbox: tuple, points: list) -> np.ndarray:
        image_rgb = image_augmentation(image_path, select_hp(hp, 'seg'))
        image = np.array(image_rgb)
        self.model_seg.set_image(image)

        if len(points) > 0:
            point_coords = np.array([[x, y] for x, y in points])
            point_labels = np.array([1 for _ in points])
        else:
            point_coords = None
            point_labels = None
        
        input_box = np.array(bbox)  # [x1, y1, x2, y2]
        masks, scores, logits = self.model_seg.predict(
            box=input_box,
            point_coords=point_coords,
            point_labels=point_labels,
            multimask_output=False
        )
        return masks[0]  # binary mask

@torch.no_grad()
def medsam_inference(medsam_model, img_embed, box_1024, H, W):
    box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
    if len(box_torch.shape) == 2:
        box_torch = box_torch[:, None, :] # (B, 1, 4)

    sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
        points=None,
        boxes=box_torch,
        masks=None,
    )
    low_res_logits, _ = medsam_model.mask_decoder(
        image_embeddings=img_embed, # (B, 256, 64, 64)
        image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
        sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
        dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
        multimask_output=False,
        )

    low_res_pred = torch.sigmoid(low_res_logits)  # (1, 1, 256, 256)

    low_res_pred = F.interpolate(
        low_res_pred,
        size=(H, W),
        mode="bilinear",
        align_corners=False,
    )  # (1, 1, gt.shape)
    low_res_pred = low_res_pred.squeeze().cpu().numpy()  # (256, 256)
    medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
    return medsam_seg

class SegModelMedSAM(BaseSegmentor):
    def _load_model(self):
        ## initialize segmentation module
        CHECKPOINT_PATH = os.path.join("checkpoints", "medsam_vit_b.pth")
        self.model_seg = sam_model_registry['vit_b'](checkpoint=CHECKPOINT_PATH).to('cuda')
        self.model_seg.eval()
        #self.model_seg = SamPredictor(sam)

    def predict(self, args: dict, image_path: str, hp: dict, bbox: tuple, points: list) -> np.ndarray:
        image_rgb = image_augmentation(image_path, select_hp(hp, 'seg'))
        img_3c = np.array(image_rgb)

        H, W, _ = img_3c.shape        

        #%% image preprocessing and model inference
        img_1024 = transform.resize(img_3c, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8)
        img_1024 = (img_1024 - img_1024.min()) / np.clip(
            img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None
        )  # normalize to [0, 1], (H, W, 3)
        # convert the shape to (3, H, W)
        img_1024_tensor = torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(device)

        box_np = np.array([bbox])
        # transfer box_np t0 1024x1024 scale
        box_1024 = box_np / np.array([W, H, W, H]) * 1024
        with torch.no_grad():
            image_embedding = self.model_seg.image_encoder(img_1024_tensor) # (1, 256, 64, 64)

        mask = medsam_inference(self.model_seg, image_embedding, box_1024, H, W)

        return mask  # [H, W] boolean array
    

