import torch
import pytorch_lightning as pl
import hydra
import torch_optimizer as torch_optim
from models.modules.augmentations import SpecAugment, WhiteNoiseAugment
from torchmetrics.regression import PearsonCorrCoef, R2Score, MeanSquaredError
import torch.nn as nn

def get_params_from_checkpoint(checkpoint, head=False):
    """
    Retrieve model parameters from a PyTorch Lightning checkpoint.

    This function looks for weights in the `checkpoint["state_dict"]` that match
    a certain pattern. By default, it excludes keys containing `'_head'` so that
    the head (e.g., classification or regression head) can be ignored if desired.

    Parameters
    ----------
    checkpoint : dict
        A checkpoint dictionary as loaded by `torch.load`. Must contain
        a `"state_dict"` key.
    head : bool, optional
        If False (default), excludes all keys that contain `'_head'`.
        If True, includes them.

    Returns
    -------
    model_weights : dict
        Dictionary of filtered model weights that can be used to load
        part of a model's state_dict.
    """
    model_weights = {}
    state_dict = checkpoint["state_dict"]
    
    for k, v in checkpoint["state_dict"].items(): 
        head_cond = ('_head' not in k) if not head else True
        if k.startswith("model.") and head_cond:
            weight_key = k.replace('model.', '')            
            model_weights[weight_key] = v
    return model_weights

class RegressionTask(pl.LightningModule):
    """
    A PyTorch Lightning Module for regression tasks.

    This class wraps a backbone model (e.g., ViT, CNN, Transformer) for feature
    extraction, plus a regression head for predicting continuous values. It
    computes various regression metrics (Pearson Correlation, R2 Score, RMSE)
    during training, validation, and testing.

    Parameters
    ----------
    hparams : dict
        Hyperparameters for the model, optimizer, scheduler, etc. Should include
        keys like 'model', 'model_head', 'criterion', 'optimizer', and 'scheduler'.
    transform : callable, optional
        A data transformation function (e.g., STFT or other input preprocessing).
        By default, None.
    freeze_backbone : bool, optional
        If True, freeze the parameters of the backbone encoder so that only
        the head is trained. Defaults to False.
    layerwise_lr_decay : float, optional
        Factor by which the learning rate decays per layer in the backbone (e.g., 0.9).
        Defaults to 0.9.
    freq_mask_param : int, optional
        Parameter controlling frequency masking in SpecAugment. Default is 0 (no freq masking).
    time_mask_param : int, optional
        Parameter controlling time masking in SpecAugment. Default is 0 (no time masking).
    noise_level : float, optional
        Standard deviation of white noise to add for data augmentation. Default is 0.15.
    augment_prob : float, optional
        Probability of applying augmentation to a given sample. Default is 0.5.
    """
    def __init__(self, hparams, transform=None, freeze_backbone=False, 
                 layerwise_lr_decay=0.9, freq_mask_param=0, time_mask_param=0, noise_level=0.15, augment_prob=0.5):
        super().__init__()
        self.img_logging_step = 0
        self.save_hyperparameters(hparams)
        self.layerwise_lr_decay = layerwise_lr_decay
        self.model = hydra.utils.instantiate(self.hparams.model)
        self.model_head = hydra.utils.instantiate(self.hparams.model_head)
        self.freeze_backbone = freeze_backbone
        
        # Use MSE Loss for Regression
        self.criterion = hydra.utils.instantiate(self.hparams.criterion)
        self.transform = transform
        self.using_spectrogram = self.model.using_spectrogram
        self.augment_prob = augment_prob
        self.noise_level = noise_level
        self.strict_loading = False
        self.tanh = nn.Tanh()

        # White noise augmentation for both waveforms and spectrograms
        self.white_noise_augment = WhiteNoiseAugment(noise_level=self.noise_level, augment_prob=self.augment_prob)
        
        if self.using_spectrogram:
            self.freq_mask_param = freq_mask_param
            self.time_mask_param = time_mask_param
            self.spec_augment = SpecAugment(freq_mask_param=self.freq_mask_param,
                                            time_mask_param=self.time_mask_param,
                                            augment_prob=self.augment_prob)
        
        # REGRESSION METRICS
        # 1) Pearson Correlation Coefficient
        self.train_pearson = PearsonCorrCoef(num_outputs=12)
        self.val_pearson = PearsonCorrCoef(num_outputs=12)
        self.test_pearson = PearsonCorrCoef(num_outputs=12)
        
        # 2) R2 Score
        self.train_r2 = R2Score(num_outputs=12, multioutput="uniform_average")
        self.val_r2 = R2Score(num_outputs=12, multioutput="uniform_average")
        self.test_r2 = R2Score(num_outputs=12, multioutput="uniform_average")
        
        # 3) RMSE (Root Mean Squared Error)
        self.train_rmse = MeanSquaredError(squared=False, num_outputs=12)
        self.val_rmse = MeanSquaredError(squared=False, num_outputs=12)
        self.test_rmse = MeanSquaredError(squared=False, num_outputs=12)

        if self.freeze_backbone:
            print('Freezing encoder params when training from scratch')
            for param in self.model.parameters():
                param.requires_grad = False
        
    def training_step(self, batch, batch_idx):
        """
        Defines a single training step.

        Parameters
        ----------
        batch : dict
            A dictionary containing 'input' (the input data) and 'label' (the regression labels).
            May also contain 'nr_padded_channels'.
        batch_idx : int
            The index of the batch, provided by PyTorch Lightning.

        Returns
        -------
        torch.Tensor
            The loss value for this step.
        """
        if self.freeze_backbone:
            self.model.eval()
        self.model_head.train()

        X = batch['input']
        y = batch['label']
        
        # TODO: this task does not function with a pad token. First determine if that is needed.
        # In our project (cerebro), it was not needed
        if 'nr_padded_channels' in batch.keys():
            nr_channels_padded = batch['nr_padded_channels']        
        
        encoder_output = self.model(X)
        # Compute predictions from the regression head
        y_preds = self.model_head(encoder_output).squeeze()
        y_preds = self.tanh(y_preds) # Force predictions in the range [-1, 1]
        y_flat = y.reshape(-1, 12)  # Shape: (batch_size * time_steps, num_outputs)
        y_preds_flat = y_preds.reshape(-1, 12)  # Shape: (batch_size * time_steps, num_outputs)
        
        # Compute metrics
        # Flattening is done in criterion so no need to do anything for the loss
        loss = self.criterion(y_preds_flat, batch)
        
        # Update regression metrics
        self.train_pearson(y_preds_flat, y_flat)
        self.train_r2(y_preds_flat, y_flat)
        self.train_rmse(y_preds_flat, y_flat)

        self.log('train_loss', loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        return loss
        
    def validation_step(self, batch, batch_idx):
        """
        Defines a single validation step.

        Parameters
        ----------
        batch : dict
            A dictionary containing 'input' and 'label'.
            May also contain 'nr_padded_channels'.
        batch_idx : int
            The index of the batch, provided by PyTorch Lightning.

        Returns
        -------
        torch.Tensor
            The loss value for this validation step.
        """
        X = batch['input']
        y = batch['label']
        
        if 'nr_padded_channels' in batch.keys():
            nr_channels_padded = batch['nr_padded_channels']        
        
        encoder_output = self.model(X)
        # Compute predictions from the regression head
        y_preds = self.model_head(encoder_output)
        y_preds = self.tanh(y_preds) # Force predictions in the range [-1, 1]
        y_flat = y.reshape(-1, 12)  # Shape: (batch_size * time_steps, num_outputs)
        y_preds_flat = y_preds.reshape(-1, 12)  # Shape: (batch_size * time_steps, num_outputs)

        
        # Compute metrics
        # Flattening is done in criterion so no need to do anything for the loss
        loss = self.criterion(y_preds_flat, batch)

        # Update regression metrics
        self.val_pearson(y_preds_flat, y_flat)
        self.val_r2(y_preds_flat, y_flat)
        self.val_rmse(y_preds_flat, y_flat)

        # Log performance metrics in Tensorboard
        self.log('val_loss', loss.item(), prog_bar=True, on_step=True, on_epoch=True, logger=True, sync_dist=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        """
        Defines a single test step.

        Parameters
        ----------
        batch : dict
            A dictionary containing 'input' and 'label'.
            May also contain 'nr_padded_channels'.
        batch_idx : int
            The index of the batch, provided by PyTorch Lightning.

        Returns
        -------
        torch.Tensor
            The loss value for this test step.
        """
        X = batch['input']
        y = batch['label']
        if 'nr_padded_channels' in batch.keys():
            nr_channels_padded = batch['nr_padded_channels']        
        
        encoder_output = self.model(X)
        
        # Compute predictions from the regression head
        y_preds = self.model_head(encoder_output).squeeze()
        y_preds = self.tanh(y_preds) # Force predictions in the range [-1, 1]

        y_flat = y.reshape(-1, 12)  # Shape: (batch_size * time_steps, num_outputs)
        y_preds_flat = y_preds.reshape(-1, 12)  # Shape: (batch_size * time_steps, num_outputs)

        # Compute metrics. 
        # Flattening is done in criterion so no need to do anything for the loss
        loss = self.criterion(y_preds, batch)
        
        # Update regression metrics
        self.test_pearson(y_preds_flat, y_flat)
        self.test_r2(y_preds_flat, y_flat)
        self.test_rmse(y_preds_flat, y_flat)
            
        # Log performance metrics in Tensorboard
        self.log('test_loss', loss.item(), prog_bar=True, on_step=True, on_epoch=True, logger=True, sync_dist=True)
        return loss
    
    def on_train_epoch_end(self):
        """
        Called at the end of the training epoch.

        Logs aggregate metrics (e.g., R2, RMSE) for the entire training epoch.
        Some metrics (Pearson, RMSE) may be commented out or selectively enabled.
        """
        # self.log('train_pearson', self.train_pearson, prog_bar=True, logger=True, sync_dist=True, on_step=False, on_epoch=True)
        self.log('train_r2', self.train_r2, prog_bar=True, logger=True, sync_dist=True, on_step=False, on_epoch=True)
        # rmse = torch.mean(self.train_rmse.compute())
        # self.log('train_rmse', rmse, prog_bar=True, logger=True, sync_dist=True, on_step=False, on_epoch=True)
        # self.train_rmse.reset()
    
    def on_validation_epoch_end(self):
        """
        Called at the end of the validation epoch.

        Logs aggregate metrics for the entire validation epoch.
        Some metrics (Pearson, RMSE) may be commented out or selectively enabled.
        """
        # self.log('val_pearson', self.val_pearson, prog_bar=True, logger=True, sync_dist=True, on_step=False, on_epoch=True)
        self.log('val_r2', self.val_r2, prog_bar=True, logger=True, sync_dist=True, on_step=False, on_epoch=True)
        # rmse = torch.mean(self.val_rmse.compute())
        # self.log('val_rmse', rmse, prog_bar=True, logger=True, sync_dist=True, on_step=False, on_epoch=True)
        # self.val_rmse.reset()
    
    def on_test_epoch_end(self):
        """
        Called at the end of the test epoch.

        Logs aggregate metrics for the entire test epoch, including
        R2 and RMSE.
        """
        # self.log('test_pearson', self.test_pearson, prog_bar=True, logger=True, sync_dist=True, on_step=False, on_epoch=True)
        self.log('test_r2', self.test_r2, prog_bar=True, logger=True, sync_dist=True, on_step=False, on_epoch=True)
        rmse = torch.mean(self.test_rmse.compute())
        self.log('test_rmse', rmse, prog_bar=True, logger=True, sync_dist=True, on_step=False, on_epoch=True)
        self.test_rmse.reset()
    
    def configure_optimizers(self):
        """
        Define optimizers and learning-rate schedulers to use in your optimization.

        Returns:
            [optimizer],[scheduler] - The first list contains optimizers, the
            second contains LR schedulers (or lr_dict).
        """
        # Separate parameters for the encoder and the head
        model_params = list(self.model.named_parameters())
        model_head_params = list(self.model_head.named_parameters())

        # Calculate the number of Transformer blocks in the encoder
        num_blocks = self.hparams.model.depth

        # Group parameters with their layer-wise learning rates
        params_to_pass = []

        # Apply layer-wise decay to encoder parameters
        base_lr = self.hparams.optimizer.lr
        decay_factor = self.layerwise_lr_decay

        for name, param in model_params:
            lr = base_lr
            if name.startswith('blocks.'):
                block_nr = int(name.split('.')[1])
                lr *= decay_factor ** (num_blocks - block_nr)
            params_to_pass.append({"params": param, "lr": lr})


        # Add head parameters with the base learning rate
        params_to_pass.extend([{"params": params} for name, params in model_head_params])

        print("\nLearning rates for encoder blocks:")
        for name, param in self.model.named_parameters():
            if name.startswith('blocks.'):
                block_nr = int(name.split('.')[1])
                lr = base_lr * (decay_factor ** (num_blocks - block_nr))
                print(f"Block {block_nr}: {lr}")


        if self.hparams.optimizer.optim == "SGD":
            optimizer = torch.optim.SGD(params_to_pass, lr=self.hparams.optimizer.lr, momentum=self.hparams.optimizer.momentum)
        elif self.hparams.optimizer.optim == 'Adam':
            optimizer = torch.optim.Adam(params_to_pass, lr=self.hparams.optimizer.lr, weight_decay=self.hparams.optimizer.weight_decay)
        elif self.hparams.optimizer.optim == 'AdamW':
            optimizer = torch.optim.AdamW(params_to_pass, lr=self.hparams.optimizer.lr, weight_decay=self.hparams.optimizer.weight_decay, betas=self.hparams.optimizer.betas)
        elif self.hparams.optimizer.optim == 'LAMB':
            optimizer = torch_optim.Lamb(params_to_pass, lr=self.hparams.optimizer.lr)
        else:
            raise NotImplementedError("No valid optimizer name")

        print('OPTIMIZER', optimizer)
        print(f"ESTIMATED TRAINING BATCHES: {self.trainer.num_training_batches}")
        print(f"ESTIMATED GRAD ACCUM: {self.trainer.accumulate_grad_batches}")
        print(f"ESTIMATED STEPPING BATCHES FOR ENTIRE TRAINING: {self.trainer.estimated_stepping_batches}")
        print(f"MAX EPOCHS: {self.trainer.max_epochs}")
        scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer=optimizer, 
                                            total_training_opt_steps=self.trainer.estimated_stepping_batches)
        print('SCHEDULER', scheduler)

        lr_scheduler_config = {
            "scheduler": scheduler,
            "interval": "step",
            "frequency": 1
        }

        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
    
    def lr_scheduler_step(self, scheduler, metric):
        """
        Manually step the scheduler each update. Called by the Trainer.

        Parameters
        ----------
        scheduler : object
            The instantiated scheduler (e.g., warmup + decay).
        metric : float
            A metric value if needed by the scheduler, but often unused.
        """
        scheduler.step_update(num_updates=self.global_step)

    def load_from_checkpoint(self, checkpoint_path, map_location= None, hparams_file = None, strict= None, **kwargs):
        """
        Custom checkpoint-loading method for partially loading a model.

        Parameters
        ----------
        checkpoint_path : str
            Path to the saved checkpoint file.
        map_location : str or torch.device, optional
            Device mapping for loading the checkpoint.
        hparams_file : str, optional
            Path to a file containing hyperparameters (not always used).
        strict : bool, optional
            Whether to strictly enforce matching keys between the checkpoint
            and the model.
        **kwargs : dict
            Additional arguments passed to `torch.load` or relevant.

        Returns
        -------
        self
            The current instance (with loaded weights).
        """
        print('\n\nOverriding load_from_checkpoint method')
    
        ckp = torch.load(checkpoint_path, map_location=map_location)
        state_dict_no_head = get_params_from_checkpoint(ckp, head=False)
        
        
        model_state_dict = self.model.state_dict()
        is_changed = False
        for k in state_dict_no_head:
            if k in model_state_dict:
                if state_dict_no_head[k].shape != model_state_dict[k].shape:
                    print(f"Skip loading parameter: {k}, "
                                f"required shape: {model_state_dict[k].shape}, "
                                f"loaded shape: {state_dict_no_head[k].shape}")
                    state_dict_no_head[k] = model_state_dict[k]
                    is_changed = True
            else:
                print(f"Dropping parameter {k}")
                is_changed = True
        
        self.model.load_state_dict(state_dict_no_head, strict=False)
              
        if self.freeze_backbone:
            print('Freezing encoder params from loaded checkpoint')
            for name, param in self.model.named_parameters():
                print(name, end=' trainable=')
                keep_unfrozen_list = ['patch_embed', 'channel_encoding', 'patch_encoding', 'pos_encoding', 'patch_embed']
                requires_grad = False
                for p in keep_unfrozen_list:
                    if p in name:
                        requires_grad=True
                        continue
                
                param.requires_grad = requires_grad
                print(requires_grad)
   
        return self
