from pytorch_lightning.callbacks import Callback
import torch
from custom_op.register import register_measure_perplexity_HOSVD
from pytorch_lightning import seed_everything


class LogActivationMemoryCallback(Callback):
    """
        Callback to log activation memory usage during training and validation.

        Args:
            log_activation_mem (bool): If True, logs activation memory for specific methods (SVD, HOSVD).
    """

    def __init__(self, log_activation_mem=False, perplexity=None):
        self.log_activation_mem             = log_activation_mem    # if True: Log estimation of activation memory
        self.first_train_batch_start_logged = False
        self.first_train_batch_end_logged   = False                 # a flag indicating that training of the 1st batch of the 1st epoch is finish
        self.training_begin                 = False                 # a flag indicating the beginning of training
        self.num_train_batches              = None                  # number of batch of data for training
        self.num_val_batches                = None                  # number of batch of data for validating

        self.first_validation_done          = False
        self.first_validation_batch_done    = False

            
        self.previous_loss = 0

        # 2nd way
        self.considered_layer = 0
        
        if perplexity is not None:
            self.perplexity = perplexity
            self.total_epsilon = len(self.perplexity.perplexity[0])
            self.total_layer = len(self.perplexity.perplexity)
            self.epsilon_idx = 0
            self.layer_idx = 0

    def on_train_epoch_start(self, trainer, model):
        """
        Called at the beginning of a training epoch.
        Attaches a list to store memory information to the model (for SVD and HOSVD) if logging is enabled and it is the first epoch.
        """
        if not self.training_begin:
            if self.log_activation_mem:
                if (hasattr(model, 'with_AMC') and model.with_AMC):
                    model.attach_memory_list()
                
                if hasattr(model, 'with_WSI') and model.with_WSI and hasattr(model, 'attach_memory_list_weight') and not model.WSI_with_sub_iter: # If using SVD to decompose weight at every iteration
                    model.attach_memory_list_weight()
                    
            self.training_begin = True
        

        if (hasattr(model, 'measure_perplexity_HOSVD_var') and model.measure_perplexity_HOSVD_var):
            model.filter_cfgs["explain_variance_threshold"] = self.perplexity.set_of_epsilons[self.epsilon_idx]
            print(f"For epsilon is {self.perplexity.set_of_epsilons[self.epsilon_idx]}")
            seed_everything(233)
            if (hasattr(model, 'measure_perplexity_HOSVD_var') and model.measure_perplexity_HOSVD_var):
                register_measure_perplexity_HOSVD(model, model.filter_cfgs)
            
            model.update_optimizer()

        
    def on_train_epoch_end(self, trainer, model):
        """
        Called at the end of a training epoch.
        Resets the attached memory list sizes for SVD or HOSVD methods after each epoch.
        """
        if self.log_activation_mem:
            if (hasattr(model, 'with_AMC') and model.with_AMC):
                model.get_resource_consumption_AMC(self.num_train_batches) # Decomposition only occurs during training 
                model.reset_k_hosvd()
            
            if (hasattr(model, 'with_WSI') and model.with_WSI and hasattr(model, 'get_weight_size_WSI')) and not model.WSI_with_sub_iter: # If using SVD to decompose weight at every iteration
                model.get_weight_size_WSI(self.num_train_batches) # Decomposition only occurs during training 
                model.reset_memory_list_weight()

    
    def on_train_batch_start(self, trainer, model, batch, batch_idx, dataloader_idx):
        """
        Called at the start of a training batch.
        Sets the flag `train_batch_start` to True at the beginning of every batch.
        """

        if self.log_activation_mem:
            if not self.first_train_batch_start_logged: # Log in the first epoch with the first train batch because the activation memory of these methods is stable.
                    model.get_resource_consumption(register_hook=True)
                    self.first_train_batch_start_logged = True

    def on_train_batch_end(self, trainer, model, outputs, batch, batch_idx, dataloader_idx):
        """
        Called at the end of a training batch.
        Logs activation memory for the first batch if applicable (for Vanilla Training and Gradient Filter)
        """
        self.num_train_batches = batch_idx + 1
        if self.log_activation_mem:
            if not self.first_train_batch_end_logged: # Log in the first epoch with the first train batch because the activation memory of these methods is stable.
                model.get_resource_consumption(register_hook=False)
                self.first_train_batch_end_logged = True
        
        if model.just_log:
            trainer.should_stop = True
            trainer.limit_val_batches = 0

        if (hasattr(model, 'measure_perplexity_HOSVD_var') and model.measure_perplexity_HOSVD_var):

            for i in range(self.total_layer):
                self.perplexity.perplexity[i][self.epsilon_idx] = model.perplexity[i].item() if isinstance(model.perplexity[i], torch.Tensor) else model.perplexity[i]
                self.perplexity.ranks[i][self.epsilon_idx]      = model.measured_rank[i]
                self.perplexity.layer_mems[i][self.epsilon_idx] = model.layer_mem[i].item() if isinstance(model.layer_mem[i], torch.Tensor) else model.layer_mem[i]

            model.clear_measured_variables()

            self.epsilon_idx += 1

    def on_validation_batch_end(self, trainer, model, outputs, batch, batch_idx, dataloader_idx):
        """
        Called at the end of a validation batch.
        Resets the attached memory list sizes for SVD or HOSVD methods after each validation.
        """
        self.num_val_batches = batch_idx + 1

    def on_validation_epoch_end(self, trainer, model):
        if not self.first_validation_done:
            self.first_validation_done = True
