from typing import List, Mapping, Tuple, Union
import itertools
import torch
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import Dataset


class MyDistributedDataParallel(DistributedDataParallel):
    def __init__(self, model, **kwargs):
        super(MyDistributedDataParallel, self).__init__(model, **kwargs)

    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)
        

def nested_to_device(
        inputs: Union[List[torch.Tensor], Mapping[str, torch.Tensor]], 
        device: torch.device
) -> List[torch.Tensor]:
    """Push all tensors in a list to the specified device."""
    if isinstance(inputs, dict):
        for k, t in inputs.items():
            inputs[k] = t.to(device)
    else:
        for i, t in enumerate(inputs):
            inputs[i] = t.to(device)
    return inputs


def get_train_val_dataset(
        dataset: Dataset,
        train_fraction: float
) -> Tuple[Dataset, Dataset]:
    """Splits a dataset into trainining and validation sets."""
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_fraction, 1 - train_fraction],
        generator=torch.Generator().manual_seed(42))
    return train_dataset, val_dataset


def extract(
    vals: torch.Tensor,
    index: torch.Tensor,
    shape: List[int],
) -> torch.Tensor:
    batch_size = index.shape[0]
    out = vals.gather(-1, index.cpu())
    return out.reshape(batch_size, *((1,) * (len(shape) - 1))).to(index.device)


def describe_tensor(X, name):
    assert isinstance(X, torch.Tensor)
    
    # Print basic info: name, shape, and dtype
    print(name, ": shape=", X.shape, ", dtype=", X.dtype, sep='')
    
    # Handle complex tensors
    if X.is_complex():
        print(f"  real part: min-max={X.real.min().item():.4}..{X.real.max().item():.4}, mean={X.real.mean().item():.4}, std={X.real.std().item():.4}")
        print(f"  imag part: min-max={X.imag.min().item():.4}..{X.imag.max().item():.4}, mean={X.imag.mean().item():.4}, std={X.imag.std().item():.4}")
    # Handle floating point tensors
    elif torch.is_floating_point(X):
        print(f"  min-max={X.min().item():.4}..{X.max().item():.4}, mean={X.mean().item():.4}, std={X.std().item():.4}")
    # Handle integer tensors
    elif X.dtype in (torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8):
        print(f"  min={X.min().item()}, max={X.max().item()}")
    # For other types (like bool)
    else:
        print(f"  unique values: {torch.unique(X)}")


def to_np(tensor):
    return tensor.detach().cpu().numpy()


def limit_iterable(iterable, k):
    return itertools.islice(iterable, k)


def torch_randint(l, r):
    return torch.randint(l, r, (1,)).item()


def torch_rand():
    return torch.rand((1,)).item()

