import json
import os
from dataclasses import dataclass, field
import time
from typing import List, Optional
import torch
import numpy as np
import h5py
from os.path import join

from src.utils import pylogger
from src.utils.permutations import perm_w_start_idcs

def get_probe_idcs(probe_idcs_or_len, img_size, static_probes=[]):
    if not isinstance(probe_idcs_or_len, int):
        if not isinstance(probe_idcs_or_len, np.ndarray):
            probe_idcs_or_len = np.array([np.array(sublist) for sublist in probe_idcs_or_len])
        return probe_idcs_or_len
    
    x_indices = np.random.randint(0, img_size, probe_idcs_or_len - len(static_probes))
    y_indices = np.random.randint(0, img_size, probe_idcs_or_len - len(static_probes))

    if len(static_probes) > 0:
        x_indices = np.concat([x_indices, static_probes[:, 0]])
        y_indices = np.concat([y_indices, static_probes[:, 1]])

    # Stack indices into a (N, 2) array
    indices = np.column_stack((x_indices, y_indices))
    
    # TODO: incorporate different sampling strategies
    # box_sample(self.mesh_centers, self.probe_mask_box)      

    return indices

def set_values_by_indices(tensor, indices):
    """
    Set values in a 5D tensor to 1 according to (h, w) indices.

    Args:
        tensor (torch.Tensor): 5D tensor of shape (b, l, h, w, d)
        indices (torch.Tensor): 3D tensor of shape (b, n, 2) containing (h, w) indices
    """
    b, l, h, w = tensor.shape
    _, n, _ = indices.shape

    # Create index grids
    b_idx = torch.arange(b, device=tensor.device).unsqueeze(1).unsqueeze(2)
    l_grid = torch.arange(l, device=tensor.device).unsqueeze(0).unsqueeze(2).expand(b, l, n)
    h_idx, w_idx = indices[..., 0], indices[..., 1]

    # Set values
    tensor[b_idx, l_grid, h_idx.unsqueeze(1), w_idx.unsqueeze(1)] = 1

    return tensor

def select_from_video(video, indices):
    """
    Select pixels from a video tensor using a stack of (x, y) indices.
    Args:
        video: Tensor of shape (B, frames, width, height, channels) or (frames, width, height, channels) or (width, height, channels)
        indices: Tensor of shape (B, x_indices, 2) or (x_indices, 2)
    Returns:
        Selected pixels with shape (B, frames, x_indices, channels) or (frames, x_indices, channels) or (x_indices, channels)
    """

    if not isinstance(indices, np.ndarray):
        indices = np.array([np.array(sublist) for sublist in indices])

    # Track how many dimensions are added
    added_dims = 0

    if video.ndim == 4:  # (frames, width, height, channels)
        video = video[None, ...]  # shape: (1, frames, width, height, channels)
        added_dims = 1
    elif video.ndim == 3:  # (width, height, channels)
        video = video[None, None, ...]  # shape: (1, 1, width, height, channels)
        added_dims = 2

    if indices.ndim == 2:  # (x_indices, 2)
        indices = indices[None, ...]  # shape: (1, x_indices, 2)

    B, frames, width, height, channels = video.shape
    x_indices, y_indices = indices[..., 0], indices[..., 1]

    # Clip indices to ensure they are within bounds
    x_indices = np.clip(x_indices, 0, width - 1)
    y_indices = np.clip(y_indices, 0, height - 1)

    # Create batch and frame indices
    batch_idx = np.arange(B)[:, None, None]
    frame_idx = np.arange(frames)[None, :, None]

    # Advanced indexing
    selected = video[
        batch_idx,
        frame_idx,
        x_indices,
        y_indices,
        :
    ]

    # Remove added dimensions
    for _ in range(added_dims):
        if selected.shape[0] == 1:
            selected = selected[0]

    return selected



def box_sample(mesh_centers, probe_mask_box):

    if probe_mask_box is None:
        return np.random.choice(np.arange(len(mesh_centers)), size=probe_idcs, replace=False)

    (x_min, y_min), (x_max, y_max) = probe_mask_box
    
    outside_mask = (
        (mesh_centers[:, 0] < x_min) | (mesh_centers[:, 0] > x_max) |
        (mesh_centers[:, 1] < y_min) | (mesh_centers[:, 1] > y_max)
    )
    outside_indices = torch.nonzero(outside_mask, as_tuple=True)[0]

    if probe_idcs > len(outside_indices):
        raise ValueError(f"Requested {probe_idcs} probe points, but only {len(outside_indices)} points outside mask box.")

    probe_idcs = np.random.choice(outside_indices, size=probe_idcs, replace=False)

    return probe_idcs
