import os
import torch
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

class ReconstructionNorm(nn.Module):
    """
    Visualize reconstruction at a specific batch (batch_idx == 1000) with high resolution and prevent cropping.
    """

    def __init__(
        self,
        save_dir="/Figures",
        alpha=0.1,
        loss_type="smooth_l1",
        using_spectrogram=True
    ):
        super().__init__()
        self.save_dir = save_dir
        self.alpha = alpha
        self.loss_type = loss_type
        self.using_spectrogram = using_spectrogram

    def loss_fn(self, pred, target, token_mask):
        """
        Compute reconstruction loss for masked and visible regions.

        Args:
            pred: [B, num_patches, patch_dim]
            target: [B, num_patches, patch_dim]
            token_mask: [B, num_patches], 1 indicates masked patches

        Returns:
            Scalar loss combining masked and optional visible errors.
        """
        if self.loss_type == 'l1':
            loss_pixel = F.l1_loss(pred, target, reduction='none')
        elif self.loss_type == 'l2':
            loss_pixel = F.mse_loss(pred, target, reduction='none')
        elif self.loss_type == 'smooth_l1':
            loss_pixel = F.smooth_l1_loss(pred, target, reduction='none')
        else:
            raise ValueError("Invalid loss_type.")

        # Mean error per patch => [B, num_patches]
        loss_per_patch = loss_pixel.mean(dim=-1)

        # Masked region loss
        masked_loss = (loss_per_patch * token_mask).sum() / (token_mask.sum() + 1e-6)

        if self.alpha == 0:
            return masked_loss
        # Visible region loss
        visible_loss = (loss_per_patch * (1 - token_mask)).sum() / ((1 - token_mask).sum() + 1e-6)
        return masked_loss + self.alpha * visible_loss

    def reshape_patches_to_wave(self, patches, B, C, T, patch_width):
        """
        Reshape patches back to the original waveform.

        Args:
            patches: [B, C*N, patch_width]
            B: batch size
            C: channels
            T: total time steps
            patch_width: width of each patch

        Returns:
            Waveform [B, C, T]
        """
        N = T // patch_width
        return patches.reshape(B, C, N, patch_width).reshape(B, C, T)

    def forward(self, pred, batch):
        """
        Compute loss and optionally visualize reconstructions.

        Args:
            pred: [B, num_patches, patch_width] from decoder
            batch: dict with keys:
                'target', 'token_mask', 'wave_gt_2d', 'batch_idx', 'epoch_idx'

        Returns:
            loss, logging_output
        """
        target = batch['target']
        token_mask = batch['token_mask']
        loss = self.loss_fn(pred, target, token_mask)

        logging_output = {'recon_loss': loss.item()}
        batch_idx = batch['batch_idx']

        # Visualize only at the chosen batch
        if self.using_spectrogram and batch_idx == 1000:
            wave_gt_2d = batch['wave_gt_2d']  # [B, channels, T]
            B, C, T = wave_gt_2d.shape
            _, _, patch_width = target.shape

            pred_wave = self.reshape_patches_to_wave(pred, B, C, T, patch_width)
            target_wave = self.reshape_patches_to_wave(target, B, C, T, patch_width)

            # Build 2D mask [B, C, T]
            N = T // patch_width
            mask_exp = token_mask.reshape(B, C, N, 1).repeat(1, 1, 1, patch_width)
            mask_2d = mask_exp.reshape(B, C, T).bool()

            # Masked input and reconstruction
            masked_input = wave_gt_2d.masked_fill(mask_2d, 0)
            reconstruction = wave_gt_2d.where(~mask_2d, pred_wave)

            # Plot for batch 0, all channels
            b0 = 0
            fig, axs = plt.subplots(C, 5, figsize=(25, 5 * C))
            axs = axs.reshape(C, 5) if C > 1 else axs[np.newaxis, :]

            for ch in range(C):
                data = {
                    'Target': target_wave[b0, ch],
                    'Prediction': pred_wave[b0, ch],
                    'Mask': (~mask_2d[b0, ch]).float(),
                    'Masked Input': masked_input[b0, ch],
                    'Reconstruction': reconstruction[b0, ch]
                }
                vmin = min(d.min().item() for d in data.values())
                vmax = max(d.max().item() for d in data.values())

                for i, (title, arr) in enumerate(data.items()):
                    ax = axs[ch, i]
                    im = ax.imshow(arr.unsqueeze(0).cpu().numpy(), aspect='auto', vmin=vmin, vmax=vmax)
                    ax.set_title(f'Channel {ch} - {title}')

            plt.tight_layout()
            # Save figures in both SVG and high-res PNG formats
            os.makedirs(self.save_dir, exist_ok=True)
            epoch = batch['epoch_idx']
            svg_path = os.path.join(self.save_dir, f'recon_epoch{epoch}_batch{batch_idx}.svg')
            png_path = os.path.join(self.save_dir, f'recon_epoch{epoch}_batch{batch_idx}.png')
            plt.savefig(svg_path, format='svg', bbox_inches='tight')
            plt.savefig(png_path, dpi=1200, bbox_inches='tight')
            plt.close(fig)

            logging_output['note'] = f'Reconstruction visualization saved for epoch {epoch}, batch {batch_idx}'

        return loss, logging_output
