import functools
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm, spectral_norm, remove_weight_norm
from stylegan_t.networks.shared import FullyConnectedLayer
from clap.hook import CLAP_Module
CLAP_MODEL_PATH = 'ckpt/630k-audioset-best.pt'

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

def WNConv1d(*args, **kwargs):
    act = kwargs.pop("act", True)
    conv = weight_norm(nn.Conv1d(*args, **kwargs))
    if not act:
        return conv
    return nn.Sequential(conv, nn.LeakyReLU(0.1))

def WNConv2d(*args, **kwargs):
    act = kwargs.pop("act", True)
    conv = weight_norm(nn.Conv2d(*args, **kwargs))
    if not act:
        return conv
    return nn.Sequential(conv, nn.LeakyReLU(0.1))

class ActNorm(nn.Module):
    def __init__(self, num_features, logdet=False, affine=True,
                 allow_reverse_init=False):
        assert affine
        super().__init__()
        self.logdet = logdet
        self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
        self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
        self.allow_reverse_init = allow_reverse_init

        self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))

    def initialize(self, input):
        with torch.no_grad():
            flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
            mean = (
                flatten.mean(1)
                .unsqueeze(1)
                .unsqueeze(2)
                .unsqueeze(3)
                .permute(1, 0, 2, 3)
            )
            std = (
                flatten.std(1)
                .unsqueeze(1)
                .unsqueeze(2)
                .unsqueeze(3)
                .permute(1, 0, 2, 3)
            )

            self.loc.data.copy_(-mean)
            self.scale.data.copy_(1 / (std + 1e-6))

    def forward(self, input, reverse=False):
        if reverse:
            return self.reverse(input)
        if len(input.shape) == 2:
            input = input[:,:,None,None]
            squeeze = True
        else:
            squeeze = False

        _, _, height, width = input.shape

        if self.training and self.initialized.item() == 0:
            self.initialize(input)
            self.initialized.fill_(1)

        h = self.scale * (input + self.loc)

        if squeeze:
            h = h.squeeze(-1).squeeze(-1)

        if self.logdet:
            log_abs = torch.log(torch.abs(self.scale))
            logdet = height*width*torch.sum(log_abs)
            logdet = logdet * torch.ones(input.shape[0]).to(input)
            return h, logdet

        return h

    def reverse(self, output):
        if self.training and self.initialized.item() == 0:
            if not self.allow_reverse_init:
                raise RuntimeError(
                    "Initializing ActNorm in reverse direction is "
                    "disabled by default. Use allow_reverse_init=True to enable."
                )
            else:
                self.initialize(output)
                self.initialized.fill_(1)

        if len(output.shape) == 2:
            output = output[:,:,None,None]
            squeeze = True
        else:
            squeeze = False

        h = output / self.scale - self.loc

        if squeeze:
            h = h.squeeze(-1).squeeze(-1)
        return h
BANDS = [(0.0, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
class MBDiscriminator(nn.Module):
    """Defines a PatchGAN discriminator as in Pix2Pix
        --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
    """
    def __init__(
        self, 
        ndf: int = 64, 
        n_bins: int = 64,
        bands: list = BANDS,
        increase_ch: bool = False,
    ):
        """Construct a Multi-Band discriminator
        Parameters:
            ndf (int)       -- the number of filters in the last conv layer
            n_bins (int)    -- the number of mel filterbank
            bands           -- list, optional, Bands to run discriminator over
            increase_ch     -- Enable increasing the number of filters
        """
        super(MBDiscriminator, self).__init__()
        bands = [(int(b[0] * n_bins), int(b[1] * n_bins)) for b in bands]
        self.bands = bands
        if increase_ch:
            convs = lambda: nn.ModuleList(
                [
                    WNConv2d(1, ndf, (3, 9), (1, 1), padding=(1, 4)),
                    WNConv2d(ndf, 2 * ndf, (3, 9), (1, 2), padding=(1, 4)),
                    WNConv2d(2 * ndf, 4 * ndf, (3, 9), (1, 2), padding=(1, 4)),
                    WNConv2d(4 * ndf, 8 * ndf, (3, 9), (1, 2), padding=(1, 4)),
                    WNConv2d(8 * ndf, ndf, (3, 3), (1, 1), padding=(1, 1)),
                ]
            )
        else:
            convs = lambda: nn.ModuleList(
                [
                    WNConv2d(1, ndf, (3, 9), (1, 1), padding=(1, 4)),
                    WNConv2d(ndf, ndf, (3, 9), (1, 2), padding=(1, 4)),
                    WNConv2d(ndf, ndf, (3, 9), (1, 2), padding=(1, 4)),
                    WNConv2d(ndf, ndf, (3, 9), (1, 2), padding=(1, 4)),
                    WNConv2d(ndf, ndf, (3, 3), (1, 1), padding=(1, 1)),
                ]
            )
        
        self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
        self.conv_post = WNConv2d(ndf, 1, (3, 3), (1, 1), padding=(1, 1), act=False)

    def forward(self, input):
        """Standard forward."""
        
        x_bands = [input[..., b[0] : b[1]] for b in self.bands]
        fmap = []

        input = []
        for band, stack in zip(x_bands, self.band_convs):
            for layer in stack:
                band = layer(band)
                fmap.append(band)
            input.append(band)

        input = torch.cat(input, dim=-1)
        input = self.conv_post(input)
        fmap.append(input)

        return input


class ConditionalMBDiscriminator(nn.Module):
    """Defines a PatchGAN discriminator as in Pix2Pix
        --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
    """
    def __init__(
        self, 
        ndf: int = 64, 
        n_bins: int = 64,
        bands: list = BANDS,
        increase_ch: bool = False,
        d_cond_type: str = 'text_encoder',
        c_dim: int = 1024,
        cmap_dim: int = 64,
        device=None,
    ):
        """Construct a Multi-Band discriminator
        Parameters:
            ndf (int)       -- the number of filters in the last conv layer
            n_bins (int)    -- the number of mel filterbank
            bands           -- list, optional, Bands to run discriminator over
            increase_ch     -- Enable increasing the number of filters
        """
        super(ConditionalMBDiscriminator, self).__init__()
        self.d_cond_type = d_cond_type
        self.c_dim = c_dim
        self.cmap_dim = cmap_dim
        
        self.cmapper = FullyConnectedLayer(self.c_dim, self.cmap_dim)


        bands = [(int(b[0] * n_bins), int(b[1] * n_bins)) for b in bands]
        self.bands = bands

        if increase_ch:
            convs = lambda: nn.ModuleList(
                [
                    WNConv2d(1, ndf, (3, 9), (1, 1), padding=(1, 4)),
                    WNConv2d(ndf, 2 * ndf, (3, 9), (1, 2), padding=(1, 4)),
                    WNConv2d(2 * ndf, 4 * ndf, (3, 9), (1, 2), padding=(1, 4)),
                    WNConv2d(4 * ndf, 8 * ndf, (3, 9), (1, 2), padding=(1, 4)),
                    WNConv2d(8 * ndf, ndf, (3, 3), (1, 1), padding=(1, 1)),
                ]
            )
        else:
            convs = lambda: nn.ModuleList(
                [
                    WNConv2d(1, ndf, (3, 9), (1, 1), padding=(1, 4)),
                    WNConv2d(ndf, ndf, (3, 9), (1, 2), padding=(1, 4)),
                    WNConv2d(ndf, ndf, (3, 9), (1, 2), padding=(1, 4)),
                    WNConv2d(ndf, ndf, (3, 9), (1, 2), padding=(1, 4)),
                    WNConv2d(ndf, ndf, (3, 3), (1, 1), padding=(1, 1)),
                ]
            )
        
        self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
        self.conv_post = WNConv2d(ndf, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
        
        self.cls = WNConv2d(1, self.cmap_dim, (3, 3), 1, padding=(1, 1), bias=False)
        
        if d_cond_type == 'clap_text_encoder':
            self.clap = CLAP_Module(enable_fusion=False, device=device)
            # self.clap = CLAP_Module(enable_fusion=False, amodel= 'HTSAT-base') # If you wanna use larger CLAP model, use this.
            self.clap.load_ckpt(CLAP_MODEL_PATH)
            for param in self.clap.model.parameters():
                param.requires_grad = False
            self.clap.model.eval()

    def forward(self, input, cond=None, model=None):
        """Standard forward."""
        x_bands = [input[..., b[0] : b[1]] for b in self.bands]
        fmap = []

        input = []
        if self.d_cond_type == 'text_encoder':
            try:
                encoder_hidden_states, _ = model.encode_text(cond) # [b, n_ctx, n_dim]
            except:
                encoder_hidden_states, _ = model.module.encode_text(cond) # [b, n_ctx, n_dim]
            encoder_hidden_states = torch.mean(encoder_hidden_states, dim=1) # [b, n_dim]
        elif self.d_cond_type == 'clap_text_encoder':
            model = self.clap
            try:
                encoder_hidden_states = model.get_text_embedding(cond, use_tensor=True)
            except:
                encoder_hidden_states = model.module.get_text_embedding(cond, use_tensor=True)

        for band, stack in zip(x_bands, self.band_convs):
            for layer in stack:
                band = layer(band)
                fmap.append(band)
            input.append(band)

        input = torch.cat(input, dim=-1)
        input = self.conv_post(input)
    
        
        # Following StyleGAN-T
        input = self.cls(input)
        cmap = self.cmapper(encoder_hidden_states).unsqueeze(-1).unsqueeze(-1)
        
        input = (input * cmap).sum(1, keepdim=True) * (1. / torch.sqrt(torch.tensor(self.cmap_dim, dtype=input.dtype, device=input.device)))
        
        fmap.append(input)
        return input