import math
import torch
from einops import rearrange

from .base import BaseModule
import torch.nn.functional as F


class Mish(BaseModule):
    def forward(self, x):
        return x * torch.tanh(torch.nn.functional.softplus(x))


class Upsample(BaseModule):
    def __init__(self, dim):
        super(Upsample, self).__init__()
        self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)

    def forward(self, x):
        return self.conv(x)


class Downsample(BaseModule):
    def __init__(self, dim):
        super(Downsample, self).__init__()
        self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1)

    def forward(self, x):
        return self.conv(x)


class Rezero(BaseModule):
    def __init__(self, fn):
        super(Rezero, self).__init__()
        self.fn = fn
        self.g = torch.nn.Parameter(torch.zeros(1))

    def forward(self, x):
        return self.fn(x) * self.g


class Block(BaseModule):
    def __init__(self, dim, dim_out, groups=8):
        super(Block, self).__init__()
        self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3,
                                         padding=1), torch.nn.GroupNorm(
                                         groups, dim_out), Mish())

    def forward(self, x, mask):
        output = self.block(x * mask)
        return output * mask


class ResnetBlock(BaseModule):
    def __init__(self, dim, dim_out, time_emb_dim, groups=8):
        super(ResnetBlock, self).__init__()
        self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim,
                                                               dim_out))

        self.block1 = Block(dim, dim_out, groups=groups)
        self.block2 = Block(dim_out, dim_out, groups=groups)
        if dim != dim_out:
            self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
        else:
            self.res_conv = torch.nn.Identity()

    def forward(self, x, mask, time_emb):
        h = self.block1(x, mask)
        h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
        h = self.block2(h, mask)
        output = h + self.res_conv(x * mask)
        return output


class LinearAttention(BaseModule):
    def __init__(self, dim, heads=4, dim_head=32):
        super(LinearAttention, self).__init__()
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x)
        q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)',
                            heads = self.heads, qkv=3)
        k = k.softmax(dim=-1)
        context = torch.einsum('bhdn,bhen->bhde', k, v)
        out = torch.einsum('bhde,bhdn->bhen', context, q)
        out = rearrange(out, 'b heads c (h w) -> b (heads c) h w',
                        heads=self.heads, h=h, w=w)
        return self.to_out(out)


class Residual(BaseModule):
    def __init__(self, fn):
        super(Residual, self).__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        output = self.fn(x, *args, **kwargs) + x
        return output


class SinusoidalPosEmb(BaseModule):
    def __init__(self, dim):
        super(SinusoidalPosEmb, self).__init__()
        self.dim = dim

    def forward(self, x, scale=1000):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
        emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


class GradLogPEstimator2d(BaseModule):
    def __init__(self, dim, dim_mults=(1, 2, 4), groups=8,
                 n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000):
        super(GradLogPEstimator2d, self).__init__()
        self.dim = dim
        self.dim_mults = dim_mults
        self.groups = groups
        self.n_spks = n_spks if not isinstance(n_spks, type(None)) else 1
        self.spk_emb_dim = spk_emb_dim
        self.pe_scale = pe_scale

        if n_spks > 1:
            self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(),
                                               torch.nn.Linear(spk_emb_dim * 4, n_feats))
        self.time_pos_emb = SinusoidalPosEmb(dim)
        self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(),
                                       torch.nn.Linear(dim * 4, dim))

        dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))
        self.downs = torch.nn.ModuleList([])
        self.ups = torch.nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)
            self.downs.append(torch.nn.ModuleList([
                ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
                ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
                Residual(Rezero(LinearAttention(dim_out))),
                Downsample(dim_out) if not is_last else torch.nn.Identity()]))

        mid_dim = dims[-1]
        self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
        self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
        self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            self.ups.append(torch.nn.ModuleList([
                ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
                ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
                Residual(Rezero(LinearAttention(dim_in))),
                Upsample(dim_in)]))
        self.final_block = Block(dim, dim)
        self.final_conv = torch.nn.Conv2d(dim, 1, 1)

    def forward(self, x, mask, mu, t, spk=None):
        if not isinstance(spk, type(None)):
            s = self.spk_mlp(spk)

        t = self.time_pos_emb(t, scale=self.pe_scale)
        t = self.mlp(t)

        if self.n_spks < 2:
            x = torch.stack([mu, x], 1)
        else:
            s = s.unsqueeze(-1).repeat(1, 1, x.shape[-1])
            x = torch.stack([mu, x, s], 1)
        mask = mask.unsqueeze(1)

        hiddens = []
        masks = [mask]
        for resnet1, resnet2, attn, downsample in self.downs:
            mask_down = masks[-1]
            x = resnet1(x, mask_down, t)
            x = resnet2(x, mask_down, t)
            x = attn(x)
            hiddens.append(x)
            x = downsample(x * mask_down)
            masks.append(mask_down[:, :, :, ::2])

        masks = masks[:-1]
        mask_mid = masks[-1]
        x = self.mid_block1(x, mask_mid, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, mask_mid, t)

        for resnet1, resnet2, attn, upsample in self.ups:
            mask_up = masks.pop()
            x = torch.cat((x, hiddens.pop()), dim=1)
            x = resnet1(x, mask_up, t)
            x = resnet2(x, mask_up, t)
            x = attn(x)
            x = upsample(x * mask_up)

        x = self.final_block(x, mask)
        output = self.final_conv(x * mask)

        return (output * mask).squeeze(1)


def get_noise(t, beta_init, beta_term, cumulative=False):
    if cumulative:
        noise = beta_init * t + 0.5 * (beta_term - beta_init) * (t**2)
    else:
        noise = beta_init + (beta_term - beta_init) * t
    return noise


class Diffusion(BaseModule):
    def __init__(self, n_feats, dim,
                 n_spks=1, spk_emb_dim=64,
                 beta_min=0.05, beta_max=20, pe_scale=1000):
        super(Diffusion, self).__init__()
        self.n_feats = n_feats
        self.dim = dim
        self.n_spks = n_spks
        self.spk_emb_dim = spk_emb_dim
        self.beta_min = beta_min
        self.beta_max = beta_max
        self.pe_scale = pe_scale

        self.estimator = GradLogPEstimator2d(dim, n_spks=n_spks,
                                             spk_emb_dim=spk_emb_dim,
                                             pe_scale=pe_scale)

    def forward_diffusion(self, x0, mask, mu, t):
        time = t.unsqueeze(-1).unsqueeze(-1)
        cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
        mean = x0 * torch.exp(-0.5 * cum_noise) + mu * (1.0 - torch.exp(-0.5 * cum_noise))
        variance = 1.0 - torch.exp(-cum_noise)
        z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, requires_grad=False)
        xt = mean + z * torch.sqrt(variance)
        return xt * mask, z * mask

    @torch.no_grad()
    def reverse_diffusion(self, z, mask, mu, n_timesteps, stoc=False, spk=None):
        h = 1.0 / n_timesteps
        xt = z * mask
        for i in range(n_timesteps):
            t = (1.0 - (i + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device)
            time = t.unsqueeze(-1).unsqueeze(-1)
            noise_t = get_noise(time, self.beta_min, self.beta_max, cumulative=False)
            if stoc:  # adds stochastic term
                dxt_det = 0.5 * (mu - xt) - self.estimator(xt, mask, mu, t, spk)
                dxt_det = dxt_det * noise_t * h
                dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device, requires_grad=False)
                dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h)
                dxt = dxt_det + dxt_stoc
            else:
                dxt = 0.5 * (mu - xt - self.estimator(xt, mask, mu, t, spk))
                dxt = dxt * noise_t * h
            xt = (xt - dxt) * mask
        return xt

    def forward_diffusion_eps(self, x0, mask, mu, eps, t):
        time = t.unsqueeze(-1).unsqueeze(-1)
        cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
        mean = x0 * torch.exp(-0.5 * cum_noise) + mu * (1.0 - torch.exp(-0.5 * cum_noise))
        variance = 1.0 - torch.exp(-cum_noise)
        xt = mean + eps * torch.sqrt(variance)
        return xt * mask

    def naive_attack(self, x0, mu, mask, n_timesteps, reverse_init, denoise_init, terminal_time=0.12, spk=None):
        intermediate_reverse = self.naive_reverse_eps_abs(x0, mask, mu, n_timesteps, terminal_time, init_point=0, spk=spk)
        intermediate_denoise = self.naive_denoise_eps_abs(x0, intermediate_reverse, mask, mu, terminal_time, init_point=0, spk=spk)
        # del intermediate_reverse[-1]
        # del intermediate_denoise[0]
        intermediate_reverse, intermediate_denoise = torch.stack(intermediate_reverse), torch.stack(intermediate_denoise)
        return ((intermediate_reverse - intermediate_denoise)**2).flatten(2).sum(dim=-1)

    def naive_denoise_eps_abs(self, x0, xt_intermediate, mask, mu, terminal_time=0.12, stoc=False, spk=None, init_point=0):
        n_timesteps = len(xt_intermediate)
        h = terminal_time / n_timesteps
        intermediate_result = []
        x0 = x0 * mask
        
        for i, eps in enumerate(xt_intermediate):
            # eps = eps * mask
            t = (((i + 1) + init_point) * h) * torch.ones(eps.shape[0], dtype=eps.dtype, device=eps.device)
            time = t.unsqueeze(-1).unsqueeze(-1)
            xt = self.forward_diffusion_eps(x0, mask, mu, eps, t)
            cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
            eps_t = -self.estimator(xt, mask, mu, t, spk) * torch.sqrt(1.0 - torch.exp(-cum_noise))
            intermediate_result.append(eps_t * mask)

        return intermediate_result

    def naive_reverse_eps_abs(self, x0, mask, mu, n_timesteps, terminal_time=0.12, stoc=False, spk=None, init_point=0):
        h = terminal_time / n_timesteps
        xt = x0 * mask

        intermediate_result = []

        i = 0
        t = ((i + init_point) * h) * torch.ones(x0.shape[0], dtype=x0.dtype, device=x0.device)
        time = t.unsqueeze(-1).unsqueeze(-1)

        for i in range(n_timesteps):
            eps = torch.randn_like(x0)
            intermediate_result.append(eps * mask)
        return intermediate_result


    def SecMI(self, x0, mu, mask, n_timesteps, reverse_init=0, denoise_init=0, terminal_time=0.12, spk=None):
        intermediate_reverse = self.ddim_reverse(x0, mask, mu, n_timesteps, terminal_time, init_point=reverse_init, spk=spk)
        intermediate_denoise = self.ddim_denoise(intermediate_reverse, mask, mu, terminal_time, init_point=denoise_init, spk=spk)
        del intermediate_reverse[-1]
        del intermediate_denoise[0]
        intermediate_reverse, intermediate_denoise = torch.stack(intermediate_reverse), torch.stack(intermediate_denoise)
        return ((intermediate_reverse - intermediate_denoise)**2).flatten(2).sum(dim=-1)

    def ddim_denoise(self, xt_intermediate, mask, mu, terminal_time=0.12, stoc=False, spk=None, init_point=0):
        n_timesteps = len(xt_intermediate)
        h = terminal_time / n_timesteps
        intermediate_result = []
        for i, xt in enumerate(xt_intermediate):
            xt = xt * mask
            t = (((i + 1) + init_point) * h) * torch.ones(xt.shape[0], dtype=xt.dtype,
                                                          device=xt.device)
            time = t.unsqueeze(-1).unsqueeze(-1)
            noise_t = get_noise(time, self.beta_min, self.beta_max,
                                cumulative=False)
            if stoc:  # adds stochastic term
                dxt_det = 0.5 * (mu - xt) - self.estimator(xt, mask, mu, t, spk)
                dxt_det = dxt_det * noise_t * h
                dxt_stoc = torch.randn(xt.shape, dtype=xt.dtype, device=xt.device,
                                       requires_grad=False)
                dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h)
                dxt = dxt_det + dxt_stoc
            else:
                dxt = 0.5 * (mu - xt - self.estimator(xt, mask, mu, t, spk))
                dxt = dxt * noise_t * h
            xt = (xt - dxt) * mask
            intermediate_result.append(xt)

        return intermediate_result

    def ddim_reverse(self, x0, mask, mu, n_timesteps, terminal_time=0.12, stoc=False, spk=None, init_point=0):
        h = terminal_time / n_timesteps
        xt = x0 * mask

        intermediate_result = []

        for i in range(n_timesteps):
            t = ((i + init_point) * h) * torch.ones(x0.shape[0], dtype=x0.dtype, device=x0.device)
            time = t.unsqueeze(-1).unsqueeze(-1)
            noise_t = get_noise(time, self.beta_min, self.beta_max, cumulative=False)
            if stoc:  # adds stochastic term
                dxt_det = 0.5 * (mu - xt) - self.estimator(xt, mask, mu, t, spk)
                dxt_det = dxt_det * noise_t * h
                dxt_stoc = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device,
                                       requires_grad=False)
                dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h)
                dxt = dxt_det + dxt_stoc
            else:
                dxt = 0.5 * (mu - xt - self.estimator(xt, mask, mu, t, spk))
                dxt = dxt * noise_t * h
            xt = (xt + dxt) * mask
            intermediate_result.append(xt)
        return intermediate_result

    def PIA(self, x0, mu, mask, n_timesteps, reverse_init=0, denoise_init=0, terminal_time=0.12, spk=None, lp=4):
        intermediate_reverse = self.ddim_reverse_groundtruth_correct(x0, mask, mu, n_timesteps, terminal_time, init_point=reverse_init, spk=spk)
        intermediate_denoise = self.ddim_denoise_groundtruth_correct(intermediate_reverse, mask, mu, terminal_time, init_point=denoise_init, spk=spk)
        # del intermediate_reverse[-1]
        # del intermediate_denoise[0]
        intermediate_reverse, intermediate_denoise = torch.stack(intermediate_reverse), torch.stack(intermediate_denoise)
        return ((intermediate_reverse - intermediate_denoise).abs()**lp).flatten(2).sum(dim=-1)

    def ddim_denoise_groundtruth_correct(self, xt_intermediate, mask, mu, terminal_time=0.12, stoc=False, spk=None, init_point=0):
        n_timesteps = len(xt_intermediate)
        h = terminal_time / n_timesteps
        intermediate_result = []
        for i, xt in enumerate(xt_intermediate):
            t = (((i + 1) + init_point) * h) * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device)
            xt = xt * mask
            intermediate_result.append(mu - self.estimator(xt, mask, mu, t, spk))

        return intermediate_result

    def ddim_reverse_groundtruth_correct(self, x0, mask, mu, n_timesteps, terminal_time=0.12, stoc=False, spk=None, init_point=0):
        h = terminal_time / n_timesteps
        xt = x0

        intermediate_result = []

        i = 10
        t = ((i + init_point) * h) * torch.ones(x0.shape[0], dtype=x0.dtype, device=x0.device)
        time = t.unsqueeze(-1).unsqueeze(-1)
        # noise_t = get_noise(time, self.beta_min, self.beta_max, cumulative=False)
        cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
        eps = -self.estimator(xt, mask, mu, t, spk) * torch.sqrt(1.0 - torch.exp(-cum_noise))
        eps = eps * mask

        for i in range(n_timesteps):
            t = (((i + 1) + init_point) * h) * torch.ones(eps.shape[0], dtype=eps.dtype, device=eps.device)
            time = t.unsqueeze(-1).unsqueeze(-1)
            xt = self.forward_diffusion_eps(x0, mask, mu, eps, t) # reverse trajectory의 intermediate 상태들을 저장
            intermediate_result.append(xt)
        return intermediate_result

    def PIAN(self, x0, mu, mask, n_timesteps, reverse_init=0, denoise_init=0, terminal_time=0.12, spk=None, lp=4):
        intermediate_reverse = self.ddim_reverse_groundtruth_correct_abs(x0, mask, mu, n_timesteps, terminal_time, init_point=reverse_init, spk=spk)
        intermediate_denoise = self.ddim_denoise_groundtruth_correct_abs(intermediate_reverse, mask, mu, terminal_time, init_point=denoise_init, spk=spk)
        del intermediate_reverse[-1]
        del intermediate_denoise[0]
        intermediate_reverse, intermediate_denoise = torch.stack(intermediate_reverse), torch.stack(intermediate_denoise)
        return ((intermediate_reverse - intermediate_denoise).abs()**lp).flatten(2).sum(dim=-1)

    def ddim_denoise_groundtruth_correct_abs(self, xt_intermediate, mask, mu, terminal_time=0.12, stoc=False, spk=None, init_point=0):
        n_timesteps = len(xt_intermediate)
        h = terminal_time / n_timesteps
        intermediate_result = []
        for i, xt in enumerate(xt_intermediate):
            t = (((i + 1) + init_point) * h) * torch.ones(xt.shape[0], dtype=xt.dtype,
                                                          device=xt.device)
            xt = xt * mask
            intermediate_result.append(mu - self.estimator(xt, mask, mu, t, spk))

        return intermediate_result

    def ddim_reverse_groundtruth_correct_abs(self, x0, mask, mu, n_timesteps, terminal_time=0.12, stoc=False, spk=None, init_point=0):
        h = terminal_time / n_timesteps
        xt = x0 * mask

        intermediate_result = []

        i = 10
        t = ((i + init_point) * h) * torch.ones(x0.shape[0], dtype=x0.dtype,
                                                device=x0.device)
        time = t.unsqueeze(-1).unsqueeze(-1)
        # noise_t = get_noise(time, self.beta_min, self.beta_max,
        #                     cumulative=False)
        cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
        eps = -self.estimator(xt, mask, mu, t, spk) * torch.sqrt(1.0 - torch.exp(-cum_noise))
        eps = eps * mask
        eps = eps / (eps.abs().mean([-1, -2], keepdim=True) * (2 / torch.pi) ** 0.5 + 0.000000000001)

        for i in range(n_timesteps):
            t = (((i + 1) + init_point) * h) * torch.ones(eps.shape[0], dtype=eps.dtype,
                                                          device=eps.device)
            time = t.unsqueeze(-1).unsqueeze(-1)
            xt = self.forward_diffusion_eps(x0, mask, mu, eps, t)
            intermediate_result.append(xt)
        return intermediate_result


    @torch.no_grad()
    def forward(self, z, mask, mu, n_timesteps, stoc=False, spk=None):
        return self.reverse_diffusion(z, mask, mu, n_timesteps, stoc, spk)

    def loss_t(self, x0, mask, mu, t, spk=None):
        xt, z = self.forward_diffusion(x0, mask, mu, t)
        time = t.unsqueeze(-1).unsqueeze(-1)
        cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
        noise_estimation = self.estimator(xt, mask, mu, t, spk)
        noise_estimation *= torch.sqrt(1.0 - torch.exp(-cum_noise))
        loss = torch.sum((noise_estimation + z)**2) / (torch.sum(mask) * self.n_feats)
        return loss, xt

    def compute_loss(self, x0, mask, mu, spk=None, offset=1e-5):
        t = torch.rand(x0.shape[0], dtype=x0.dtype, device=x0.device, requires_grad=False)
        t = torch.clamp(t, offset, 1.0 - offset)
        return self.loss_t(x0, mask, mu, t, spk)
