import torch
import torch.nn as nn
from einops import rearrange
from torch.distributions import Normal
from torch.quasirandom import SobolEngine
from scipy.stats import norm
from torchlambertw import special
import torch.nn.functional as F
import math

try:
    import gq_cuda
    HAS_GQ_CUDA = True
except ImportError:
    HAS_GQ_CUDA = False

def prior_samples(n_samples, n_variable, seed_rec):
    sobol = SobolEngine(n_variable, scramble=True, seed=seed_rec)
    samples_sobol = sobol.draw(n_samples)
    samples_i = torch.from_numpy(norm.ppf(samples_sobol))
    return samples_i

class GaussianRegularizer(nn.Module):

    def __init__(self, format, logvar_range=[-30.0, 20.0]):
        super().__init__()

        self.format = format

        assert self.format in ["bchw", "blc"]
        self.logvar_range = logvar_range

    def forward(self, z):
        z = z.float()

        if self.format == "bchw":
            b, c, h, w = z.shape
            c = c // 2
            l = h * w
            zhat = rearrange(z, "b c h w -> b (h w) c")
        else:
            b, l, c = z.shape
            c = c // 2
            zhat = z

        # b, l, c
        mu, logvar = zhat.chunk(2, 2)
        logvar = torch.clamp(logvar, self.logvar_range[0], self.logvar_range[1])
        std = torch.exp(0.5 * logvar)
        var = torch.exp(logvar)

        zhat = mu + torch.randn_like(mu) * std

        # blc
        kl2 = 1.4426 * 0.5 * (torch.pow(mu, 2) + var - 1.0 - logvar)
        kl2 = kl2.reshape(b,l,self.group,c//self.group)
        kl2 = torch.sum(kl2,dim=2) # sum over group dimension
        kl2_mean, kl2_min, kl2_max = torch.mean(kl2), torch.min(kl2), torch.max(kl2)

        kl_loss = 0.5 * torch.sum(
            torch.pow(mu, 2) + var - 1.0 - logvar,
            dim=[1, 2],
        )

        if self.format == "bchw":
            zhat = rearrange(zhat, "b (h w) c -> b c h w", h=h)

        info = {"kl": torch.mean(kl_loss), "kl2_mean": kl2_mean, "kl2_min": kl2_min, "kl2_max": kl2_max}
        return zhat, info


class TargetAdaptativeGaussianRegularizer(nn.Module):
    def __init__(self, format, logvar_range=[-30.0, 20.0], group=1, target=16):

        super().__init__()
        self.format = format
        self.logvar_range = logvar_range
        self.group = group
        self.target = target
        self.lam_factor = 1 + 1e-2
        self.lam = 1.0
        self.lam_min = 1.0
        self.lam_max = 1.0
        self.lam_range = (1e-3, 1e3)
        self.tolerance = 0.5

    def get_trainable_parameters(self):
        yield from ()

    def forward(self, z: torch.Tensor):
        z = z.float()
        if self.format == "bchw":
            b, c, h, w = z.shape
            c = c // 2
            l = h * w
            z = z
        else:
            z = rearrange(z, "b l c -> b c l")
            b, c, l = z.shape
            h = int(math.sqrt(l))
            z = z.reshape(b, c, h, h)
        # b, l, c
        mu, logvar = z.chunk(2, 2)
        logvar = torch.clamp(logvar, self.logvar_range[0], self.logvar_range[1])
        std = torch.exp(0.5 * logvar)
        var = torch.exp(logvar)

        zhat = mu + torch.randn_like(mu) * std

        kls = 1.4426 * (
            0.5
            * (torch.pow(mu, 2) + var - 1.0 - logvar)
        )

        b, c, h, w = kls.shape

        kls = torch.sum(kls.reshape(b, self.group, c // self.group, h, w), dim=1)

        ge = (kls > self.target + self.tolerance).type(kls.dtype) * self.lam_max
        eq = (kls <= self.target + self.tolerance).type(kls.dtype) * (
            kls >= self.target - self.tolerance
        ).type(kls.dtype)
        le = (kls < self.target - self.tolerance).type(kls.dtype) * self.lam_min
        kl_loss = torch.sum((ge * kls + eq * kls + le * kls), dim=[1,2,3])

        kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]

        # update lambda
        if torch.mean(kls) > self.target:
            self.lam = self.lam * self.lam_factor
        else:
            self.lam = self.lam / self.lam_factor

        if torch.max(kls) > self.target + self.tolerance:
            self.lam_max = self.lam_max * self.lam_factor
        else:
            self.lam_max / self.lam_max * self.lam_factor
        self.lam_max = max(min(self.lam_max, self.lam_range[1]), 1.0)

        if torch.min(kls) < self.target - self.tolerance:
            self.lam_min = self.lam_min / self.lam_factor
        else:
            self.lam_min = self.lam_min * self.lam_factor
        self.lam_min = max(min(self.lam_min, 1.0), self.lam_range[0])

        info = {"kl": torch.mean(kl_loss) * self.lam, "kl2_mean": torch.mean(kls).detach(), "kl2_min": torch.min(kls).detach(), "kl2_max": torch.max(kls).detach()}

        if self.format == "blc":
            zhat = rearrange(zhat, "b c h h -> b (h h) c")

        return zhat, info

class GaussianQuantRegularizer(nn.Module):
    # Gaussian VAE

    # args
    # levels: the FSQ levels parameter, see Table 1 of FSQ paper
    # format: data format, must be one of []"bchw", "blc"]

    def __init__(self, format, group, n_samples, seed, beta=1.0, logvar_range=[-30.0, 20.0], backend="torch"):
        super().__init__()

        self.format = format
        assert self.format in ["bchw", "blc"]

        self.group = group
        self.n_samples = n_samples
        self.beta = beta

        self.seed = seed
        self.register_buffer("prior_samples", prior_samples(self.n_samples, self.group, self.seed).float(), persistent=False)
        self.normal_dist = Normal(torch.zeros([1, self.group]), torch.ones([1, self.group]))
        self.register_buffer("normal_log_prob", self.normal_dist.log_prob(self.prior_samples).float(), persistent=False)

        self.logvar_range = logvar_range
        self.perturbed = None
        if backend == "cuda" and HAS_GQ_CUDA is False:
            print("no gq cuda module is detected, use pytorch backend!")
            backend = "torch"
        self.backend = backend

    def forward(self, z):
        z = z.float()

        if self.format == "bchw":
            b, c, h, w = z.shape
            l = h * w
            z = rearrange(z, "b c h w -> b (h w) c")
            c = c // 2
        else:
            b, l, c = z.shape
            z = z
            c = c // 2

        # b, l, c
        mu, logvar = z.chunk(2, 2)
        logvar = torch.clamp(logvar, self.logvar_range[0], self.logvar_range[1])
        std = torch.exp(logvar * 0.5)

        mu = mu.reshape(b, l, self.group, c // self.group).permute(0,1,3,2).reshape(-1, self.group)
        std = std.reshape(b, l, self.group, c // self.group).permute(0,1,3,2).reshape(-1, self.group)

        if self.backend == "cuda":
            # cuda impl
            if self.perturbed is None or self.perturbed.shape[0] != mu.shape[0]:
                # shape change, need buffer update
                self.perturbed = torch.zeros([mu.shape[0], self.n_samples]).to(device=mu.device).contiguous()
            gq_cuda.ops.gq_cuda(
                mu,std,self.prior_samples,self.perturbed,self.group,mu.shape[0],self.n_samples,self.beta
            )
            indices = torch.argmax(self.perturbed, dim=1)
            zhat = torch.index_select(self.prior_samples, 0, indices)

        elif self.backend == "torch":
            # torch impl
            bs = mu.shape[0] // 8
            zhat = torch.zeros_like(mu)
            indices = torch.zeros([mu.shape[0]], device=mu.device, dtype=torch.long)

            for i in range(0, mu.shape[0], bs):
                mu_q = mu[i : i + bs]
                std_q = std[i:i+bs]

                q_normal_dist = Normal(mu_q[:, None, :], std_q[:, None, :])
                
                log_ratios = (
                    q_normal_dist.log_prob(self.prior_samples[None])
                    - self.normal_log_prob[None] * self.beta
                )

                perturbed = torch.sum(log_ratios, dim=2)
                argmax_indices = torch.argmax(perturbed, dim=1)
                zhat[i : i + bs] = torch.index_select(self.prior_samples, 0, argmax_indices)
                indices[i : i + bs] = argmax_indices
        elif self.backend == "raw":

            bs = mu.shape[0] // 8
            zhat = torch.zeros_like(mu)
            indices = torch.zeros([mu.shape[0]], device=mu.device, dtype=torch.long)

            for i in range(0, mu.shape[0], bs):
                mu_q = mu[i : i + bs]
                std_q = std[i:i+bs]

                log_ratios = - ((self.prior_samples[None] - mu_q[:, None, :]) / std_q[:, None, :]) ** 2 + (self.prior_samples[None]) ** 2 * self.beta

                perturbed = torch.sum(log_ratios, dim=2)
                argmax_indices = torch.argmax(perturbed, dim=1)
                zhat[i : i + bs] = torch.index_select(self.prior_samples, 0, argmax_indices)
                indices[i : i + bs] = argmax_indices
        else:
            raise ValueError

        zhat = zhat.reshape(b, l, c // self.group, self.group).permute(0, 1, 3, 2).reshape(b, l, c).float()
        indices = indices.reshape(b, l, c // self.group)
        if self.format == "bchw":
            zhat = rearrange(zhat, "b (h w) c -> b c h w", h=h)
            indices = rearrange(indices, "b (h w) c -> b c h w", h=h)

        return zhat, {"indices": indices}

    def dequant(self, indices):
        if self.format == "bchw":
            b, ng, h, w = indices.shape
            l = h * w
            indices = rearrange(indices, "b c h w -> b (h w) c")
        else:
            b, l, ng = indices.shape
            # here, c is number of groups

        indices = indices.reshape(-1)
        zhat = torch.zeros([b*l*ng, self.group], device=indices.device, dtype=torch.float32)

        # bs = indices.shape[0] // 8
        # for i in range(0, indices.shape[0], bs):
        #     zhat[i:i+bs] = torch.index_select(self.prior_samples, 0, indices[i:i+bs]).float()
        zhat = torch.index_select(self.prior_samples, 0, indices).float()
        zhat = zhat.reshape(b, l, ng, self.group).permute(0, 1, 3, 2).reshape(b, l, ng * self.group)

        if self.format == "bchw":
            zhat = rearrange(zhat, "b (h w) c -> b c h w", h=h)
        return zhat

class GroupedLambertWRegularizer(nn.Module):
    # import torch
    # from sgm.modules.autoencoding.regularizers import GroupedLambertWRegularizer
    # lambert = GroupedLambertWRegularizer(11.091, 16)
    # z, out = lambert(torch.randn([2,32,8,8]))
    # print(out["bits-mean"])
    def __init__(self, format, logvar_range=[-30.0, 20.0], group=1, target=16):
        super().__init__()
        self.format = format
        self.logvar_range=logvar_range
        self.target = target
        self.group = group

    def forward(self, z: torch.Tensor):
        z = z.float()
        if self.format == "bchw":
            b, c, h, w = z.shape
            c = c // 2
            l = h * w
            z = z
        else:
            z = rearrange(z, "b l c -> b c l")
            b, c, l = z.shape
            h = int(math.sqrt(l))
            z = z.reshape(b, c, h, h)
        # b, l, c
        tau, gamma_unnorm = torch.chunk(z, 2, dim=1)
        b, c, h, w = tau.shape
        assert(c % self.group == 0)
        ng = c // self.group

        tau = tau.reshape(b, self.group, ng, h, w)
        gamma_unnorm = gamma_unnorm.reshape(b, self.group, ng, h, w)

        gamma = F.softmax(gamma_unnorm, dim=1)

        # shrink upperbound of mu a little to avoid nan in lambert
        kw = gamma * (self.target - 0.1) + 0.1 / self.group
    
        mu = torch.sqrt(2 * kw) * F.tanh(tau) * (1 - 1e-2)

        W = special.lambertw
        # t = lambda x: torch.sqrt(2 * np.e * x + 2 + 1e-6)
        # W = lambda x: (13 / 720 * t(x) ** 3 + 257 / 720 * t(x) ** 2 + t(x) / 6 - 1) / (
        #     103 / 720 * t(x) ** 2 + 5 / 6 * t(x) + 2
        # )

        var = -W(-torch.exp(mu ** 2 - 2 * kw - 1.0))

        if torch.isnan(var).any():
            var = var.reshape(-1)
            mu = mu.reshape(-1)
            kw = kw.reshape(-1)
            for i in range(var.shape[0]):
                if torch.isnan(var[i]):
                    print("caught nan")
                    print(mu[i], kw[i])
                    print(mu[i].dtype, kw[i].dtype)
            print("var nan")
            assert(0)

        std = torch.sqrt(var)
        logvar = torch.log(var)

        kls = (
            1.4426 * (0.5 * (torch.pow(mu, 2) + var - 1.0 - logvar)).clone().detach()
        )
        kls = torch.sum(kls, dim=1)

        zhat = mu + std * torch.randn_like(mu)

        zhat = zhat.reshape(b, c, h, w)

        if self.format == "blc":
            zhat = rearrange(zhat, "b c h w -> b (h w) c", h=h)

        info = {"kl2_mean": torch.mean(kls), "kl2_min": torch.min(kls), "kl2_max": torch.max(kls)}

        return zhat, info


def cal_mu(kl, var):
    from torchlambertw import special
    W = special.lambertw
    mu2 = 2 * kl - var + 1.0 + torch.log(var)
    z = -torch.exp(- 2 * kl - 1)
    var1 = -W(z, 0)
    var2 = -W(z, -1)
    print(mu2, var1, var2, z, -1/torch.exp(torch.tensor(1)))
    return torch.sqrt(mu2)

if __name__ == "__main__":
    z = torch.randn([1, 32, 4, 4]).cuda()
    gauss = GaussianQuantRegularizer("bchw", 16, 1024, 42).cuda()
    zhat, info = gauss(z)
    z2 = gauss.dequant(info["indices"])

    print(zhat.shape, z2.shape)

    print(torch.mean(torch.abs(zhat - z2)))
