import os
import torch
import pytorch_lightning as pl
import hydra
import torch_optimizer as torch_optim

class MaskedWaveletPretrainingTask(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters(hparams)

        # Encoder: wavelet decomposition, patch masking, and transformer backbone
        self.model = hydra.utils.instantiate(self.hparams.model)

        # Decoder: MAE-style reconstruction head
        self.model_head = hydra.utils.instantiate(self.hparams.model_head)

        # Loss function module
        self.criterion = hydra.utils.instantiate(self.hparams.criterion)

        self.img_logging_step = 0

    def _log_train_reconstruction_data(self, logging_output, log_frequency=10000):
        """
        Log reconstruction images to TensorBoard at specified intervals.
        """
        if 'images' in logging_output:
            if self.img_logging_step % log_frequency == 0:
                for name, image in logging_output['images'].items():
                    self.logger.experiment.add_image(name, image, self.img_logging_step)
            logging_output.pop('images')
            self.img_logging_step += 1

    def get_wavelet_features(self, X):
        """
        Compute wavelet features without masking.

        Args:
            X: Input waveform tensor of shape [B, in_ch, T]

        Returns:
            Wavelet features of shape [B, wave_decomp_ch, T]
        """
        with torch.no_grad():
            return self.model.wavelet_decomp(X)

    def training_step(self, batch, batch_idx):
        """
        Training step:
          1) Encode input into latent representations and mask tokens
          2) Reconstruct using the decoder
          3) Generate target wavelet features without masking
          4) Compute and log the loss
        """
        X = batch['input']  # [B, in_ch, T]

        # Encode and mask
        latent, token_mask, ids_restore = self.model(X)

        # Decode reconstruction
        pred = self.model_head(latent, ids_restore)

        # Prepare target wavelet features
        with torch.no_grad():
            wave_gt_2d = self.get_wavelet_features(X)
            B, C, T = wave_gt_2d.shape
            patch_width = self.model.patch_width
            N = T // patch_width
            wave_gt_patches = wave_gt_2d.reshape(B, C * N, patch_width)

        # Populate batch for loss computation
        batch['target'] = wave_gt_patches
        batch['token_mask'] = token_mask
        batch['wave_gt_2d'] = wave_gt_2d
        batch['batch_idx'] = batch_idx
        batch['epoch_idx'] = self.current_epoch
        if hasattr(self.hparams.criterion, 'save_dir'):
            batch['save_dir'] = self.hparams.criterion.save_dir

        loss, logging_output = self.criterion(pred, batch)

        # Log loss and any reconstruction images
        self._log_train_reconstruction_data(logging_output, log_frequency=5000)
        self.log('train_loss', loss, on_step=True, on_epoch=True,
                 prog_bar=True, logger=True, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        """
        Validation step follows the same procedure as training.
        """
        X = batch['input']
        latent, token_mask, ids_restore = self.model(X)
        pred = self.model_head(latent, ids_restore)

        with torch.no_grad():
            wave_gt_2d = self.get_wavelet_features(X)
            B, C, T = wave_gt_2d.shape
            patch_width = self.model.patch_width
            N = T // patch_width
            wave_gt_patches = wave_gt_2d.reshape(B, C * N, patch_width)

        batch['target'] = wave_gt_patches
        batch['token_mask'] = token_mask
        batch['wave_gt_2d'] = wave_gt_2d
        batch['batch_idx'] = batch_idx
        batch['epoch_idx'] = self.current_epoch
        if hasattr(self.hparams.criterion, 'save_dir'):
            batch['save_dir'] = self.hparams.criterion.save_dir

        loss, logging_output = self.criterion(pred, batch)
        self._log_train_reconstruction_data(logging_output, log_frequency=5000)
        self.log('val_loss', loss, on_step=False, on_epoch=True,
                 prog_bar=True, logger=True, sync_dist=True)
        return loss

    def configure_optimizers(self):
        """
        Configure optimizer and learning-rate scheduler.
        """
        params = list(self.model.parameters()) + list(self.model_head.parameters())
        optim_cfg = self.hparams.optimizer
        if optim_cfg.optim == 'SGD':
            optimizer = torch.optim.SGD(params, lr=optim_cfg.lr, momentum=optim_cfg.momentum)
        elif optim_cfg.optim == 'Adam':
            optimizer = torch.optim.Adam(params, lr=optim_cfg.lr, weight_decay=optim_cfg.weight_decay)
        elif optim_cfg.optim == 'AdamW':
            optimizer = torch.optim.AdamW(params, lr=optim_cfg.lr,
                                          weight_decay=optim_cfg.weight_decay, betas=optim_cfg.betas)
        elif optim_cfg.optim == 'LAMB':
            optimizer = torch_optim.Lamb(params, lr=optim_cfg.lr)
        else:
            raise NotImplementedError(f"Optimizer '{optim_cfg.optim}' is not supported.")

        scheduler = hydra.utils.instantiate(
            self.hparams.scheduler,
            optimizer=optimizer,
            total_training_opt_steps=self.trainer.estimated_stepping_batches
        )
        return {'optimizer': optimizer,
                'lr_scheduler': {'scheduler': scheduler, 'interval': 'step', 'frequency': 1}}

    def lr_scheduler_step(self, scheduler, metric):
        scheduler.step_update(num_updates=self.global_step)
