from typing import Any

import numpy as np
import torch


def tensor_to_device(batch: Any, device: torch.device):
    """Recursively move tensors in nested structure to device."""
    if torch.is_tensor(batch):
        return batch.to(device)
    if isinstance(batch, dict):
        return {k: tensor_to_device(v, device) for k, v in batch.items()}
    if isinstance(batch, (list, tuple)):
        return type(batch)(tensor_to_device(v, device) for v in batch)
    return batch


def select_random_points(pr, gt, point_num=9):
    pred, gt = pr.data.cpu().numpy(), gt.data.cpu().numpy()
    error = np.zeros_like(pred)
    error[pred != gt] = 1

    batch_points, batch_labels = [], []
    for j in range(error.shape[0]):
        one_pred, one_gt, one_error = pred[j, 0], gt[j, 0], error[j, 0]
        indices = np.argwhere(one_error == 1)
        if indices.shape[0] > 0:
            selected_indices = indices[np.random.choice(indices.shape[0], point_num, replace=True)]
        else:
            indices = np.random.randint(0, pred.shape[-1], size=(point_num, 2))
            selected_indices = indices[np.random.choice(indices.shape[0], point_num, replace=True)]

        points, labels = [], []
        for i in selected_indices:
            y, x = i[0], i[1]
            if one_pred[y, x] == 0 and one_gt[y, x] == 1:
                label = 1
            elif one_pred[y, x] == 1 and one_gt[y, x] == 0:
                label = 0
            else:
                label = -1
            points.append((x, y))
            labels.append(label)

        batch_points.append(points)
        batch_labels.append(labels)
    return np.array(batch_points), np.array(batch_labels)


def generate_point(masks, labels, low_res_masks, batched_input, point_num):
    masks_binary = (torch.sigmoid(masks.clone()) > 0.5).float()
    low_res_logits = torch.sigmoid(low_res_masks.clone())
    points, point_labels = select_random_points(masks_binary, labels, point_num=point_num)

    batched_input["mask_inputs"] = low_res_logits
    batched_input["point_coords"] = torch.as_tensor(points, dtype=torch.float)
    batched_input["point_labels"] = torch.as_tensor(point_labels, dtype=torch.long)
    batched_input["boxes"] = None
    return batched_input


def set_trainable_parts(
    model,
    *,
    train_image_encoder: bool = False,
    train_mask_decoder: bool = False,
    train_prompt_encoder: bool = False,
    train_image_encoder_adapter: bool = False,
    train_image_encoder_non_adapter: bool = False,
    adapter_keyword: str = "adapter",
) -> None:
    """Freeze all parameters and unfreeze specified parts based on flags."""
    for p in model.parameters():
        p.requires_grad = False
    if train_image_encoder and hasattr(model, "image_encoder"):
        for p in model.image_encoder.parameters():
            p.requires_grad = True
    if hasattr(model, "image_encoder"):
        for name, p in model.image_encoder.named_parameters():
            has_kw = adapter_keyword.lower() in name.lower()
            if has_kw and train_image_encoder_adapter:
                p.requires_grad = True
            elif not has_kw and train_image_encoder_non_adapter:
                p.requires_grad = True
    if train_mask_decoder and hasattr(model, "mask_decoder"):
        for p in model.mask_decoder.parameters():
            p.requires_grad = True
    if train_prompt_encoder and hasattr(model, "prompt_encoder"):
        for p in model.prompt_encoder.parameters():
            p.requires_grad = True
