import torch
import pytorch_lightning as pl
import hydra
import torch_optimizer as torch_optim
import torch.nn as nn
from models.modules.patching import patchify



def get_params_from_checkpoint(checkpoint, head=False):
    """
    Extracts model weights from a checkpoint.

    This function filters out parameters related to the model head (decoder) 
    if `head=False`. It retains only those parameters that are part of 
    'model.*' keys in the checkpoint.

    Args:
        checkpoint (dict): A PyTorch checkpoint containing a 'state_dict' entry.
        head (bool, optional): If `True`, keep all parameters (including the 
            head). If `False`, exclude the parameters that contain '_head' in 
            their keys. Defaults to False.

    Returns:
        dict: A dictionary of the filtered model weights, where keys are 
        the parameter names (with 'model.' prefix removed) and values 
        are the associated tensors.
    """
    model_weights = {}
    state_dict = checkpoint["state_dict"]
    
    for k, v in checkpoint["state_dict"].items(): 
        head_cond = ('_head' not in k) if not head else True
        if k.startswith("model.") and head_cond:
            weight_key = k.replace('model.', '')            
            model_weights[weight_key] = v
    return model_weights

class SimmimMAETask(pl.LightningModule):
    """
    A PyTorch Lightning module for a SimMIM MAE Task.

    This task handles the forward pass of an encoder-decoder style model,
    random token masking, and the computation of reconstruction losses.
    It also sets up optimizers, schedulers, and custom checkpoint-loading 
    functionality.

    Attributes:
        hparams (dict): Hyperparameters for model configuration, optimizer, scheduler, etc.
        masking_ratio (float): The ratio of tokens to mask during training.
        model (nn.Module): The encoder component of the architecture.
        model_head (nn.Module): The decoder component of the architecture.
        criterion (nn.Module): Loss function for reconstruction (e.g., MSE or L1).
        patch_size (int): Patch size used in patch embedding.
        keep_chans (int): Number of channels to keep.
        using_spectrogram (bool): Whether to apply a spectrogram transformation.
        strict_loading (bool): For controlling strictness when loading state dicts.
        transform (nn.Module): Transformation function (e.g., STFT) for spectrogram usage.
        square_patches (bool): Whether patches are square or not (applies primarily to spectrogram tasks).
        mask_token (torch.Tensor): Learned embedding used for masked tokens.
        pad_token (torch.Tensor): Learned embedding used for padded tokens.
        img_logging_step (int): Internal counter for controlling how often to log images.
    """
    def __init__(self, hparams, masking_ratio):
        """
        Initializes the SimmimMAETask.

        Args:
            hparams (dict): A dictionary containing all the hyperparameters. 
                This includes model configuration, optimizer, and scheduler settings.
            masking_ratio (float): The ratio of tokens to mask during training.
        """
        super().__init__()
        self.img_logging_step = 0
        self.save_hyperparameters(hparams)
        self.model = hydra.utils.instantiate(self.hparams.model)  # Encoder
        self.model_head = hydra.utils.instantiate(self.hparams.model_head)  # Decoder
        self.criterion = hydra.utils.instantiate(self.hparams.criterion)
        self.patch_size = self.hparams.model.patch_size
        self.keep_chans = self.hparams.model.keep_chans
        self.using_spectrogram = self.hparams.model.using_spectrogram
        self.strict_loading = False
        self.transform = None
        self.square_patches = False
        if self.using_spectrogram:
            self.transform = hydra.utils.instantiate(self.hparams.preprocessor)
            self.square_patches = self.hparams.model.square_patches

        self.mask_token = self.model.mask_token                     # NOTE: until general TimeFM restructuring, we assume the model has a mask & pad token
        self.pad_token = self.model.pad_token                       # this is done in order to maintain checkpoint loading stability
        self.masking_ratio = masking_ratio       
        
    def _log_train_reconstruction_data(self, logging_output, log_frequency=10000):
        """
        Logs image or waveform data to TensorBoard at a specified frequency.

        Args:
            logging_output (dict): Contains potential key 'images' (and possibly 
                other logging info). 'images' should be a dict of name -> torch.Tensor 
                images to log.
            log_frequency (int, optional): The frequency (in steps) at which to log 
                images. Defaults to 10000.
        """
        if "images" in logging_output:
            # Only process images if it's time to log
            if self.img_logging_step % log_frequency == 0:
                logging_images = logging_output['images']
                for image_name, image in logging_images.items():
                    self.logger.experiment.add_image(image_name, image, self.img_logging_step)
            
            # Remove 'images' from logging_output after processing
            logging_output.pop('images')
            self.img_logging_step += 1
    
    def training_step(self, batch, batch_idx):
        """
        Defines a single step of training.

        This method:
        1. Extracts input data from the batch.
        2. Passes it through patch embedding, then masks tokens.
        3. Computes latent representations via the encoder.
        4. Passes latents to the decoder.
        5. Calculates reconstruction loss against the ground truth.

        Args:
            batch (dict): A dictionary containing input data with key 'input'.
            batch_idx (int): Index of this batch.

        Returns:
            torch.Tensor: The computed loss for this batch.
        """
        # Collect ground truth
        X = batch['input']
        
        # Encoder
        B, C = X.shape[:2] # batch_size, number_channels
        x = self.model.patch_embed(X) #NOTE: until general TimeFM restructuring, we assume the model has a patch_embed module in order to move masking to the task
        x, token_mask, ids_restore, attn_mask = self.prepare_tokens(x, nr_channels_padded=batch['nr_padded_channels'], mask_tokens=True)
        latent = self.model(x, directly_input_tokens=True, attn_mask=attn_mask)
        
        batch["token_mask"] = token_mask
        batch["attn_mask"] = attn_mask
        
        # Decoder
        pred = self.model_head(latent, ids_restore)

        # Compute loss
        target = patchify(X, patch_size=self.patch_size, keep_chans=self.keep_chans, using_spectrogram=self.using_spectrogram, square_patches=self.square_patches) 
        batch["target"] = target
        loss, logging_output = self.criterion(pred, batch)

        # Log images/waveforms in TensorBoard
        self._log_train_reconstruction_data(logging_output, log_frequency=5000)

        # Log training loss values in TensorBoard
        self.log("train_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        """
        Defines a single step of validation.

        Similar to the training step, but typically with different logging 
        frequencies and no optimizer step.

        Args:
            batch (dict): A dictionary containing input data with key 'input'.
            batch_idx (int): Index of this batch in the current validation epoch.

        Returns:
            torch.Tensor: The computed validation loss for this batch.
        """
        # Collect ground truth
        X = batch['input']
        
        # Encoder
        x = self.model.patch_embed(X) #NOTE: until general TimeFM restructuring, we assume the model has a patch_embed module in order to move masking to the task
        x, token_mask, ids_restore, attn_mask = self.prepare_tokens(x, nr_channels_padded=batch['nr_padded_channels'], mask_tokens=True)
        latent = self.model(x, directly_input_tokens=True, attn_mask=attn_mask)
        
        batch["token_mask"] = token_mask
        batch["attn_mask"] = attn_mask
        
        # Decoder
        pred = self.model_head(latent, ids_restore)

        # Compute loss
        target = patchify(X, patch_size=self.patch_size, keep_chans=self.keep_chans, using_spectrogram=self.using_spectrogram, square_patches=self.square_patches) 
        batch["target"] = target
        loss, logging_output = self.criterion(pred, batch)

        # Log images/waveforms in TensorBoard
        self._log_train_reconstruction_data(logging_output, log_frequency=5000)

        # Log validation loss values in TensorBoard
        self.log("val_loss", loss.item(), on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        return loss
    
    
    def on_after_batch_transfer(self, batch, dataloader_idx):
        """
        Hook that is called after a batch is transferred to the GPU/TPU.

        If using spectrogram, transforms the input waveforms into STFT or 
        another spectrogram representation.

        Args:
            batch (dict): The batch of data after transfer to device.
            dataloader_idx (int): Index of the dataloader.

        Returns:
            dict: The updated batch (transformed if `using_spectrogram` is True).
        """
        if self.using_spectrogram:
            # Compute STFT Representation
            batch['input'] = self.transform(batch['input'])
        return batch
    
    def configure_optimizers(self):
        """
        Defines the optimizer(s) and learning-rate scheduler(s).

        This method:
        1. Combines parameters from both encoder (self.model) and decoder (self.model_head).
        2. Instantiates the optimizer based on user-specified hyperparameters.
        3. Instantiates a LR scheduler (if provided by hydra configuration).
        4. Sets the scheduler to step every optimization step.

        Returns:
            dict: A dictionary containing:
                "optimizer": The chosen optimizer.
                "lr_scheduler": A config dict for the learning rate scheduler, 
                                with keys "scheduler", "interval", and "frequency".
        """
        params_to_pass =  list(self.model.parameters()) + list(self.model_head.parameters())
        if self.hparams.optimizer.optim == "SGD":
            optimizer = torch.optim.SGD(params_to_pass, lr=self.hparams.optimizer.lr, momentum=self.hparams.optimizer.momentum)
        elif self.hparams.optimizer.optim == 'Adam':
            optimizer = torch.optim.Adam(params_to_pass, lr=self.hparams.optimizer.lr, weight_decay=self.hparams.optimizer.weight_decay)
        elif self.hparams.optimizer.optim == 'AdamW':
            optimizer = torch.optim.AdamW(params_to_pass, lr=self.hparams.optimizer.lr, weight_decay=self.hparams.optimizer.weight_decay, betas=self.hparams.optimizer.betas)
        elif self.hparams.optimizer.optim == 'LAMB':
            optimizer = torch_optim.Lamb(params_to_pass, lr=self.hparams.optimizer.lr)
        else:
            raise NotImplementedError("No valid optimizer name")

        print('OPTIMIZER', optimizer)

        print(f"ESTIMATED TRAINING BATCHES: {self.trainer.num_training_batches}")
        print(f"ESTIMATED GRAD ACCUM: {self.trainer.accumulate_grad_batches}")
        print(f"ESTIMATED STEPPING BATCHES FOR ENTIRE TRAINING: {self.trainer.estimated_stepping_batches}")
        print(f"MAX EPOCHS: {self.trainer.max_epochs}")
        scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer=optimizer, 
                                            total_training_opt_steps=self.trainer.estimated_stepping_batches)
        print('SCHEDULER', scheduler)

        lr_scheduler_config = {
            "scheduler": scheduler,
            "interval": "step",
            "frequency": 1
        }

        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
    
    def lr_scheduler_step(self, scheduler, metric):
        """
        Manual step for the learning rate scheduler.

        By default, PyTorch Lightning calls `scheduler.step()` automatically 
        for 'on_epoch_end' schedulers. However, when a custom approach is needed 
        (e.g., step per update), we can call `scheduler.step_update()` here.

        Args:
            scheduler (object): The instantiated scheduler.
            metric (float): The metric to base step on (not used in some schedulers).
        """
        scheduler.step_update(num_updates=self.global_step)
        

    def load_from_checkpoint(self, checkpoint_path, map_location= None, hparams_file = None, strict= None, **kwargs):
        """
        Custom method to load encoder weights from a checkpoint, optionally skipping
        the decoder head.

        Args:
            checkpoint_path (str): Path to the .ckpt file.
            map_location (torch.device or str, optional): Device map to load weights on.
            hparams_file (str, optional): File containing hyperparameters (not used here).
            strict (bool, optional): Whether to strictly enforce matching parameter sizes.
            **kwargs: Additional keyword arguments, if any.

        Returns:
            SimmimMAETask: The instance with loaded weights.
        """
        print('\n\nOverriding load_from_checkpoint method')
    
        ckp = torch.load(checkpoint_path, map_location=map_location)
        state_dict_no_head = get_params_from_checkpoint(ckp, head=False)
        
        
        model_state_dict = self.model.state_dict()
        for k in state_dict_no_head:
            if k in model_state_dict:
                if state_dict_no_head[k].shape != model_state_dict[k].shape:
                    print(f"Skip loading parameter: {k}, "
                                f"required shape: {model_state_dict[k].shape}, "
                                f"loaded shape: {state_dict_no_head[k].shape}")
                    state_dict_no_head[k] = model_state_dict[k]
            else:
                print(f"Dropping parameter {k}")

        self.model.load_state_dict(state_dict_no_head, strict=False)
              
        if self.freeze_backbone:
            print('Freezing encoder params from loaded checkpoint')
            for param in self.model.parameters():
                param.requires_grad = False      
        return self
    
    
    def prepare_tokens(self, x, nr_channels_padded=None, mask_tokens=True):
        """
        Prepares tokens by padding and masking them.

        1. If `nr_channels_padded` is given, identifies padded channels and 
           replaces them with the pad token in `x`.
        2. Optionally masks tokens in `x`.

        Args:
            x (torch.Tensor): Token embeddings of shape (B, N, D), where 
                B is batch size, N is number of tokens, D is embedding dimension.
            nr_channels_padded (torch.Tensor, optional): Number of padded channels 
                per batch element. Shape: (B,).
            mask_tokens (bool, optional): Whether to apply random masking. Defaults to True.

        Returns:
            tuple: (x, mask, ids_restore, attn_mask)
                - x (torch.Tensor): The token sequence possibly with mask and 
                  pad tokens replaced.
                - mask (torch.Tensor or None): Binary mask for masked tokens. 
                  Shape (B, N).
                - ids_restore (torch.Tensor or None): Indices to restore the 
                  original token ordering after shuffling. Shape (B, N).
                - attn_mask (torch.Tensor or None): Binary mask for padded tokens. 
                  Shape (B, N). 1 indicates a real token, 0 indicates a padded token.
        """
        B, C = x.shape[:2] # B = batch size; C = num channels; 
        B, N, D = x.shape 
        P = N // C # P = num per-channel tokens (i.e. patches)
        pad_mask = None
        attn_mask = None
        mask = None
        ids_restore = None
        
        if nr_channels_padded is not None:
            # since we have a varying number of padded channels in each batch, we take the approach
            # of creating an index mask, thus assigning the [PAD] token value to each to-be-padded location
            # NOTE: each to-be-padded location is expected to be already 0-padded in the dataset class,
            #       to avoid array stacking errors
            nr_real_chans = C - nr_channels_padded # vector of size (B, C)            
            channel_indices = torch.arange(C).unsqueeze(0).to(x.device)  # Shape (1, C)
            pad_mask = channel_indices >= nr_real_chans.unsqueeze(1)  # Shape (B, C)
            pad_mask = pad_mask.repeat_interleave(P, dim=1) # Shape (B, C*P)
            
            # extract attn_mask of shape B C*P and assign 0 to padded tokens, 1 to non-padded tokens
            attn_mask = (~pad_mask).int() # Shape (B, C*P)
            
            # expand pad mask and assign learned [PAD] to input
            pad_mask = pad_mask.unsqueeze(-1).expand(-1, -1, D)  # Shape (B, C*P, D)
            x[pad_mask] = self.pad_token.expand(B, C*P, -1)[pad_mask] # x remains at the same shape (B, N, D), only the padded tokens are modified
            
        if mask_tokens:
            # mask, unshuffle and get ids of tokens to be restored
            x, mask, ids_restore = self._mask_tokens(x, attn_mask) # (B, N, D)
    
        return x, mask, ids_restore, attn_mask
            
    
    def _mask_tokens(self, x, attn_mask=None):
        """
        Handles the random masking of tokens.

        1. Randomly shuffles token embeddings for each sample.
        2. Keeps a subset (based on `masking_ratio`) and discards others.
        3. Gathers masked tokens at the end, unshuffles, 
           and returns a binary mask indicating which positions are masked.

        Args:
            x (torch.Tensor): Embedding tensor of shape (B, N, D).
            attn_mask (torch.Tensor or None): Attention mask of shape (B, N), 
                where 1 indicates real tokens and 0 indicates padded tokens.

        Returns:
            tuple: (x, mask, ids_restore)
                - x (torch.Tensor): The shuffled token embeddings with mask tokens appended 
                  and unshuffled back to original ordering. Shape (B, N, D).
                - mask (torch.Tensor): A binary mask of shape (B, N), where 1 indicates masked tokens.
                - ids_restore (torch.Tensor): Indices used to restore the original order. Shape (B, N).
        """
        B = x.shape[0] # (B, N, D)
        x, mask, ids_restore = self.random_masking(x, self.masking_ratio, attn_mask) # (B,len_masked,D), (B,N), (B,N)

        # append mask tokens to sequence and unshuffle
        nr_mask_tokens = ids_restore.shape[1] - x.shape[1] # int
        mask_tokens = self.mask_token.repeat(B, nr_mask_tokens, 1) # (B masked_tokens, D)
        x = torch.cat([x, mask_tokens], dim=1) # (B, N, D)
        x = torch.gather(x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # (B, N, D)        
        return x, mask, ids_restore # (B,N,D), (B,N), (B,N)
    
    def random_masking(self, x, mask_ratio, attn_mask=None):
        """
        Performs per-sample random masking by per-sample shuffling.

        1. Generate random noise per sample.
        2. Sort tokens by ascending noise (small indices are kept, large ones are masked).
        3. Optionally account for padded tokens with `attn_mask`.
        4. Return the partially-masked embeddings, the binary mask, and restore indices.

        Args:
            x (torch.Tensor): Embedding tensor of shape (B, N, D).
            mask_ratio (float): Fraction of tokens to mask.
            attn_mask (torch.Tensor, optional): Shape (B, N). 1 for real tokens, 0 for padded.

        Returns:
            tuple: (x_masked, mask, ids_restore)
                - x_masked (torch.Tensor): The partially-masked embeddings. Shape (B, len_keep, D).
                - mask (torch.Tensor): Binary mask of shape (B, N). 1 for masked, 0 for kept tokens.
                - ids_restore (torch.Tensor): Indices to restore the original order. Shape (B, N).
        """
        B, N, D = x.shape  # N = batch_size, N = num_tokens, D = embed_dim
        if attn_mask is not None:
            nr_padded_tokens = attn_mask.shape[-1] - attn_mask.sum(axis=-1) # (B)
            mask_keep = ((N - nr_padded_tokens) * (1-mask_ratio)).to(int) # (B)
        else:
            len_keep = int(N * (1 - mask_ratio)) # int
        
        noise = torch.rand(B, N, device=x.device)  # noise in [0, 1] of shape (B,N)

        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove; shape (B,N)
        ids_restore = torch.argsort(ids_shuffle, dim=1) # shape (B,N)

        # get total nr of tokens padded per batch
        # keep the first subset
        if attn_mask is not None:
            keep_ids = ids_shuffle < mask_keep.unsqueeze(1) # shape (B,N)
            x_shuffled = torch.gather(x, dim=1, index=ids_shuffle.unsqueeze(-1).repeat(1, 1, D)) # shape (B,len_keep,D)
            x_masked = torch.zeros_like(x) # shape (B,N,D)
            x_masked[keep_ids] = x_shuffled[keep_ids] 
        else:
            ids_keep = ids_shuffle[:, :len_keep] # shape (B,len_keep)
            x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # shape (B,len_keep,D)

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([B, N], device=x.device) # shape (B,N)
        if attn_mask is not None:
            mask[keep_ids] = 0 # shape (B,N)
        else:
            mask[:, :len_keep] = 0 # shape (B,len_keep)

        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore) # shape (B,N,D)

        return x_masked, mask, ids_restore # # shapes (B,len_keep,D), (B,N), (B,N) 
