import torch
import torch.nn as nn
import torch.nn.functional as F


def get_model(config):
    model_name = config['model_class']
    model_class = eval(model_name)
    model = model_class(config)
    return model


class PianoArrangeUnet(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Encoder部分
        self.enc1 = nn.Sequential(
            nn.Conv2d(1, 4, kernel_size=3, padding=1),
            nn.BatchNorm2d(4),
            nn.ReLU(inplace=True)
        )
        self.enc2 = nn.Sequential(
            nn.Conv2d(4, 8, kernel_size=3, padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(inplace=True)
        )
        self.enc3 = nn.Sequential(
            nn.Conv2d(8, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )
        self.enc4 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )

        # Decoder部分
        self.dec3 = nn.Sequential(
            nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(16, 8, kernel_size=2, stride=2),
            nn.BatchNorm2d(8),
            nn.ReLU(inplace=True)
        )
        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(8, 4, kernel_size=2, stride=2),
            nn.BatchNorm2d(4),
            nn.ReLU(inplace=True)
        )
        self.final = nn.Conv2d(4, 1, kernel_size=1)
        
    def forward(self, batch):
        mix_prolls = batch['mix_prolls']

        # Binarize the input and target
        x = (mix_prolls != 0).float()
        
        # Encoder
        x = x.unsqueeze(1)  # [bs, 1, 16, 128]
        skip1 = self.enc1(x)  # [bs, 4, 16, 128]
        skip2 = self.enc2(nn.functional.avg_pool2d(skip1, 2))  # [bs, 8, 8, 64]
        skip3 = self.enc3(nn.functional.avg_pool2d(skip2, 2))  # [bs, 16, 4, 32]
        bottleneck = self.enc4(nn.functional.avg_pool2d(skip3, 2))  # [bs, 32, 2, 16]
        
        # Decoder with skip connections
        up3 = self.dec3(bottleneck)  # [bs, 16, 4, 32]
        up3 = up3 + skip3  # Skip connection
        
        up2 = self.dec2(up3)  # [bs, 8, 8, 64]
        up2 = up2 + skip2  # Skip connection
        
        up1 = self.dec1(up2)  # [bs, 4, 16, 128]
        up1 = up1 + skip1  # Skip connection
        
        output = self.final(up1)  # [bs, 1, 16, 128]
        logits = output.squeeze(1) # [bs, 16, 128]

        # Assembly the output
        # pred = (logits > 0).int()
        thres = self.config['thres']
        prob = torch.sigmoid(logits)
        pred = (prob > thres).int()
        non_zero_idx = (pred != 0)
        pred[non_zero_idx] = mix_prolls[non_zero_idx].int()
        ret = {
            'logits': logits,
            'pred': pred,
        }

        # Calculate loss
        if 'piano_prolls' in batch:
            ground_truth = batch['piano_prolls']
            tgt = (ground_truth != 0).float()
            loss_note = F.binary_cross_entropy_with_logits(input=logits, target=tgt, pos_weight=torch.tensor(4))
            ret['loss_note'] = loss_note

            if self.config.get('note_loss_only', False):
                loss_pos = torch.tensor(0)
                loss_bar = torch.tensor(0)
            else:
                loss_pos = note_density_loss_pos(logits, tgt)
                loss_bar = note_density_loss_bar(logits, tgt)

            ret['loss_pos'] = loss_pos
            ret['loss_bar'] = loss_bar
            loss_tot = loss_note + loss_pos + loss_bar
            ret['loss_tot'] = loss_tot
        
        return ret
    

class PianoArrangeUnetNoPool(nn.Module):
    '''
    This model perform kind of bad. Do not use it.
    '''
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Encoder部分
        kernal_size = 4
        stride = 2
        self.enc1 = nn.Sequential(
            nn.Conv2d(1, 4, kernel_size=kernal_size, padding=1, stride=stride),
            nn.BatchNorm2d(4),
            nn.ReLU(inplace=True)
        )
        self.enc2 = nn.Sequential(
            nn.Conv2d(4, 8, kernel_size=kernal_size, padding=1, stride=stride),
            nn.BatchNorm2d(8),
            nn.ReLU(inplace=True)
        )
        self.enc3 = nn.Sequential(
            nn.Conv2d(8, 16, kernel_size=kernal_size, padding=1, stride=stride),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )
        self.enc4 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=kernal_size, padding=1, stride=stride),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )

        # Decoder部分
        self.dec4 = nn.Sequential(
            nn.ConvTranspose2d(32, 16, kernel_size=kernal_size, stride=stride),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )
        self.dec3 = nn.Sequential(
            nn.ConvTranspose2d(16, 8, kernel_size=kernal_size, stride=stride),
            nn.BatchNorm2d(8),
            nn.ReLU(inplace=True)
        )
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(8, 4, kernel_size=kernal_size, stride=stride),
            nn.BatchNorm2d(4),
            nn.ReLU(inplace=True)
        )
        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(4, 1, kernel_size=kernal_size, stride=stride),
            nn.BatchNorm2d(1),
            nn.ReLU(inplace=True)
        )
        # self.final = nn.Conv2d(4, 1, kernel_size=1)

    def forward(self, batch):
        mix_prolls = batch['mix_prolls']

        # Binarize the input and target
        x = (mix_prolls != 0).float()
        
        # Encoder
        x = x.unsqueeze(1)  # [bs, 1, 16, 128]
        skip1 = self.enc1(x)  # [bs, 4, 8, 64]
        skip2 = self.enc2(skip1)  # [bs, 8, 4, 32]
        skip3 = self.enc3(skip2)  # [bs, 16, 2, 16]
        bottleneck = self.enc4(skip3)  # [bs, 32, 1, 8]
        
        # Decoder with skip connections
        up3 = self.dec4(bottleneck)[:, :, :2, :16]  # [bs, 16, 2, 16]
        up3 = up3 + skip3  # Skip connection
        
        up2 = self.dec3(up3)[:, :, :4, :32]  # [bs, 8, 4, 32]
        up2 = up2 + skip2  # Skip connection
        
        up1 = self.dec2(up2)[:, :, :8, :64]  # [bs, 4, 8, 64]
        up1 = up1 + skip1  # Skip connection
        
        output = self.dec1(up1)[:, :, :16, :128]  # [bs, 1, 16, 128]
        # output = self.final(up1)[:, :, :16, :128]  # [bs, 1, 16, 128]
        logits = output.squeeze(1) # [bs, 16, 128]

        # Assembly the output
        # pred = (logits > 0).int()
        thres = self.config['thres']
        prob = torch.sigmoid(logits)
        pred = (prob > thres).int()
        non_zero_idx = (pred != 0)
        pred[non_zero_idx] = mix_prolls[non_zero_idx].int()
        ret = {
            'logits': logits,
            'pred': pred,
        }

        # Calculate loss
        if 'piano_prolls' in batch:
            ground_truth = batch['piano_prolls']
            tgt = (ground_truth != 0).float()
            loss_note = F.binary_cross_entropy_with_logits(input=logits, target=tgt, pos_weight=torch.tensor(4))
            ret['loss_note'] = loss_note

            loss_pos = note_density_loss_pos(logits, tgt)
            ret['loss_pos'] = loss_pos

            loss_bar = note_density_loss_bar(logits, tgt)
            ret['loss_bar'] = loss_bar

            loss_tot = loss_note + loss_pos + loss_bar
            ret['loss_tot'] = loss_tot
        
        return ret
    

def note_density_loss_pos(logits, tgt):
    '''
    Calculate position-level note density loss

    Args:
    - logits: Tensor, shape [batch_size, pos=16, pitch=128]
    - tgt: Tensor, shape [batch_size, pos=16, pitch=128]
    '''
    # Simulate binarized logits by gumbel sigmoid
    pred = gumbel_sigmoid(logits=logits)

    # Calculate note density per position
    onset_per_pos_tgt = tgt.sum(dim=2)  # [batch_size, pos=16]
    onset_per_pos_out = pred.sum(dim=2)  # [batch_size, pos=16]

    loss = jensen_shannon_loss_2d(prediction=onset_per_pos_out, target=onset_per_pos_tgt)
    
    return loss


def note_density_loss_bar(logits, tgt):
    '''
    Calculate bar-level note density loss
    - logits: Tensor, shape [batch_size, pos=16, pitch=128]
    - tgt: Tensor, shape [batch_size, pos=16, pitch=128]
    '''
    pred = gumbel_sigmoid(logits=logits)

    # Calculate note density per bar
    onset_per_bar_tgt = tgt.sum(dim=(1, 2))  # [batch_size]
    onset_per_bar_out = pred.sum(dim=(1, 2))  # [batch_size]

    loss = jensen_shannon_loss_1d(prediction=onset_per_bar_out, target=onset_per_bar_tgt)

    return loss


def gumbel_sigmoid(logits: torch.Tensor, tau: float = 1, hard: bool = False, threshold: float = 0.5) -> torch.Tensor:
    """
    Samples from the Gumbel-Sigmoid distribution and optionally discretizes.
    The discretization converts the values greater than `threshold` to 1 and the rest to 0.
    The code is adapted from the official PyTorch implementation of gumbel_softmax:
    https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax

    Args:
      logits: `[..., num_features]` unnormalized log probabilities
      tau: non-negative scalar temperature
      hard: if ``True``, the returned samples will be discretized,
            but will be differentiated as if it is the soft sample in autograd
     threshold: threshold for the discretization,
                values greater than this will be set to 1 and the rest to 0

    Returns:
      Sampled tensor of same shape as `logits` from the Gumbel-Sigmoid distribution.
      If ``hard=True``, the returned samples are descretized according to `threshold`, otherwise they will
      be probability distributions.

    """
    gumbels = (
        -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
    )  # ~Gumbel(0, 1)
    gumbels = (logits + gumbels) / tau  # ~Gumbel(logits, tau)
    y_soft = gumbels.sigmoid()

    if hard:
        # Straight through.
        indices = (y_soft > threshold).nonzero(as_tuple=True)
        y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format)
        y_hard[indices[0], indices[1]] = 1.0
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft
    return ret


def jensen_shannon_loss_2d(prediction: torch.Tensor, target: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    """
    Computes the Jensen-Shannon divergence loss between two distributions.
    Normalization is done at sample level
    
    Args:
        prediction (torch.Tensor): Predicted tensor, shape [batch_size, pos].
        target (torch.Tensor): Target tensor, shape [batch_size, pos].
        eps (float): Small constant to avoid numerical instability when taking logs.
        
    Returns:
        torch.Tensor: Scalar loss value representing the Jensen-Shannon divergence.
    """
    # Normalize target and prediction to make them valid probability distributions
    target = target / (target.sum(dim=1, keepdim=True) + eps)  # Normalize target along pos
    prediction = prediction / (prediction.sum(dim=1, keepdim=True) + eps)  # Normalize prediction along pos
    
    # Compute the mixture distribution M = 0.5 * (P + Q)
    mixture = 0.5 * (target + prediction)
    
    # Compute KL divergence: D_KL(P || M) and D_KL(Q || M)
    kl_target_mixture = F.kl_div(mixture.log(), target, log_target=True)
    kl_prediction_mixture = F.kl_div(mixture.log(), prediction, log_target=True)
    
    # Compute Jensen-Shannon divergence: D_JS = 0.5 * (KL(P || M) + KL(Q || M))
    js_divergence = 0.5 * (kl_target_mixture + kl_prediction_mixture)
    
    return js_divergence


def jensen_shannon_loss_1d(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """
    Computes the Jensen-Shannon divergence loss between two distributions across the batch dimension.
    Normalization is done at batch level
    
    Args:
        prediction (torch.Tensor): Predicted logits or values, shape [batch_size,].
        target (torch.Tensor): Target values, shape [batch_size,].

    Returns:
        torch.Tensor: Scalar loss representing the Jensen-Shannon divergence.
    """
    # Normalize prediction and target to valid probability distributions across the batch
    prediction = prediction / prediction.sum(dim=0, keepdim=True)
    target = target / target.sum(dim=0, keepdim=True)

    # Mixture distribution M = 0.5 * (P + Q)
    mixture = 0.5 * (target + prediction)

    # Compute KL divergences
    kl_target_mixture = F.kl_div(mixture.log(), target, log_target=True)
    kl_prediction_mixture = F.kl_div(mixture.log(), prediction, log_target=True)

    # JS divergence = 0.5 * (KL(target || mixture) + KL(prediction || mixture))
    js_divergence = 0.5 * (kl_target_mixture + kl_prediction_mixture)

    return js_divergence