import collections
import os
import random
import sys
from contextlib import contextmanager
from typing import Iterable, Tuple, Union

import numpy as np
import torch
from torch.types import Device
from torch.utils.data import Dataset

from .models.utils import ragged_len


class CertifyDataset(Dataset):
    """Wrapper for a `Dataset` that removes target from the output

    Args:
        dataset: Dataset to wrap, it must have a __getitem__ that returns in the format (x, y)
    """

    def __init__(self, dataset: Dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        input, _ = self.dataset[idx]
        return input


def vrange(starts: torch.Tensor, ends: torch.Tensor) -> torch.Tensor:
    """Create concatenated ranges for multiple starts/ends

    Example:

        >>> starts = torch.tensor([0, 100, 5])
        >>> ends = torch.tensor([2, 104, 6])
        >>> print(vrange(starts, lengths))
        tensor([0, 1, 100, 101, 102, 103, 5])

    Args:
        starts: start for each range
        ends: end for each range

    Returns:
        Concatenated ranges
    """
    n = ends - starts
    return torch.repeat_interleave(ends - n.cumsum(0), n) + torch.arange(n.sum(0))


# From https://stackoverflow.com/a/22434262
def _fileno(file_or_fd):
    fd = getattr(file_or_fd, "fileno", lambda: file_or_fd)()
    if not isinstance(fd, int):
        raise ValueError("Expected a file (`.fileno()`) or a file descriptor")
    return fd


# From https://stackoverflow.com/a/22434262
@contextmanager
def stdout_redirected(to=os.devnull, stdout=None):
    if stdout is None:
        stdout = sys.stdout

    stdout_fd = _fileno(stdout)
    # copy stdout_fd before it is overwritten
    # NOTE: `copied` is inheritable on Windows when duplicating a standard stream
    with os.fdopen(os.dup(stdout_fd), "wb") as copied:
        stdout.flush()  # flush library buffers that dup2 knows nothing about
        try:
            os.dup2(_fileno(to), stdout_fd)  # $ exec >&to
        except ValueError:  # filename
            with open(to, "wb") as to_file:
                os.dup2(to_file.fileno(), stdout_fd)  # $ exec > to
        try:
            yield stdout  # allow code to be run with the redirected stdout
        finally:
            # restore stdout to its previous value
            # NOTE: dup2 makes stdout_fd inheritable unconditionally
            stdout.flush()
            os.dup2(copied.fileno(), stdout_fd)  # $ exec >&copied


def collate_pad(batch, stack=True):
    r"""
    Minor modification from torch default collate_fn at: https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py
    """
    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        # Check if size are all equal, if not, return use rnn padding instead
        for x in batch:
            if x.size() != elem.size():
                break
        else:
            out = None
            if torch.utils.data.get_worker_info() is not None:
                # If we're in a background process, concatenate directly into a
                # shared memory tensor to avoid an extra copy
                numel = sum(x.numel() for x in batch)
                storage = elem.storage()._new_shared(numel)
                if stack:
                    out = elem.new(storage).resize_(len(batch), *list(elem.size()))
                else:
                    out = elem.new(storage).resize_(
                        len(batch) * elem.size(0), *list(elem.size()[1:])
                    )
            # If we stack in a new dimension
            if stack:
                return torch.stack(batch, dim=0, out=out)
            # We do not atack in the new dimension, we concat on the 0th dimension instead
            else:
                return torch.cat(batch, dim=0, out=out)

        # if elem.is_sparse:
        #    max_sizes = list(map(max, list(zip(*[list(x.size()) for x in batch]))))
        #    for i, x in enumerate(batch):
        #        batch[i] = torch.sparse_coo_tensor(x._indices(), x._values(), max_sizes)
        #    return torch.stack(batch, dim=0)
        # else:
        #    return torch.nn.utils.rnn.pad_sequence(batch, batch_first=True)
        if elem.is_sparse:
            batch = [x.to_dense() for x in batch]
        return torch.nn.utils.rnn.pad_sequence(batch, batch_first=True)
    elif (
        elem_type.__module__ == "numpy"
        and elem_type.__name__ != "str_"
        and elem_type.__name__ != "string_"
    ):
        if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap":
            return collate_pad([torch.as_tensor(b) for b in batch], stack=stack)
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int):
        return torch.tensor(batch)
    elif isinstance(elem, str):
        return batch
    elif isinstance(elem, collections.abc.Mapping):
        try:
            return elem_type(
                {key: collate_pad([d[key] for d in batch], stack=stack) for key in elem}
            )
        except TypeError:
            # The mapping type may not support `__init__(iterable)`.
            return {
                key: collate_pad([d[key] for d in batch], stack=stack) for key in elem
            }
    elif isinstance(elem, tuple) and hasattr(elem, "_fields"):  # namedtuple
        return elem_type(
            *(collate_pad(samples, stack=stack) for samples in zip(*batch))
        )
    elif isinstance(elem, collections.abc.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError("each element in list of batch should be of equal size")
        # It may be accessed twice, so we use a list.
        transposed = list(zip(*batch))

        if isinstance(elem, tuple):
            # Backwards compatibility.
            return [collate_pad(samples, stack=stack) for samples in transposed]
        else:
            try:
                return elem_type(
                    [collate_pad(samples, stack=stack) for samples in transposed]
                )
            except TypeError:
                # The sequence type may not support `__init__(iterable)` (e.g., `range`).
                return [collate_pad(samples, stack=stack) for samples in transposed]


def inv_collate_pad(batch, pad_value=None, batch_size=None):
    """Inverts the collate_pad's impact. Convert batch representation of inputs into sequences.
    e.g. If x = [[1,2,3], [1,1,0]], meta = {"key": [1,2]}. Then applying this:
        inv_collate_pad((x, meta)) will give a tuple of length 2
        1. ([1,2,3], {"key": [1]}) ; 2. ([1,1], {"key": [2]})

    Args:
        batch (Any): A batch of data, can contain Sequence, Mapping and torch.Tensor
        pad_value (torch.Tensor, optional): The value used to pad uneven sequence. Specify False to disable this behavior. Defaults to None.

    Raises:
        ValueError: If unhandled type is passed into this function, then it will throw error

    Returns:
        Any: The output disassembled on a sample by sample basis.
    """
    batch_type = type(batch)
    # Base case: Identify the first level of Tensor and return that as a List
    if isinstance(batch, torch.Tensor):
        # Remove inserted padding, specify False for pad value to disable.
        out = list(batch)
        if batch.ndim > 1 and pad_value is not False:
            l_in = ragged_len(x=batch, pad_value=pad_value)
            out = [x[:l] for x, l in zip(out, l_in)]
        return out

    elif isinstance(batch, collections.abc.Mapping):
        if batch_size is not None:
            out = [batch_type()] * batch_size
        else:
            out = []
        for key, batch_value in batch.items():
            # Inverse_collate should return a list of expanded elemnets
            for i, elem in enumerate(inv_collate_pad(batch_value)):
                # Add a new entry if not existed yet
                if len(out) <= i:
                    out.append(batch_type())
                out[i][key] = elem
        return out
    # Other iterable types
    elif isinstance(batch, Iterable):
        # Sequence should be decomposed into List of tuples
        # Logic: for each value in tuple, apply inv_collate, zip the results and return that
        return list(zip(map(inv_collate_pad, batch)))
    else:
        raise ValueError(f"Type ({str(batch_type)}) is not implemented.")


def get_gpu_memory(device: Union[Device, int] = None) -> Tuple[float, float]:
    """Returns the global used and total GPU memory for a given device in MiB

    Args:
        device: device (torch.device or int, optional): selected device. Returns
            statistic for the current device, given by :func:`~torch.cuda.current_device`,
            if :attr:`device` is ``None`` (default).
    """
    free, total = torch.cuda.mem_get_info(device)
    one_mib = 1 << 20
    free, total = free / one_mib, total / one_mib
    used = total - free
    return used, total

def set_seed(seed: int) -> None:
    """
    Sets the random seed for various libraries to ensure reproducible results.
    
    Args:
        seed (int): The seed value to be set for the random number generators.
    """
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

def seed_worker(worker_id: int) -> None:
    """
    Seeds a worker process with an adjusted seed value based on the worker's ID.
    
    This function is typically used with `torch.utils.data.DataLoader` to ensure
    reproducibility when using multi-processing data loading.
    
    Args:
        worker_id (int): The ID of the worker process.
    """
    worker_seed = (torch.initial_seed() + worker_id) % (2**32)
    np.random.seed(worker_seed)
    random.seed(worker_seed)

def inv_softmax(x: torch.Tensor, C: float = 0) -> torch.Tensor:
    """
    Computes the inverse of the softmax function for a given tensor.
    
    Args:
        x (torch.Tensor): The input tensor for which the inverse softmax is to be computed.
        C (float, optional): An arbitrary constant to adjust the output. Defaults to 0.
        
    Returns:
        torch.Tensor: The result of applying the inverse softmax function to the input tensor.
    """
    return torch.log(x) + C

