import math
import numpy as np
import torch as th
import torch.nn as nn


class PopArtLayer(nn.Module):
    mu: np.ndarray
    sigma: np.ndarray

    def __init__(self, input_dim: int, output_dim: int, beta: float = 3.0e-4) -> None:
        super().__init__()

        self.beta = beta

        self.weights = nn.Parameter(th.Tensor(output_dim, input_dim))
        self.bias = nn.Parameter(th.Tensor(output_dim))

        self.mu = np.zeros(output_dim, dtype=np.float32)
        self.sigma = np.ones(output_dim, dtype=np.float32)
        self.nu = np.zeros(output_dim, dtype=np.float32)

        nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5))
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weights)
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.bias, -bound, bound)

        self.output_dim = output_dim

    def forward(self, x: th.Tensor):
        normalised_y = x.mm(self.weights.t())
        normalised_y = normalised_y + self.bias.unsqueeze(0).expand_as(normalised_y)
        return normalised_y

    @th.no_grad()
    def denormalise(self, normalised_y: th.Tensor):
        sigma = th.from_numpy(self.sigma).to(normalised_y.device)
        mu = th.from_numpy(self.mu).to(normalised_y.device)

        return normalised_y * sigma + mu

    def normalise(self, unnormalised_y: np.ndarray):
        return (unnormalised_y - self.mu) / self.sigma

    def normalise_tensor(self, unnormalised_y: th.Tensor):
        mu = th.from_numpy(self.mu).to(unnormalised_y.device)
        sigma = th.from_numpy(self.sigma).to(unnormalised_y.device)
        return (unnormalised_y - mu) / sigma

    def update_stats_and_params(self, target: np.ndarray):
        old_mu = self.mu
        old_sigma = self.sigma

        target = target.reshape(-1, self.output_dim)

        mu = target.mean(axis=0)
        nu = (target**2).mean(axis=0)
        self.mu = (1 - self.beta) * self.mu + self.beta * mu
        self.nu = (1 - self.beta) * self.nu + self.beta * nu

        self.sigma = np.sqrt(self.nu - self.mu**2 + 1e-8)
        self.sigma = np.clip(self.sigma, 1e-4, 1e6)

        old_sigma = th.from_numpy(old_sigma).to(self.weights.device)
        sigma = th.from_numpy(self.sigma).to(self.weights.device)
        old_mu = th.from_numpy(old_mu).to(self.weights.device)
        mu = th.from_numpy(self.mu).to(self.weights.device)

        self.weights.data = (self.weights.t() * old_sigma / sigma).t()
        self.bias.data = (old_sigma * self.bias + old_mu - mu) / sigma

    def update_and_normalise(self, target: np.ndarray) -> np.ndarray:
        self.update_stats_and_params(target)
        return self.normalise(target)

    @th.no_grad()
    def update_and_normalise_tensor(self, target: th.Tensor) -> th.Tensor:
        device, dtype = target.device, target.dtype
        old_mu = th.from_numpy(self.mu).to(device, dtype)
        old_nu = th.from_numpy(self.nu).to(device, dtype)
        old_sigma = th.from_numpy(self.sigma).to(device, dtype)

        mu = target.mean()
        nu = (target**2).mean()

        new_mu = (1 - self.beta) * old_mu + self.beta * mu
        new_nu = (1 - self.beta) * old_nu + self.beta * nu
        new_sigma = th.clip(th.sqrt(new_nu - new_mu**2 + 1e-8), 1e-4, 1e6)

        self.weights.data = self.weights.data * old_sigma / new_sigma
        self.bias.data = (old_sigma * self.bias.data + old_mu - new_mu) / new_sigma

        self.mu = new_mu.cpu().numpy()
        self.nu = new_nu.cpu().numpy()
        self.sigma = new_sigma.cpu().numpy()

        return (target - new_mu) / new_sigma
