import torch
import pytorch_lightning as pl
import hydra
import torch_optimizer as torch_optim

# Here we import the patchify_2d and unpatchify_2d utilities
from models.modules.patching2D import patchify_2d, unpatchify_2d

class MaskedWaveletPretrainingTask(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters(hparams)
        
        # 1) Wavelet + Encoder (internally contains wavelet_decomp + ViT Encoder + PatchEmbed2D)
        self.model = hydra.utils.instantiate(self.hparams.model)  
        
        # 2) Decoder: MAE-style reconstruction -> [B, num_patches, 50]
        self.model_head = hydra.utils.instantiate(self.hparams.model_head)
        
        # 3) Loss calculation
        self.criterion = hydra.utils.instantiate(self.hparams.criterion)
        
        self.img_logging_step = 0

    def _log_train_reconstruction_data(self, logging_output, log_frequency=10000):
        """
        A helper function to visualize and log reconstruction data.
        Keeps the original logic unchanged.
        """
        if "images" in logging_output:
            if self.img_logging_step % log_frequency == 0:
                logging_images = logging_output['images']
                for image_name, image in logging_images.items():
                    self.logger.experiment.add_image(image_name, image, self.img_logging_step)
            logging_output.pop('images')
            self.img_logging_step += 1

    def training_step(self, batch, batch_idx):
        """
        Training step:
          1) Apply wavelet + patchify + masking to X => latent
          2) Decoder outputs pred => [B, num_patches, 50]
          3) wave_gt does the same patchify => wave_gt_patches => [B, num_patches, 50]
          4) Compare pred vs wave_gt_patches => Loss
        """
        X = batch['input']  # [B, in_ch, T], the raw waveform

        # ------------------------------
        # 1) Encoder => latent
        # Inside the model: wave_decomp -> wave_2d -> PatchEmbed2D -> random_masking -> ViT => latent
        latent, token_mask, ids_restore = self.model(X, mask_tokens=True)  
        
        # 2) Decoder => pred
        #   Here, pred should be [B, num_patches, 50] 
        #   (because decoder_output_dim=50 in model_head config)
        pred = self.model_head(latent, ids_restore)
        
        # ------------------------------
        # 3) wave_gt
        #   wavelet_decomp => [B, wave_decomp_ch, T], e.g. [B,64,1000]
        with torch.no_grad():
            wave_gt = self.model.wavelet_decomp(X)  # => [B, wave_decomp_ch, T]
            # Expand an extra dimension => [B,1,64,1000] to match patchify_2d
            wave_gt_2d = wave_gt.unsqueeze(1)

        patch_size = self.model.patch_size 
        
        # 使用一致的patch_size
        wave_gt_patches = patchify_2d(wave_gt_2d, patch_size=patch_size)

        # ------------------------------
        # 4) Loss computation
        #   pred and wave_gt_patches both are [B, num_patches, 50]
        batch["target"] = wave_gt_patches
        batch["token_mask"] = token_mask  # [B, num_patches], 1 => masked
        batch["wave_gt_2d"] = wave_gt_2d
        batch["batch_idx"] = batch_idx
        batch["epoch_idx"] = self.current_epoch

        if hasattr(self.hparams.criterion, 'save_dir'):
            batch["save_dir"] = self.hparams.criterion.save_dir
        # In most MAE implementations, pred vs target is MSE (or L1, smooth_l1, etc.)
        loss, logging_output = self.criterion(pred, batch)

        # Logging
        self._log_train_reconstruction_data(logging_output, log_frequency=5000)
        self.log("train_loss", loss.item(), on_step=True, on_epoch=True, 
                 prog_bar=True, logger=True, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        # Almost the same logic
        X = batch['input']  # [B, in_ch, T]
        latent, token_mask, ids_restore = self.model(X, mask_tokens=True)
        pred = self.model_head(latent, ids_restore)  # [B, num_patches, 50]
        
        with torch.no_grad():
            wave_gt = self.model.wavelet_decomp(X)
            wave_gt_2d = wave_gt.unsqueeze(1)
            patch_size = self.model.patch_size 
            wave_gt_patches = patchify_2d(wave_gt_2d, patch_size=patch_size)
        
        batch["target"] = wave_gt_patches
        batch["token_mask"] = token_mask
        batch["wave_gt_2d"] = wave_gt_2d
        batch["batch_idx"] = batch_idx
        batch["epoch_idx"] = self.current_epoch
        if hasattr(self.hparams.criterion, 'save_dir'):
           batch["save_dir"] = self.hparams.criterion.save_dir

        loss, logging_output = self.criterion(pred, batch)

        self._log_train_reconstruction_data(logging_output, log_frequency=5000)
        self.log("val_loss", loss.item(), on_step=False, on_epoch=True, 
                 prog_bar=True, logger=True, sync_dist=True)
        return loss

    def configure_optimizers(self):
        """
        Keep the original optimizer + LR scheduler logic unchanged.
        """
        params_to_pass = list(self.model.parameters()) + list(self.model_head.parameters())
        if self.hparams.optimizer.optim == "SGD":
            optimizer = torch.optim.SGD(params_to_pass, lr=self.hparams.optimizer.lr, 
                                        momentum=self.hparams.optimizer.momentum)
        elif self.hparams.optimizer.optim == 'Adam':
            optimizer = torch.optim.Adam(params_to_pass, lr=self.hparams.optimizer.lr, 
                                         weight_decay=self.hparams.optimizer.weight_decay)
        elif self.hparams.optimizer.optim == 'AdamW':
            optimizer = torch.optim.AdamW(params_to_pass, lr=self.hparams.optimizer.lr, 
                                          weight_decay=self.hparams.optimizer.weight_decay, 
                                          betas=self.hparams.optimizer.betas)
        elif self.hparams.optimizer.optim == 'LAMB':
            optimizer = torch_optim.Lamb(params_to_pass, lr=self.hparams.optimizer.lr)
        else:
            raise NotImplementedError("No valid optimizer name")

        scheduler = hydra.utils.instantiate(
            self.hparams.scheduler, 
            optimizer=optimizer, 
            total_training_opt_steps=self.trainer.estimated_stepping_batches
        )
        
        lr_scheduler_config = {
            "scheduler": scheduler,
            "interval": "step",
            "frequency": 1
        }
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}

    def lr_scheduler_step(self, scheduler, metric):
        scheduler.step_update(num_updates=self.global_step)
