import numpy as np
import torch.nn
from numpy.core.multiarray import ndarray
from torch.distributions.utils import _standard_normal
from torch.optim import RMSprop


device = torch.device("cpu")


# TODO should inherit from torch.distributions.Distribution?
class SimpleNormal:
    def __init__(self, means, logvars):
        self._means = means
        self._logvars = logvars

        self.two_times_pi = torch.tensor(
            np.ones(1) * 2 * np.pi, requires_grad=False, device=device
        )

    @property
    def mean(self):
        return self._means

    @property
    def logvar(self):
        return self._logvars

    @property
    def variance(self):
        return torch.exp(self._logvars)

    @property
    def stddev(self):
        return torch.sqrt(self.variance)

    @property
    def logstd(self):
        return torch.log(self.stddev)

    @property
    def inv_var(self):
        return torch.exp(-self._logvars)

    @torch.no_grad()
    def sample(self):
        res = torch.empty_like(self._means, device=device).uniform_()
        torch.multiply(self._logvars.exp_(), res, out=res)
        torch.add(self._means, res, out=res)

        return res

    def rsample(self, sample_shape=torch.Size()):  # noqa: B008
        eps = _standard_normal(self._means.shape, dtype=self.mean.dtype, device=device)
        return self.stddev + eps * self.mean

    def log_prob(self, value):
        return (
            -((value - self._means) ** 2) / (2 * self.variance)
            - self.logstd
            - torch.log(torch.sqrt(self.two_times_pi))
        )


def to_tensor(x):
    if (isinstance(x, np.ndarray) or np.isscalar(x)) and not isinstance(x, str):
        return torch.from_numpy(np.array(x)).float()
    else:
        return x


def to_numpy(x):
    if isinstance(x, torch.Tensor):
        return x.cpu().detach().numpy()
    else:
        return x


class Decorator:
    pass


# noinspection PyPep8Naming
class input_to_tensors(Decorator):
    def __call__(self, *args, **kwargs):
        new_args = [to_tensor(arg) for arg in args]
        new_kwargs = {key: to_tensor(value) for key, value in kwargs.items()}
        return self.func(*new_args, **new_kwargs)


# noinspection PyPep8Naming
class output_to_tensors(Decorator):
    def __call__(self, *args, **kwargs):
        outputs = self.func(*args, **kwargs)
        if isinstance(outputs, np.ndarray):
            return to_tensor(outputs)
        if isinstance(outputs, tuple):
            new_outputs = tuple([to_tensor(item) for item in outputs])
            return new_outputs
        return outputs


# noinspection PyPep8Naming
class input_to_numpy(Decorator):
    def __call__(self, *args, **kwargs):
        new_args = [to_numpy(arg) for arg in args]
        new_kwargs = {key: to_numpy(value) for key, value in kwargs.items()}
        return self.func(*new_args, **new_kwargs)


# noinspection PyPep8Naming
class output_to_numpy(Decorator):
    def __call__(self, *args, **kwargs):
        outputs = self.func(*args, **kwargs)
        if isinstance(outputs, torch.Tensor):
            return to_numpy(outputs)
        if isinstance(outputs, tuple):
            new_outputs = tuple([to_numpy(item) for item in outputs])
            return new_outputs
        return outputs


class Swish(torch.nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


def torch_truncated_normal_initializer(w: torch.Tensor):
    """Initializes weights of given module using truncated normal distribution"""
    if w.ndim == 2:
        input_dim = w.data.shape[0]
        stddev = 1 / (2 * np.sqrt(input_dim))
        torch_truncated_normal_initializer_(w.data, std=stddev)
    if w.ndim == 3:
        num_members, input_dim, _ = w.data.shape
        stddev = 1 / (2 * np.sqrt(input_dim))
        for i in range(num_members):
            torch_truncated_normal_initializer_(w.data[i], std=stddev)


# inplace truncated normal function for pytorch.
# credit to https://github.com/Xingyu-Lin/mbpo_pytorch/blob/main/model.py#L64
def torch_truncated_normal_initializer_(
    tensor: torch.Tensor, mean: float = 0, std: float = 1
):
    """Samples from a truncated normal distribution in-place.

    Args:
        tensor (tensor): the tensor in which sampled values will be stored.
        mean (float): the desired mean (default = 0).
        std (float): the desired standard deviation (default = 1).

    Returns:
        (tensor): the tensor with the stored values. Note that this modifies the
            input tensor in place, so this is just a pointer to the same object.
    """
    torch.nn.init.normal_(tensor, mean=mean, std=std)
    while True:
        cond = torch.logical_or(tensor < mean - 2 * std, tensor > mean + 2 * std)
        if not torch.sum(cond):
            break
        tensor = torch.where(
            cond,
            torch.nn.init.normal_(
                torch.ones(tensor.shape, device=tensor.device), mean=mean, std=std
            ),
            tensor,
        )
    return tensor


def initializer_from_string(initializer_str, bias_initializer_str):
    if initializer_str == "xavier_uniform":
        weight_initializer = torch.nn.init.xavier_uniform_
    elif initializer_str == "torch_truncated_normal":
        weight_initializer = torch_truncated_normal_initializer
    elif initializer_str == "kaiming_uniform":
        weight_initializer = torch.nn.init.kaiming_uniform_
    elif initializer_str == "uniform":
        weight_initializer = torch.nn.init.uniform_
    elif initializer_str == "big_uniform":
        weight_initializer = lambda w: torch.nn.init.uniform_(w, -5, 35)

    else:
        raise NotImplementedError(
            f"Weight initializer {initializer_str} does not exist."
        )

    if bias_initializer_str == "constant_zero":
        def bias_initializer(w):
            return torch.nn.init.constant_(w, 0.0)
    elif bias_initializer_str == "kaiming_uniform":
        bias_initializer = torch.nn.init.kaiming_uniform_
    elif bias_initializer_str == "xavier_uniform":
        bias_initializer = torch.nn.init.xavier_uniform_
    elif bias_initializer_str == "uniform":
        bias_initializer = torch.nn.init.uniform_
    elif bias_initializer_str == "big_uniform":
        bias_initializer = lambda w: torch.nn.init.uniform_(w, -2, 2)
    else:
        raise NotImplementedError(
            f"Bias initializer {bias_initializer_str} does not exist."
        )

    def _weight_initializer(m):
        if isinstance(m, torch.nn.Linear) or isinstance(m, torch.nn.Conv1d) or hasattr(m, 'weight'):
            weight_initializer(m.weight)
        elif isinstance(m, torch.nn.Linear) or hasattr(m, 'bias'): 
            bias_initializer(m.bias)
        elif (
            isinstance(m, torch.nn.Tanh)
            or isinstance(m, torch.nn.ReLU)
            or isinstance(m, torch.nn.Sequential)
            or isinstance(m, Swish)
            or isinstance(m, torch.nn.SiLU)
        ):
            pass
        else:
            print("type model", type(m))
        # assert initialized

    return _weight_initializer


def activation_from_string(act_str):
    # noinspection SpellCheckingInspection
    act_dict = {
        "relu": torch.nn.ReLU,
        "tanh": torch.nn.Tanh,
        "none": None,
        "swish": Swish,
        "silu": torch.nn.SiLU,
        "sigmoid": torch.nn.Sigmoid,
    }
    if act_str in act_dict:
        return act_dict[act_str]
    else:
        raise NotImplementedError(
            "Add activation function {} to dictionary".format(act_str)
        )


def optimizer_from_string(opt_str):
    # noinspection SpellCheckingInspection
    opt_dict = {
        "Adam": torch.optim.Adam,
        "RMSprop": RMSprop,
    }
    if opt_str in opt_dict:
        return opt_dict[opt_str]
    else:
        raise NotImplementedError(
            "Implement optimizer {} and add it to dictionary".format(opt_str)
        )


def torch_clip(x, min_val, max_val):
    if min_val is None and max_val is None:
        raise ValueError("One of max or min must be given")
    elif min_val is None:
        return torch.min(x, max_val)
    elif max_val is None:
        return torch.max(x, min_val)
    else:
        return torch.max(torch.min(x, max_val), min_val)


class Normalizer:
    count: float
    sum_of_squares: ndarray
    sum: ndarray

    def __init__(self, shape, eps=1e-6, clip_range=(None, None)):
        self.mean = 0.0
        self.std = 1.0
        self.eps = eps
        self.shape = shape
        self.clip_range = clip_range

        self.mean_tensor = torch.zeros(1).to(device)
        self.std_tensor = torch.ones(1).to(device)

        self.re_init()

    def re_init(self):
        self.sum = np.zeros(self.shape)
        self.sum_of_squares = np.zeros(self.shape)
        self.count = 1.0

    def update(self, data):
        self.sum += np.sum(data, axis=0)
        self.sum_of_squares += np.sum(np.square(data), axis=0)
        self.count += data.shape[0]

        self.mean = self.sum / self.count
        self.std = np.maximum(
            self.eps,
            np.sqrt(
                self.sum_of_squares / self.count
                - np.square(self.sum / self.count)
                + self.eps
            ),
        )

        self.mean_tensor = torch.from_numpy(self.mean).float().to(device)
        self.std_tensor = torch.from_numpy(self.std).float().to(device)

    def normalize(self, data, out=None):
        if isinstance(data, torch.Tensor):
            if out is None:
                res = (data - self.mean_tensor) / self.std_tensor
                if not tuple(self.clip_range) == (None, None):
                    return torch_clip(res, *self.clip_range)
                else:
                    return res
            else:
                torch.sub(data, self.mean_tensor, out=out)
                torch.divide(out, self.std_tensor, out=out)
                if not tuple(self.clip_range) == (None, None):
                    torch.clip(
                        out, min=self.clip_range[0], max=self.clip_range[1], out=out
                    )
        else:
            res = (data - self.mean) / self.std
            if not tuple(self.clip_range) == (None, None):
                return np.clip(res, *self.clip_range)
            else:
                return res

    def denormalize(self, data, out=None):
        if isinstance(data, torch.Tensor):
            if out is None:
                return data * self.std_tensor + self.mean_tensor
            else:
                torch.multiply(data, self.std_tensor, out=out)
                torch.add(out, self.mean_tensor, out=out)
        else:
            return data * self.std + self.mean

    def state_dict(self):
        return {
            "mean": self.mean,
            "std": self.std,
            "sum": self.sum,
            "sum_of_squares": self.sum_of_squares,
            "count": self.count,
        }

    def load_state_dict(self, state_dict):
        self.__dict__.update(state_dict)

        self.mean_tensor = torch.from_numpy(np.asarray(self.mean)).float().to(device)
        self.std_tensor = torch.from_numpy(np.asarray(self.std)).float().to(device)
