import torch
import pytorch_lightning as pl
import hydra
import torch_optimizer as torch_optim
import torch.nn as nn
from models.modules.patching import patchify



def get_params_from_checkpoint(checkpoint, head=False):
    """
    Extract model weights from a PyTorch Lightning checkpoint.

    This function filters out weights belonging to a (potential) classification head
    in the checkpoint (when `head=False`), or it includes them (when `head=True`).

    Parameters
    ----------
    checkpoint : dict
        The loaded checkpoint dictionary (e.g. from `torch.load`).
    head : bool, optional
        If False, exclude all keys containing `'_head'`. If True, include them,
        by default False.

    Returns
    -------
    dict
        A dictionary of weights (model state_dict items) that match the specified criteria.
    """
    model_weights = {}
    state_dict = checkpoint["state_dict"]
    
    for k, v in checkpoint["state_dict"].items(): 
        head_cond = ('_head' not in k) if not head else True
        if k.startswith("model.") and head_cond:
            weight_key = k.replace('model.', '')            
            model_weights[weight_key] = v
    return model_weights

class VITMAETask(pl.LightningModule):
    """
    A PyTorch Lightning task for Masked Autoencoder (MAE) training using a ViT-style model.

    This class encapsulates the encoder, decoder (model head), masking logic,
    optimization routines, and logging for MAE training. It inherits from
    `pl.LightningModule` and overrides relevant methods such as `training_step`,
    `validation_step`, `configure_optimizers`, etc.

    Parameters
    ----------
    hparams : dict
        The hyperparameters and configuration dictionary (Hydra).
    masking_ratio : float
        The ratio of tokens to mask during MAE training.
    """
    def __init__(self, hparams, masking_ratio):
        super().__init__()
        self.img_logging_step = 0
        self.save_hyperparameters(hparams)
        self.model = hydra.utils.instantiate(self.hparams.model)  # Encoder
        self.model_head = hydra.utils.instantiate(self.hparams.model_head)  # Decoder
        self.criterion = hydra.utils.instantiate(self.hparams.criterion)
        self.patch_size = self.hparams.model.patch_size
        self.keep_chans = self.hparams.model.keep_chans
        self.using_spectrogram = self.hparams.model.using_spectrogram
        self.strict_loading = False
        self.transform = None
        self.square_patches = False
        if self.using_spectrogram:
            self.transform = hydra.utils.instantiate(self.hparams.preprocessor)
            self.square_patches = self.hparams.model.square_patches

        self.mask_token = self.model.mask_token                     # NOTE: until general TimeFM restructuring, we assume the model has a mask & pad token
        self.pad_token = self.model.pad_token                       # this is done in order to maintain checkpoint loading stability
        self.masking_ratio = masking_ratio       
        
    def _log_train_reconstruction_data(self, logging_output, log_frequency=10000):
        """
        Log reconstruction data (images or waveforms) to TensorBoard at a specified frequency.

        Parameters
        ----------
        logging_output : dict
            Dictionary containing logging items. If it has key "images", those
            will be logged.
        log_frequency : int, optional
            Frequency (in steps) at which to log the images/waveforms,
            by default 10000.
        """
        if "images" in logging_output:
            # Only process images if it's time to log
            if self.img_logging_step % log_frequency == 0:
                logging_images = logging_output['images']
                for image_name, image in logging_images.items():
                    self.logger.experiment.add_image(image_name, image, self.img_logging_step)
            
            # Remove 'images' from logging_output after processing
            logging_output.pop('images')
            self.img_logging_step += 1
    
    def training_step(self, batch, batch_idx):
        """
        Perform one training step.

        Parameters
        ----------
        batch : dict
            Batch containing 'input' and potentially other keys used in training.
        batch_idx : int
            Index of the batch (provided by PyTorch Lightning).

        Returns
        -------
        torch.Tensor
            The loss value for this training step.
        """
        # Collect ground truth
        X = batch['input']
        
        # Encoder
        B, C = X.shape[:2] # batch_size, number_channels
        x = self.model.patch_embed(X) #NOTE: until general TimeFM restructuring, we assume the model has a patch_embed module in order to move masking to the task
        x, token_mask, ids_restore, attn_mask = self.prepare_tokens(x, nr_channels_padded=batch['nr_padded_channels'], mask_tokens=True)
        latent = self.model(x, directly_input_tokens=True, attn_mask=attn_mask)
        
        batch["token_mask"] = token_mask
        batch["attn_mask"] = attn_mask
        
        # Decoder
        x = self.unshuffle_tokens(x, ids_restore)
        pred = self.model_head(latent, ids_restore)

        # Compute loss
        target = patchify(X, patch_size=self.patch_size, keep_chans=self.keep_chans, using_spectrogram=self.using_spectrogram, square_patches=self.square_patches) 
        batch["target"] = target
        loss, logging_output = self.criterion(pred, batch)

        # Log images/waveforms in TensorBoard
        self._log_train_reconstruction_data(logging_output, log_frequency=5000)

        # Log training loss values in TensorBoard
        self.log("train_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        """
        Perform one validation step.

        Parameters
        ----------
        batch : dict
            Batch containing 'input' and potentially other keys used in validation.
        batch_idx : int
            Index of the batch (provided by PyTorch Lightning).

        Returns
        -------
        torch.Tensor
            The loss value for this validation step.
        """
        # Collect ground truth
        X = batch['input']
        
        # Encoder
        x = self.model.patch_embed(X) #NOTE: until general TimeFM restructuring, we assume the model has a patch_embed module in order to move masking to the task
        x, token_mask, ids_restore, attn_mask = self.prepare_tokens(x, nr_channels_padded=batch['nr_padded_channels'], mask_tokens=True)
        latent = self.model(x, directly_input_tokens=True, attn_mask=attn_mask)
        
        batch["token_mask"] = token_mask
        batch["attn_mask"] = attn_mask
        
        # Decoder
        # NOTE: assuming model head is mae_decoder-like. This means it has a way of unshuffling the tokens and combine them with a mask token
        pred = self.model_head(latent, ids_restore)

        # Compute loss
        target = patchify(X, patch_size=self.patch_size, keep_chans=self.keep_chans, using_spectrogram=self.using_spectrogram, square_patches=self.square_patches) 
        batch["target"] = target
        loss, logging_output = self.criterion(pred, batch)

        # Log images/waveforms in TensorBoard
        self._log_train_reconstruction_data(logging_output, log_frequency=5000)

        # Log validation loss values in TensorBoard
        self.log("val_loss", loss.item(), on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        return loss
    
    
    def on_after_batch_transfer(self, batch, dataloader_idx):
        """
        Hook that is called after a batch is transferred to the GPU/CPU in the DataLoader.

        If using spectrogram transformation, applies the STFT transform to the input.

        Parameters
        ----------
        batch : dict
            The batch dictionary after it has been transferred (includes 'input', etc.).
        dataloader_idx : int
            The index of the dataloader this batch comes from.

        Returns
        -------
        dict
            The modified batch (possibly transformed if spectrogram is used).
        """
        if self.using_spectrogram:
            # Compute STFT Representation
            batch['input'] = self.transform(batch['input'])
        return batch
    
    def configure_optimizers(self):
        """
        Define optimizers and learning rate schedulers.

        This method sets up the optimizer (SGD/Adam/AdamW/LAMB) based on
        `self.hparams.optimizer`, and instantiates a LR scheduler from
        `self.hparams.scheduler`.

        Returns
        -------
        dict
            A dictionary containing the optimizer and a configuration for the
            learning rate scheduler (with step-level interval).
        """
        """
        Define optimizers and learning-rate schedulers to use in your optimization.

        Returns:
            [optimizer],[scheduler] - The first list contains optimizers, the
            second contains LR schedulers (or lr_dict).
        """
        params_to_pass =  list(self.model.parameters()) + list(self.model_head.parameters())
        if self.hparams.optimizer.optim == "SGD":
            optimizer = torch.optim.SGD(params_to_pass, lr=self.hparams.optimizer.lr, momentum=self.hparams.optimizer.momentum)
        elif self.hparams.optimizer.optim == 'Adam':
            optimizer = torch.optim.Adam(params_to_pass, lr=self.hparams.optimizer.lr, weight_decay=self.hparams.optimizer.weight_decay)
        elif self.hparams.optimizer.optim == 'AdamW':
            optimizer = torch.optim.AdamW(params_to_pass, lr=self.hparams.optimizer.lr, weight_decay=self.hparams.optimizer.weight_decay, betas=self.hparams.optimizer.betas)
        elif self.hparams.optimizer.optim == 'LAMB':
            optimizer = torch_optim.Lamb(params_to_pass, lr=self.hparams.optimizer.lr)
        else:
            raise NotImplementedError("No valid optimizer name")

        print('OPTIMIZER', optimizer)

        print(f"ESTIMATED TRAINING BATCHES: {self.trainer.num_training_batches}")
        print(f"ESTIMATED GRAD ACCUM: {self.trainer.accumulate_grad_batches}")
        print(f"ESTIMATED STEPPING BATCHES FOR ENTIRE TRAINING: {self.trainer.estimated_stepping_batches}")
        print(f"MAX EPOCHS: {self.trainer.max_epochs}")
        scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer=optimizer, 
                                            total_training_opt_steps=self.trainer.estimated_stepping_batches)
        print('SCHEDULER', scheduler)

        lr_scheduler_config = {
            "scheduler": scheduler,
            "interval": "step",
            "frequency": 1
        }

        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
    
    def lr_scheduler_step(self, scheduler, metric):
        """
        Manually step the scheduler for each update.

        This is used to handle schedulers that require a custom step function,
        such as certain warmup or dynamic scheduling schedulers.

        Parameters
        ----------
        scheduler : object
            The learning rate scheduler.
        metric : float
            A metric value (not necessarily used here, but required by PL signature).
        """
        scheduler.step_update(num_updates=self.global_step)
        

    def load_from_checkpoint(self, checkpoint_path, map_location= None, hparams_file = None, strict= None, **kwargs):
        """
        Custom method to load a checkpoint while possibly ignoring the model head.

        If `freeze_backbone` is set, the encoder parameters will be frozen after loading.

        Parameters
        ----------
        checkpoint_path : str
            Path to the checkpoint to load.
        map_location : str or torch.device, optional
            Where to map the loaded state_dict, by default None.
        hparams_file : str, optional
            Path to hyperparameters file, by default None.
        strict : bool, optional
            Whether to strictly enforce that the keys in `state_dict` match
            the keys returned by this module’s `state_dict()`, by default None.
        **kwargs : dict
            Additional arguments passed to `torch.load` if needed.

        Returns
        -------
        self : VITMAETask
            The instance with loaded weights.
        """
        print('\n\nOverriding load_from_checkpoint method')
    
        ckp = torch.load(checkpoint_path, map_location=map_location)
        state_dict_no_head = get_params_from_checkpoint(ckp, head=False)
        
        
        model_state_dict = self.model.state_dict()
        for k in state_dict_no_head:
            if k in model_state_dict:
                if state_dict_no_head[k].shape != model_state_dict[k].shape:
                    print(f"Skip loading parameter: {k}, "
                                f"required shape: {model_state_dict[k].shape}, "
                                f"loaded shape: {state_dict_no_head[k].shape}")
                    state_dict_no_head[k] = model_state_dict[k]
            else:
                print(f"Dropping parameter {k}")

        self.model.load_state_dict(state_dict_no_head, strict=False)
              
        if self.freeze_backbone:
            print('Freezing encoder params from loaded checkpoint')
            for param in self.model.parameters():
                param.requires_grad = False      
        return self
    
    
    def prepare_tokens(self, x, nr_channels_padded=None, mask_tokens=True):
        """
        Prepare and optionally mask tokens before feeding them to the MAE encoder.

        This includes:
        - Handling padded channels (assigning [PAD] tokens).
        - Optionally masking tokens by calling `_mask_tokens`.

        Parameters
        ----------
        x : torch.Tensor
            Token embeddings of shape (B, N, D).
        nr_channels_padded : torch.Tensor, optional
            A tensor of shape (B,) indicating how many channels
            are padded for each sample in the batch, by default None.
        mask_tokens : bool, optional
            Whether or not to apply random masking, by default True.

        Returns
        -------
        x : torch.Tensor
            Potentially masked token embeddings of shape (B, N, D).
        mask : torch.Tensor or None
            A binary mask of shape (B, N), where 1 indicates a masked token, 0 otherwise,
            or None if `mask_tokens` is False.
        ids_restore : torch.Tensor or None
            A restore index of shape (B, N), used to unshuffle tokens, or None if not applicable.
        attn_mask : torch.Tensor or None
            A binary attention mask of shape (B, N) indicating non-padded tokens, or None if no padding.
        """
        B, C = x.shape[:2] # B = batch size; C = num channels; 
        B, N, D = x.shape 
        P = N // C # P = num per-channel tokens (i.e. patches)
        pad_mask = None
        attn_mask = None
        mask = None
        ids_restore = None
        
        if nr_channels_padded is not None:
            # since we have a varying number of padded channels in each batch, we take the approach
            # of creating an index mask, thus assigning the [PAD] token value to each to-be-padded location
            # NOTE: each to-be-padded location is expected to be already 0-padded in the dataset class,
            #       to avoid array stacking errors
            nr_real_chans = C - nr_channels_padded # vector of size (B, C)            
            channel_indices = torch.arange(C).unsqueeze(0).to(x.device)  # Shape (1, C)
            pad_mask = channel_indices >= nr_real_chans.unsqueeze(1)  # Shape (B, C)
            pad_mask = pad_mask.repeat_interleave(P, dim=1) # Shape (B, C*P)
            
            # extract attn_mask of shape B C*P and assign 0 to padded tokens, 1 to non-padded tokens
            attn_mask = (~pad_mask).int() # Shape (B, C*P)
            
            # expand pad mask and assign learned [PAD] to input
            pad_mask = pad_mask.unsqueeze(-1).expand(-1, -1, D)  # Shape (B, C*P, D)
            x[pad_mask] = self.pad_token.expand(B, C*P, -1)[pad_mask] # x remains at the same shape (B, N, D), only the padded tokens are modified
            
        if mask_tokens:
            # mask, unshuffle and get ids of tokens to be restored
            x, mask, ids_restore = self._mask_tokens(x, attn_mask) # (B, N, D)
    
        return x, mask, ids_restore, attn_mask
            
    def _mask_tokens(self, x, attn_mask=None):
        """
        Helper function to apply random masking to tokens.

        Parameters
        ----------
        x : torch.Tensor
            Token embeddings of shape (B, N, D).
        attn_mask : torch.Tensor or None
            A binary attention mask of shape (B, N). 1 indicates a real token,
            0 indicates a padded token.

        Returns
        -------
        tuple
            (x_masked, mask, ids_restore) where:
              - x_masked is the partially masked embeddings of shape (B, M, D),
                where M depends on how many tokens are kept.
              - mask is a binary mask of shape (B, N), 1 = masked, 0 = kept.
              - ids_restore is the shuffling restore indices of shape (B, N).
        """
        B = x.shape[0] # (B, N, D)
        x, mask, ids_restore = self.random_masking(x, self.masking_ratio, attn_mask) # (B,len_masked,D), (B,N), (B,N)        
        return x, mask, ids_restore # (B,N_masked,D), (B,N), (B,N)
    
    def random_masking(self, x, mask_ratio, attn_mask=None):
        """
        Perform per-sample random masking by per-sample shuffling.

        Random noise is generated for each token in the batch, then tokens are
        sorted by ascending noise. The first subset of tokens is kept, and the
        remainder is replaced with the mask token.

        If `attn_mask` is provided, padded tokens (where attn_mask=0) are excluded
        from the random masking process (they remain padded).

        Parameters
        ----------
        x : torch.Tensor
            Input token embeddings of shape (B, N, D).
        mask_ratio : float
            Fraction of tokens to mask (in [0,1]).
        attn_mask : torch.Tensor or None
            A binary attention mask of shape (B, N), where 1 indicates a real token
            and 0 indicates a padded token. If None, no tokens are treated as padded.

        Returns
        -------
        x_masked : torch.Tensor
            Masked token embeddings of shape (B, M, D),
            where M depends on how many tokens are kept.
        mask : torch.Tensor
            A binary mask of shape (B, N), 1 indicates masked token, 0 indicates kept.
        ids_restore : torch.Tensor
            Indices used to undo the token shuffle, of shape (B, N).
        """
        B, N, D = x.shape  # N = batch_size, N = num_tokens, D = embed_dim
        if attn_mask is not None:
            nr_padded_tokens = attn_mask.shape[-1] - attn_mask.sum(axis=-1) # (B)
            mask_keep = ((N - nr_padded_tokens) * (1-mask_ratio)).to(int) # (B)
        else:
            len_keep = int(N * (1 - mask_ratio)) # int
        
        noise = torch.rand(B, N, device=x.device)  # noise in [0, 1] of shape (B,N)

        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove; shape (B,N)
        ids_restore = torch.argsort(ids_shuffle, dim=1) # shape (B,N)

        # get total nr of tokens padded per batch
        # keep the first subset
        if attn_mask is not None:
            keep_ids = ids_shuffle < mask_keep.unsqueeze(1) # shape (B,N)
            x_shuffled = torch.gather(x, dim=1, index=ids_shuffle.unsqueeze(-1).repeat(1, 1, D)) # shape (B,len_keep,D)
            x_masked = torch.zeros_like(x) # shape (B,N,D)
            x_masked[keep_ids] = x_shuffled[keep_ids] 
        else:
            ids_keep = ids_shuffle[:, :len_keep] # shape (B,len_keep)
            x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # shape (B,len_keep,D)

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([B, N], device=x.device) # shape (B,N)
        if attn_mask is not None:
            mask[keep_ids] = 0 # shape (B,N)
        else:
            mask[:, :len_keep] = 0 # shape (B,len_keep)

        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore) # shape (B,N,D)

        return x_masked, mask, ids_restore # # shapes (B,len_keep,D), (B,N), (B,N) 
