import torch
import torch.nn as nn

import espnet2.gan_svs.pits.modules as modules


# TODO (Yifeng): This comment is generated by ChatGPT, which may not be accurate.
class YingDecoder(nn.Module):
    """Ying decoder module."""

    def __init__(
        self,
        hidden_channels,
        kernel_size,
        dilation_rate,
        n_layers,
        yin_start,
        yin_scope,
        yin_shift_range,
        gin_channels=0,
    ):
        """Initialize the YingDecoder module.

        Args:
            hidden_channels (int): Number of hidden channels.
            kernel_size (int): Size of the convolutional kernel.
            dilation_rate (int): Dilation rate of the convolutional layers.
            n_layers (int): Number of convolutional layers.
            yin_start (int): Start point of the yin target signal.
            yin_scope (int): Scope of the yin target signal.
            yin_shift_range (int): Maximum number of frames to shift the yin
                                   target signal.
            gin_channels (int, optional): Number of global conditioning channels.
                                          Defaults to 0.
        """
        super().__init__()
        self.in_channels = yin_scope
        self.out_channels = yin_scope
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.dilation_rate = dilation_rate
        self.n_layers = n_layers
        self.gin_channels = gin_channels

        self.yin_start = yin_start
        self.yin_scope = yin_scope
        self.yin_shift_range = yin_shift_range

        self.pre = nn.Conv1d(self.in_channels, hidden_channels, 1)
        self.dec = modules.WN(
            hidden_channels,
            kernel_size,
            dilation_rate,
            n_layers,
            gin_channels=gin_channels,
        )
        self.proj = nn.Conv1d(hidden_channels, self.out_channels, 1)

    def crop_scope(self, x, yin_start, scope_shift):
        """Crop the input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape [B, C, T].
            yin_start (int): Starting point of the yin target signal.
            scope_shift (torch.Tensor): Shift tensor of shape [B].

        Returns:
            torch.Tensor: Cropped tensor of shape [B, C, yin_scope].

        """
        return torch.stack(
            [
                x[
                    i,
                    yin_start
                    + scope_shift[i] : yin_start
                    + self.yin_scope
                    + scope_shift[i],
                    :,
                ]
                for i in range(x.shape[0])
            ],
            dim=0,
        )

    def infer(self, z_yin, z_mask, g=None):
        """Generate yin prediction.

        Args:
            z_yin (torch.Tensor): Input yin target tensor of shape [B, yin_scope, C].
            z_mask (torch.Tensor): Input mask tensor of shape [B, yin_scope, 1].
            g (torch.Tensor, optional): Global conditioning tensor of shape
                                        [B, gin_channels, 1]. Defaults to None.

        Returns:
            torch.Tensor: Predicted yin tensor of shape [B, yin_scope, C].

        """
        B = z_yin.shape[0]
        scope_shift = torch.randint(
            -self.yin_shift_range, self.yin_shift_range, (B,), dtype=torch.int
        )
        z_yin_crop = self.crop_scope(z_yin, self.yin_start, scope_shift)
        x = self.pre(z_yin_crop) * z_mask
        x = self.dec(x, z_mask, g=g)
        yin_hat_crop = self.proj(x) * z_mask
        return yin_hat_crop

    def forward(self, z_yin, yin_gt, z_mask, g=None):
        """Forward pass of the decoder.

        Args:
            z_yin (torch.Tensor): The input yin note sequence of shape (B, C, T_yin).
            yin_gt (torch.Tensor): The ground truth yin note sequence of shape
                                   (B, C, T_yin).
            z_mask (torch.Tensor): The mask tensor of shape (B, 1, T_yin).
            g (torch.Tensor): The global conditioning tensor.
        Returns:
            torch.Tensor: The predicted yin note sequence of shape (B, C, T_yin).
            torch.Tensor: The shifted ground truth yin note sequence of shape
                          (B, C, T_yin).
            torch.Tensor: The cropped ground truth yin note sequence of shape
                          (B, C, T_yin).
            torch.Tensor: The cropped input yin note sequence of shape (B, C, T_yin).
            torch.Tensor: The scope shift tensor of shape (B,).

        """
        B = z_yin.shape[0]
        scope_shift = torch.randint(
            -self.yin_shift_range, self.yin_shift_range, (B,), dtype=torch.int
        )
        z_yin_crop = self.crop_scope(z_yin, self.yin_start, scope_shift)
        yin_gt_shifted_crop = self.crop_scope(yin_gt, self.yin_start, scope_shift)
        yin_gt_crop = self.crop_scope(
            yin_gt, self.yin_start, torch.zeros_like(scope_shift)
        )
        x = self.pre(z_yin_crop) * z_mask
        x = self.dec(x, z_mask, g=g)
        yin_hat_crop = self.proj(x) * z_mask
        return yin_gt_crop, yin_gt_shifted_crop, yin_hat_crop, z_yin_crop, scope_shift
