import torch
import pytorch_lightning as pl
import hydra
import torch_optimizer as torch_optim
import torch.nn as nn
from models.modules.patching import patchify

class MaskedSpecPretrainingTask(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.img_logging_step = 0
        self.save_hyperparameters(hparams)
        self.model = hydra.utils.instantiate(self.hparams.model)  # Encoder
        self.model_head = hydra.utils.instantiate(self.hparams.model_head)  # Decoder
        self.criterion = hydra.utils.instantiate(self.hparams.criterion)
        self.patch_size = self.hparams.model.patch_size
        self.keep_chans = self.hparams.model.keep_chans
        self.using_spectrogram = self.hparams.model.using_spectrogram
        
        self.transform = None
        self.square_patches = False
        if self.using_spectrogram:
            self.transform = hydra.utils.instantiate(self.hparams.preprocessor)
            self.square_patches = self.hparams.model.square_patches
        
    def _log_train_reconstruction_data(self, logging_output, log_frequency=10000):
        if "images" in logging_output:
            # Only process images if it's time to log
            if self.img_logging_step % log_frequency == 0:
                logging_images = logging_output['images']
                for image_name, image in logging_images.items():
                    self.logger.experiment.add_image(image_name, image, self.img_logging_step)
            
            # Remove 'images' from logging_output after processing
            logging_output.pop('images')
            self.img_logging_step += 1
    
    def training_step(self, batch, batch_idx):
        # Collect ground truth
        X = batch['input']
        target = patchify(X, patch_size=self.patch_size, keep_chans=self.keep_chans, using_spectrogram=self.using_spectrogram, square_patches=self.square_patches) 
        batch["target"] = target

        # Encoder
        latent, token_mask, ids_restore = self.model(X, mask_tokens=True)
        batch["token_mask"] = token_mask
        
        # Decoder
        pred = self.model_head(latent, ids_restore)

        # Compute loss
        loss, logging_output = self.criterion(pred, batch)

        # Log images/waveforms in TensorBoard
        self._log_train_reconstruction_data(logging_output, log_frequency=5000)

        # Log training loss values in TensorBoard
        self.log("train_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        # Collect ground truth
        X = batch['input']
        target = patchify(X, patch_size=self.patch_size, keep_chans=self.keep_chans, using_spectrogram=self.using_spectrogram, square_patches=self.square_patches) # input images to patches which will be used in loss computation
        batch["target"] = target

        # Encoder
        latent, token_mask, ids_restore = self.model(X, mask_tokens=True)
        batch["token_mask"] = token_mask
        
        # Decoder
        pred = self.model_head(latent, ids_restore)

        # Compute loss
        loss, logging_output = self.criterion(pred, batch)

        # Log images/waveforms in TensorBoard
        self._log_train_reconstruction_data(logging_output, log_frequency=5000)

        # Log validation loss values in TensorBoard
        self.log("val_loss", loss.item(), on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        return loss
    
    
    def on_after_batch_transfer(self, batch, dataloader_idx):
        if self.using_spectrogram:
            # Compute STFT Representation
            batch['input'] = self.transform(batch['input'])
        return batch
      
    # def on_after_batch_transfer(self, batch, dataloader_idx):
    #     if self.transform is not None:
    #         batch['input'] = self.transform(batch['input'])
    #     return batch
    
    def configure_optimizers(self):
        """
        Define optimizers and learning-rate schedulers to use in your optimization.

        Returns:
            [optimizer],[scheduler] - The first list contains optimizers, the
            second contains LR schedulers (or lr_dict).
        """
        params_to_pass =  list(self.model.parameters()) + list(self.model_head.parameters())
        if self.hparams.optimizer.optim == "SGD":
            optimizer = torch.optim.SGD(params_to_pass, lr=self.hparams.optimizer.lr, momentum=self.hparams.optimizer.momentum)
        elif self.hparams.optimizer.optim == 'Adam':
            optimizer = torch.optim.Adam(params_to_pass, lr=self.hparams.optimizer.lr, weight_decay=self.hparams.optimizer.weight_decay)
        elif self.hparams.optimizer.optim == 'AdamW':
            optimizer = torch.optim.AdamW(params_to_pass, lr=self.hparams.optimizer.lr, weight_decay=self.hparams.optimizer.weight_decay, betas=self.hparams.optimizer.betas)
        elif self.hparams.optimizer.optim == 'LAMB':
            optimizer = torch_optim.Lamb(params_to_pass, lr=self.hparams.optimizer.lr)
        else:
            raise NotImplementedError("No valid optimizer name")

        print('OPTIMIZER', optimizer)

        print(f"ESTIMATED TRAINING BATCHES: {self.trainer.num_training_batches}")
        print(f"ESTIMATED GRAD ACCUM: {self.trainer.accumulate_grad_batches}")
        print(f"ESTIMATED STEPPING BATCHES FOR ENTIRE TRAINING: {self.trainer.estimated_stepping_batches}")
        print(f"MAX EPOCHS: {self.trainer.max_epochs}")
        scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer=optimizer, 
                                            total_training_opt_steps=self.trainer.estimated_stepping_batches)
        print('SCHEDULER', scheduler)

        lr_scheduler_config = {
            "scheduler": scheduler,
            "interval": "step",
            "frequency": 1
        }

        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
    
    def lr_scheduler_step(self, scheduler, metric):
        scheduler.step_update(num_updates=self.global_step)