import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader as tDataLoader

from typing import Callable, List, Union, Tuple, Dict


def xonly(batch):
    '''
    Extracts the input tensor from an arbitrary dataloader
    :param batch:
    :return:
    '''
    return batch if isinstance(batch, torch.Tensor) else batch['x'] if isinstance(batch, dict) else batch[0]


def yonly(batch):
    if isinstance(batch, torch.Tensor):
        return torch.Tensor([-1])
    elif isinstance(batch, dict):
        return get_whichever(batch, ['target', 'y', 'label'])
    else:
        return ensure_tensor(batch[-1], dtype=torch.long)


def get_whichever(src: dict, options: list):
    for o in options:
        if o in src:
            return src[o]
    return -1


def ensure_tensor(rand, dtype=torch.float):
    if isinstance(rand, torch.Tensor):
        return rand
    else:
        return torch.Tensor(rand, dtype=dtype)


def dict_to_cpu(d: dict):
    return {k: v.cpu() if isinstance(v, torch.Tensor) or isinstance(v, nn.Module) else v for k, v in d.items()}


def dict_detach(d: dict):
    return {k: v.detach() if isinstance(v, torch.Tensor) or isinstance(v, nn.Module) else v for k, v in d.items()}


def flatten_dicts(d: dict, context = []):
    retdict = {}
    for k, v in d.items():
        if isinstance(v, Dict):
            retdict = {**retdict, **flatten_dicts(v, context+[k])}
        else:
            retdict['_'.join(context+[k])] = v
    return retdict


# def any_to_device_nested(whatev, device=None, dtype=None, respect_floats=False):
#     # this one enforces
#     if isinstance(whatev, torch.Tensor): # or isinstance(whatev, nn.Module):
#         return whatev.to(device=device, dtype=dtype) if whatev.dtype.is_floating_point or not respect_floats else whatev.to(device=self.device)
#     elif isinstance(whatev, Dict):
#         return {k: any_to_device_nested(v, device=device, dtype=dtype, respect_floats=respect_floats) for k,v in whatev.items()}
#     elif isinstance(whatev, List) or isinstance(whatev, Tuple):
#         return [any_to_device_nested(v, device=device, dtype=dtype, respect_floats=respect_floats) for v in whatev]
#     else:
#         raise TypeError(f'Cannot attempt to put stuff ona device, stuff in question: {type(whatev)}')


def any_to_device_nested(whatev, device=None, dtype=None, respect_ints=False, raise_if_no_match=True):
    def apply_tensor_device_dtype(t: torch.Tensor):
        return t.to(device=device, dtype=dtype) if dtype is not None and (t.dtype.is_floating_point or not respect_ints) else t.to(device=device)
    
    def rinm(t):
        raise ValueError(f"Encountered a type that is netither iterable nor a tensor: {type(t)}")
    
    return any_apply_recursive(
        whatev, 
        func_if_other_type=rinm if raise_if_no_match else lambda a: a, 
        kwargs={torch.is_tensor: apply_tensor_device_dtype}
    )


def any_apply_recursive(whatev, func_if_other_type: Callable=lambda a: a, kwargs: Dict = {}):
    # kwargs are a mapping (Callabe or Type):Callable, first callable 
    # run through them first
    for k, v in kwargs.items():
        if k(whatev) if isinstance(k, Callable) else isinstance(whatev, k):
            return v(whatev)
    if isinstance(whatev, Dict):
        return {k: any_apply_recursive(v, func_if_other_type=func_if_other_type, kwargs=kwargs) for k,v in whatev.items()}
    elif isinstance(whatev, List) or isinstance(whatev, Tuple):
        return [any_apply_recursive(v, func_if_other_type=func_if_other_type, kwargs=kwargs) for v in whatev]
    else:
        # if not any of the types of interest or an iterable, return with the other type func
        # add a func that raises if needed to enforce the types in the dict
        return func_if_other_type(whatev)

