import os
import cv2
import numpy as np
import base64
import random
from io import BytesIO
from PIL import Image
import matplotlib.pyplot as plt

def image_to_base64(image: np.ndarray) -> str:
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    pil_image = Image.fromarray(image_rgb)
    buffer = BytesIO()
    pil_image.save(buffer, format="PNG")
    buffer.seek(0)
    return base64.b64encode(buffer.read()).decode('utf-8')

def encode_patch_to_base64(patch_rgb):
    patch_bgr = cv2.cvtColor(patch_rgb, cv2.COLOR_RGB2BGR)
    success, buffer = cv2.imencode(".png", patch_bgr)
    if not success:
        raise ValueError("Encoding patch failed")
    return base64.b64encode(buffer).decode('utf-8')

def encode_mask_to_base64(mask):
    mask_uint8 = (mask.astype(np.uint8)) * 255
    success, buffer = cv2.imencode(".png", mask_uint8)
    if not success:
        raise ValueError("Mask encoding failed")
    return base64.b64encode(buffer).decode('utf-8')

def encode_image(image_path):
    try:
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')
    except FileNotFoundError:
        print(f"Error: File {image_path} not found.")
        return None
    except Exception as e:
        print(f"Error encoding image: {e}")
        return None

def apply_mask_to_image(image_path, mask):
    image = cv2.imread(image_path)
    if len(image.shape) == 2:
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
    mask_bool = mask.astype(bool)
    masked_image = np.zeros_like(image)
    masked_image[mask_bool] = image[mask_bool]
    return cv2.cvtColor(masked_image, cv2.COLOR_BGR2RGB)

def extract_patches_from_mask(image_path, mask, visualize=False):
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    h_img, w_img, _ = image.shape
    ys, xs = np.where(mask)
    if len(xs) == 0 or len(ys) == 0:
        raise ValueError("Mask is empty. Cannot compute bounding box.")
    x1, y1, x2, y2 = np.min(xs), np.min(ys), np.max(xs), np.max(ys)
    x1 = int(0.9*x1)
    y1 = int(0.9*y1)
    x2 = int(1.1*x2)
    y2 = int(1.1*y2)
    
    x2 = min(x2, w_img)
    y2 = min(y2, h_img)
    
    patch_w, patch_h = x2 - x1 + 1, y2 - y1 + 1
    
    patch_mask = image[y1:y2 + 1, x1:x2 + 1]
    max_x, max_y = w_img - patch_w, h_img - patch_h
    random.seed(0)
    rand_x = random.randint(0, max_x)
    rand_y = random.randint(0, max_y)
    patch_rand = image[rand_y:rand_y + patch_h, rand_x:rand_x + patch_w]

    if visualize:
        fig, axs = plt.subplots(1, 2, figsize=(10, 5))
        axs[0].imshow(patch_mask)
        axs[0].set_title("Patch from Mask")
        axs[0].axis("off")
        axs[1].imshow(patch_rand)
        axs[1].set_title("Random Patch")
        axs[1].axis("off")
        plt.tight_layout()
        plt.show()

    return patch_mask, patch_rand

def sample_crops_by_area(image_path, crop_area):
    image_pil = Image.open(image_path).convert("RGB")
    image = np.array(image_pil)
    H, W, _ = image.shape
    crop_size = int(np.sqrt(crop_area))
    crop_w = crop_h = crop_size
    step_x = W // 3
    step_y = H // 3
    centers = [(step_x // 2 + i * step_x, step_y // 2 + j * step_y) for j in range(3) for i in range(3)]
    crops = []
    for cx, cy in centers:
        x1 = max(cx - crop_w // 2, 0)
        y1 = max(cy - crop_h // 2, 0)
        x2 = min(x1 + crop_w, W)
        y2 = min(y1 + crop_h, H)
        if x2 - x1 < crop_w:
            x1 = max(x2 - crop_w, 0)
        if y2 - y1 < crop_h:
            y1 = max(y2 - crop_h, 0)
        crop_np = image[y1:y2, x1:x2]
        crop_pil = Image.fromarray(crop_np)
        crops.append(crop_pil)
    return crops


def get_gt(image_path, args):
    file_name = os.path.basename(image_path)
    gt_file = file_name.split('.')[0] + '.npy'
    gt_path = os.path.join(args.gt_dir, gt_file)
    return np.load(gt_path)

def calc_dice(mask, gt_np):
    return 2 * (mask * gt_np).sum() / ((mask + gt_np).sum() + 1e-8)

def calc_dice_for_file(image_path, mask, args):
    gt_np = get_gt(image_path, args)
    return calc_dice(mask, gt_np)

def visualization(image_path, bbox, mask, points, args, save_path=None):
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    h, w, _ = image.shape

    # Draw bounding box
    image_with_bbox = image.copy()
    x1, y1, x2, y2 = map(int, bbox)
    cv2.rectangle(image_with_bbox, (x1, y1), (x2, y2), color=(255, 0, 0), thickness=3)

    # Display mask
    mask_display = (mask * 255).astype(np.uint8)

    # Load ground truth
    gt = get_gt(image_path, args)
    gt_display = (gt * 255).astype(np.uint8)

    # Overlay mask on image
    mask_overlay = image.copy()
    red_mask = np.zeros_like(mask_overlay)
    red_mask[:, :, 0] = 255
    alpha = 0.5
    mask_bool = mask.astype(bool)
    mask_overlay[mask_bool] = cv2.addWeighted(mask_overlay[mask_bool], 1 - alpha, red_mask[mask_bool], alpha, 0)

    # Plot all components
    fig, axs = plt.subplots(1, 4, figsize=(16, 4))
    axs[0].imshow(image_with_bbox)
    axs[0].set_title("Bounding Box")
    axs[0].axis("off")
    
    for (x, y) in points:
        color = 'red'
        axs[0].plot(x, y, marker='o', markersize=6, color=color)

    axs[1].imshow(gt_display, cmap='gray')
    axs[1].set_title("Ground Truth")
    axs[1].axis("off")

    axs[2].imshow(mask_display, cmap='gray')
    axs[2].set_title("Predicted Mask")
    axs[2].axis("off")

    masked_rgb = apply_mask_to_image(image_path, mask)
    axs[3].imshow(masked_rgb)
    axs[3].set_title("Masked Image")
    axs[3].axis("off")

    plt.tight_layout()
    if save_path is None:
        plt.show()
    else:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path)
        plt.close(fig)

