import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader as tDataLoader

from typing import Callable, List, Union, Tuple, Dict


def bce_softmax_pos(t:torch.Tensor, single_task_subselect = False):
    # Convert compute the BCE input into a softmax distribution
    retval = torch.sigmoid(torch.stack([-t, t], dim=-1)) # first goes the prob of the negative class
    if single_task_subselect: # this way it looks like we just have softmax over two logits on the output
        assert retval.shape[-2] == 1, f"With single_task_subselect enabled,no multitask stuff is supported (given dim {retval.shape[-1]})"
        # return torch.index_select(retval, -2, torch.tensor(0, dtype=torch.int)).sum(dim=-2)
        return torch.sum(retval, dim=-2)
    else:
        return retval


def xonly(batch):
    '''
    Extracts the input tensor from an arbitrary dataloader
    :param batch:
    :return:
    '''
    return batch if isinstance(batch, torch.Tensor) else batch['x'] if isinstance(batch, dict) else batch[0]


def yonly(batch):
    if isinstance(batch, torch.Tensor):
        return torch.Tensor([-1])
    elif isinstance(batch, dict):
        return get_whichever(batch, ['target', 'y', 'label'])
    else:
        return ensure_tensor(batch[-1], dtype=torch.long)


def get_whichever(src: dict, options: list):
    for o in options:
        if o in src:
            return src[o]
    return None


def ensure_tensor(rand, dtype=torch.float):
    if isinstance(rand, torch.Tensor):
        return rand
    else:
        return torch.Tensor(rand, dtype=dtype)


def dict_to_dev(d: dict, device, dtype=None):
    return {k: v.to(device=device, dtype=dtype) if isinstance(v, torch.Tensor) or isinstance(v, nn.Module) else v for k, v in d.items()}

def dict_to_dev_dtype_adaptive(d: dict, device, dtype=None):  # only converts the floating dtypes to that of a model
    return {k: v.to(device=device, dtype=dtype if torch.is_floating_point(v) else v.dtype) if isinstance(v, torch.Tensor) or isinstance(v, nn.Module) else v for k, v in d.items()}

def dict_to_cpu(d: dict):
    return dict_to_dev(d, 'cpu')

def dict_detach(d: dict):
    return {k: v.detach() if isinstance(v, torch.Tensor) or isinstance(v, nn.Module) else v for k, v in d.items()}


def list_to_dev(l: list, device, dtype=None):
    return [v.to(device=device, dtype=dtype) if isinstance(v, torch.Tensor) or isinstance(v, nn.Module) else v for v in l]

def list_to_dev_dtype_adaptive(l: list, device, dtype=None):
    return [v.to(device=device, dtype=dtype if torch.is_floating_point(v) else v.dtype) if isinstance(v, torch.Tensor) or isinstance(v, nn.Module) else v for v in l]


def _forward_result_maybe_ensemble(
    x,
    model: nn.Module,
    reduce_outputs = True,
    **kwargs
):
    # at this point everything is synced to the right devices, dtypes, etc
    # should not break the compatibility since the usual calls should not give ModuleList as input
    if isinstance(model, nn.ModuleList): # treat it as an ensemble, compute the bayesian model average for the input
        results = [_forward_result(m, x, **kwargs) for m in model]
        results = torch.stack(results, 0) # concatenate over a new dimension
        results = results.mean(0) if reduce_outputs else results # reduce over a new dimension if requested (by default)
        return results
    else:
        return _forward_result(model, x, **kwargs)


def __exemplar_tensor(*args, **kwargs):
    # find an input tensor somewhere in the parameters
    for t in args:
        if isinstance(t, torch.Tensor):
            return t
    for v in kwargs.values():
        if isinstance(v, torch.Tensor):
            return v
    raise ValueError('No tensor found!')


def _forward_result(
    model: nn.Module,
    *xs,
    device_movements_async = False,
    **named_xs,
):
    try:
        device = next(model.parameters()).device
        dtype = next(model.parameters()).data.dtype
    except StopIteration:
        yeti = __exemplar_tensor(*xs, **named_xs)
        dtype = yeti.dtype
        device = yeti.device 
    
    named_xs = dict_to_dev_dtype_adaptive(named_xs, device=device, dtype=dtype)
    xs = list_to_dev_dtype_adaptive(xs, device=device, dtype=dtype)
    y_hat = model.forward(*xs, **named_xs)

    return y_hat


def _forward(
    x, y,
    model: nn.Module,
    loss: nn.Module,
    device_movements_async = False,
):
    y_hat = _forward_result_maybe_ensemble(
        x,
        model,
        device_movements_async=device_movements_async
    )
    y = y.to(next(model.parameters()).device, non_blocking=device_movements_async)
    l = loss(y_hat, y)

    return l, y_hat


@torch.enable_grad()
def train_episode(
    model: nn.Module,
    opt: optim.Optimizer,
    loss: nn.Module,
    dl: tDataLoader,
    logger_callback: Callable = None
) -> List[torch.Tensor]:
    
    model.train()
    losses = []

    for x, y in dl:
        opt.zero_grad(set_to_none=True)

        l, _ = _forward(x, y, model, loss)

        l.backward()
        opt.step()

        l = l.detach().cpu()
        losses.append(l)

        if logger_callback:
            logger_callback(l)
    opt.zero_grad(set_to_none=True)

    return losses


@torch.no_grad()
def validate(
    model: nn.Module, 
    loss: nn.Module, 
    dl: tDataLoader, 
    logger_callback: Callable = None,
    ) -> List[torch.Tensor]:
    """Validate the model on a given data loader.

    Args:
        model (nn.Module): The PyTorch model to validate.
        loss (nn.Module): The loss function used for validation.
        dl (tDataLoader): The data loader to use for validation.
        logger_callback (Callable, optional): A callback function to log the validation losses. Defaults to None.

    Returns:
        List[torch.Tensor]: The list of validation losses.
    """
    model.eval()
    losses = []

    for x, y in dl:
        l, _ = _forward(x, y, model, loss)
        # l = l.detach().cpu()  # redundant?
        losses.append(l.detach().cpu() if isinstance(l, torch.Tensor) else l)

        if logger_callback:
            logger_callback(l)

    return losses


def reduce_losses(losses, reduction='mean', batch_dim=False, keep_grad=False):
    if not keep_grad:
        with torch.no_grad():
            if not batch_dim: # if batch dim set to false
                return torch.stack(losses, dim=0).mean()
            else:
                return torch.cat(losses, dim=batch_dim).mean()
    else:
        if not batch_dim: # if batch dim set to false
            return torch.stack(losses, dim=0).mean()
        else:
            return torch.cat(losses, dim=batch_dim).mean()


def validate_fancy(
    val_dl: tDataLoader, 
    model: nn.Module, 
    val_losses, 
    reduction: Callable = lambda a: torch.stack(a).mean()
    ):
    if isinstance(val_losses, Callable):
        vlosses = validate(dl=val_dl, model=model, loss=val_losses)
        if len(vlosses) == 0:
            return vlosses
        else:
            vloss = vlosses
            return reduction(vloss)
    elif isinstance(val_losses, Dict):
        return {k: validate_full(val_dl, model, v, reduction=reduction) for k,v in val_losses.items()}


from copy import deepcopy

def train_net_es(
    train_dl: tDataLoader, 
    val_dl: tDataLoader, 
    model: nn.Module, 
    loss: nn.Module, 
    val_loss: Union[Dict[str, Callable], Callable], 
    opt: optim.Optimizer, 
    max_len: int = 32, 
    es_len: int = 16, 
    es_tol: float = 1e-2, 
    least_best: bool = True,
    verbose: bool = True,
    which_loss = None
    ):
    """
    Trains a neural network with early stopping.

    Args:
        train_dl: A DataLoader object for the training data.
        val_dl: A DataLoader object for the validation data.
        model: The neural network model to train.
        loss: The loss function to use for training.
        val_loss: The loss function to use for validation.
        opt: The optimizer to use for training.
        max_len: The maximum number of training epochs.
        es_len: The number of epochs to wait before early stopping.
        es_tol: The tolerance for early stopping.
        least_best: Whether to use the highest or lowest validation loss as the stopping criterion.
    Returns:
        The trained neural network model.
    """
    train_losses = []
    validation_losses = []

    best_model = deepcopy(model)
    es_left = es_len
    
    vloss = validate_fancy(val_dl=val_dl, model=model, val_losses=val_loss)
    print(f'Initial val loss {vloss}') if verbose else None
    if which_loss:
        val_loss = val_loss[which_loss]
    best_scores = val_loss

    for i in range(max_len):
        tlosses = train_episode(dl=train_dl, model=model, loss=loss, opt=opt)
        vloss = validate_fancy(val_dl=val_dl, model=model, val_losses=val_loss)
        tloss = torch.stack(tlosses).mean()

        print(f'Epoch: {i}, Train loss: {tloss}, Val loss: {vloss}', end='\r', flush=True) if verbose else None

        train_losses.append(tloss)
        validation_losses.append(val_loss)

        if which_loss:
            val_loss = val_loss[which_loss]

        if (vloss+es_tol < best_scores and least_best) or (vloss-es_tol > best_scores and not least_best):
            best_scores = vloss
            es_left = es_len
            best_model = deepcopy(model)
            print(f'\nNew best validation loss!') if verbose else None
        else:
            es_left -= 1
            if es_left == 0:
                # ran out of early stoppping
                break
    
    return best_model, (train_losses, validation_losses)


from torch import nn
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.normalization import LayerNorm, GroupNorm

def set_module_state_recursive(model: nn.Module, training=False, dropout=False, xnorm=False):
    """
    Sets the state of all dropout and layer normalization layers within a PyTorch model to training or evaluation mode.
    Useful for when optimizing only learnable modules.

    Args:
        model (nn.Module): The PyTorch model to modify.
        training (bool, optional): Whether to set the state to training or evaluation mode. Defaults to False.
        dropout (bool, optional): Whether to include dropout layers in the search. Defaults to False.
        layernorm (bool, optional): Whether to include layer normalization layers in the search. Defaults to False.
    """
    for name, module in model.named_modules():
        if isinstance(module, nn.Dropout) and dropout:
            module.train() if training else module.eval()
        elif isinstance(module, (LayerNorm, GroupNorm, _BatchNorm)) and xnorm:
            # TODO: consider other module types that might need that
            module.train() if training else module.eval()


import numpy as np


@torch.no_grad()
def encode_array_batchwise(model: nn.Module, data, batchsize=1024, verbose=False):
    """
    data: [n_samples, *features]
    """
    all_chunks = []
    for i in range(1+data.shape[0]//batchsize):
        chunk_in = data[i*batchsize:min((i+1)*batchsize, data.shape[0])]
        if isinstance(chunk_in, np.ndarray):
            chunk_in = torch.from_numpy(chunk_in)
        chunk_out = _forward_result(model, chunk_in)
        all_chunks.append(chunk_out.detach().cpu())
    out = torch.cat(all_chunks, dim=0)
    return out.numpy()


class InMemoryBestModelCacher():
    def __init__(self, higher_better=True):
        self.model = None
        self.val = -torch.inf if higher_better else torch.inf
        self.higher_better = higher_better
    
    def _assign(self, model: nn.Module, val):
        self.model = deepcopy(model.state_dict())
        self.val = val
    
    def restore(self, model):
        model.load_state_dict(self.model)
        return model, self.val

    def update(self, model: nn.Module, val, tol=0.):
        if (self.val+tol < val and self.higher_better) or (self.val-tol > val and not self.higher_better):
            # substitute the existing
            self._assign(model, val)
            return True
        return False


class InMemoryBestModelCacherWithEpochCounter(InMemoryBestModelCacher):
    def __init__(self, **kwargs):
        super(InMemoryBestModelCacherWithEpochCounter, self).__init__(**kwargs)
        self.episode_count = 0
        self.best_model_episode = 0
    
    def check_tolerance(self, es_tol):
        # check if we exceed the early stopping tolerance threshold
        return (self.episode_count - self.best_model_episode) > es_tol

    def update(self, model: nn.Module, val, tol=0.):
        new_best = super(InMemoryBestModelCacherWithEpochCounter, self).update(model, val, tol)
        self.best_model_episode = self.episode_count if new_best else self.best_model_episode
        self.episode_count += 1
        return new_best
