import argparse
import itertools
import logging
import random
from argparse import Namespace
from typing import Any, Dict, Iterable, Iterator, List, Tuple, Type

import numpy as np  # type: ignore
import torch
from matplotlib.colors import to_rgba  # type: ignore
from torch import nn

__all__ = [
    "params_sans", "seed", "set_logger", "str2bool", "flatten",
    "unflatten_like", "get_mixture_mu_var", "gaussian_nll", "ensure_keys", "get_color",
    "to_cpu", "to_device", "get_test_name", "softmax_log_softmax_of_sample"
]


T = torch.Tensor


def params_sans(model: nn.Module, without: Type[nn.Module] = None) -> Iterator[Any]:
    if without is None:
        raise ValueError("need an argument for without")

    params = []
    for m in model.modules():
        if not isinstance(m, without):
            params.append(m.parameters(recurse=False))
    return itertools.chain(*params)


def seed(run: int) -> None:
    torch.manual_seed(run)
    random.seed(run)
    np.random.seed(run)


def softmax_log_softmax_of_sample(x: T) -> Tuple[T, T]:
    """given that we have some logit samples in the shape of (samples, n, dim) stably return the softmax and the log softmax"""
    sample = x.size(0)

    out = torch.log_softmax(x, dim=-1)
    out = torch.logsumexp(out, dim=0) - np.log(sample)
    return out.exp(), out


def set_logger(level: str) -> logging.Logger:
    logging.basicConfig(
        format="%(asctime)s %(levelname)-8s %(message)s",
        level=level,
        datefmt="%Y-%m-%d %H:%M:%S",
    )

    return logging.getLogger()


def str2bool(v: Any) -> bool:
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def flatten(lst: Iterable[T]) -> T:
    tmp = [i.contiguous().view(-1, 1) for i in lst]
    return torch.cat(tmp).view(-1)


def unflatten_like(vector: T, like_tensor_list: Iterable[T]) -> Iterable[T]:
    # Takes a flat torch.tensor and unflattens it to a list of torch.tensors shaped like like_tensor_list
    if len(vector.size()) == 1:
        vector = vector.unsqueeze(0)

    outList = []
    i = 0
    for tensor in like_tensor_list:
        # n = module._parameters[name].numel()
        n = tensor.numel()
        outList.append(vector[:, i : i + n].view(tensor.shape))
        i += n
    return outList


def get_mixture_mu_var(mus: T, vars: T, dim: int = 0) -> Tuple[T, T]:
    mu = mus.mean(dim=dim)
    var = (vars + mus ** 2).mean(dim=dim) - (mu ** 2)
    return mu, var


def gaussian_nll(mu: T, var: T, y: T) -> T:
    s = len(mu.size())
    if s != len(var.size()) or s != len(y.size()):
        raise ValueError(f"mu: {mu.size()} var: {var.size()} and y: {y.size()} must all be the smae size")

    return ((1 / (2 * var)) * (y - mu) ** 2 + 0.5 * torch.log(var))  # type: ignore


def ensure_keys(o: Dict[str, Any], keys: List[Any], vals: List[Any]) -> None:
    for k, v in zip(keys, vals):
        if k not in o.keys():
            o[k] = v


colors = [
    "tab:blue", "tab:orange", "tab:green", "tab:red", "tab:purple",
    "tab:brown", "tab:pink", "tab:gray", "tab:olive", "tab:cyan",
    "mediumseagreen", "teal", "navy", "darkgoldenrod", "darkslateblue",
]


def get_color(i: int) -> Tuple[float, ...]:
    if i < len(colors):
        return to_rgba(colors[i])  # type: ignore
    return (np.random.rand(), np.random.rand(), np.random.rand(), 1.0)


def to_cpu(*args: Any) -> Any:
    return [v.cpu() for v in args]


def to_device(*args: Any, device: torch.device) -> Any:
    return [v.to(device) for v in args]


def get_test_name(args: Namespace) -> str:
    if not args.ood_test and not args.corrupt_test:
        return "standard"
    elif args.ood_test:
        return "ood"
    elif args.corrupt_test:
        return "corrupt"
    else:
        raise NotImplementedError(f"this combination of args has no known test name: {args=}")
