"""
    use_tunable_threshold_activation: will use JumpReLU with tunable threshold as the bias
    use_threshold_activation: will use JumpReLU with fixed threshold theta
    These two methods are mutually exclusive, and does not work with topK for now. 
"""


import os, sys
from tqdm import tqdm, trange
# os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["VSCODE_PROXY_CUDA_DEVICE"] # FIXME: remove this line in sbatch script
current_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(current_dir)
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)
grandpa_dir = os.path.dirname(parent_dir)
sys.path.append(grandpa_dir)

import random
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
import lightning.pytorch as pl
import wandb
import yaml

from Simtransformer.simtransformer.utils import (
    EasyDict, clever_load, check_cosine_similarity,  Shampoo, signSGD, normSGD, NormalizeSGD, CosineAnnealingWarmup, dominance_metrics, AlignmentLoss,
    MIR, HDR
)
from Simtransformer.simtransformer.model_base import (
    SAEWithChannel, 
    GradRescaler, 
    SparseAutoEncoder, 
    JumpReLU, 
    ThresholdReLU)
from typing import List, Tuple, Union, Any
from lightning.pytorch.utilities.types import LRSchedulerTypeUnion
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch import seed_everything
import lightning
import time
import copy

from datetime import datetime
import matplotlib.pyplot as plt



    
class SytheticDataset(Dataset):
    def __init__(self, feat_dict, total_length=10_000, data_type='balanced', return_label=False, **kwargs):
        """
        Args:
            feat_dict (dict): A dictionary where each key maps to a 
                              feature tensor of shape [N, D], for example.
            total_length (int): How many samples you want this dataset 
                                to have (since each sample is generated on-the-fly).
        """
        super().__init__()
        self.feat_dict = feat_dict
        self.total_length = total_length
        # Cache the keys to avoid repeated list(...) calls in __getitem__
        self.keys = list(self.feat_dict.keys())
        self.return_label = return_label
        self.data_type = data_type
        self.kwargs = kwargs
        
    def __len__(self):
        return self.total_length
    
    def __getitem__(self, idx):
        # Randomly pick 3 keys from feat_dict
        kwargs = self.kwargs
        # sample perturbed scaler for each feature
        factors = torch.ones(len(self.keys), device=self.feat_dict['feat0'].device)
            
        if self.data_type == 'perturbed':
            # apply small perturbation around 1, scaled by a small factor (e.g., 0.1)
            perturbation = torch.randn(3, device=self.feat_dict['feat0'].device) * 0.5
            factors = factors + perturbation  # Add perturbation
            # Clip factors to ensure they stay within a reasonable range (e.g., [0.5, 1.5] or [0.8, 1.2])
            clamp_min = kwargs.get('clamp_min', 0.5)
            clamp_max = kwargs.get('clamp_max', 1.5)
            factors = torch.clamp(factors, min=clamp_min, max=clamp_max)
            factors = factors / factors.sum() * 3
            # print(factors)
            # Sum the features from the randomly chosen entries
        elif self.data_type == 'imbalanced':
            factor_1 = kwargs.get('factor_1', 1)
            factor_2 = kwargs.get('factor_2', 1)
            factor_3 = kwargs.get('factor_3', 1)
            factors = torch.tensor([factor_1, factor_2, factor_3], device=self.feat_dict['feat0'].device)
            factors = factors / factors.sum() * 3
        
        sum_features = 0

        idx = [random.randint(0, self.feat_dict[k].shape[0] - 1) for k in self.keys]
        for k in self.keys:
            # Random index within the dataset for feature k
            # rand_idx = random.randint(0, self.feat_dict[k].shape[0] - 1)
            sum_features = sum_features + self.feat_dict[k][idx[self.keys.index(k)]] * factors[self.keys.index(k)]
        

        input_norm = kwargs.get('input_norm', None)
        if input_norm is not None:
            sum_features = sum_features /(torch.norm(sum_features, p=2, dim=-1, keepdim=True) + 1e-8) * input_norm
        # Return the summed feature as your "data"
        # If you also need a label, you can return (sum_features, label)
        if self.return_label:
            return sum_features, torch.tensor(idx)    # return the sum of the features and the keys
        return sum_features

class SynSAEDataModule(pl.LightningDataModule):
    def __init__(self, 
                 feat_set, 
                 batch_size, 
                 num_workers=4, 
                 num_samples=10_000, 
                 return_label=False,
                 **kwargs):
        """
        Args:
            feat_set (dict): A dict mapping keys -> paths to .pt files,
                             each containing a tensor of features.
            batch_size (int): Batch size used in DataLoader.
            num_workers (int): Number of workers in DataLoader.
            num_samples (int): How many random samples to generate in the SummationDataset.
        """
        super().__init__()
        self.feat_set = {}
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.num_samples = num_samples
        # Load each .pt file into memory and store in self.feat_set
        for k, path in feat_set.items():
            self.feat_set[k] = torch.load(path)
            # debias the feature
            self.feat_set[k] = self.feat_set[k] - self.feat_set[k].mean(dim=0)
        self.return_label = return_label
        self.kwargs = kwargs
        
    def setup(self, stage=None):
        first_key = next(iter(self.feat_set))
        self.hidden_size = self.feat_set[first_key].shape[1]
            
    def train_dataloader(self):
        # Create your custom dataset that sums three random features
        train_dataset = SytheticDataset(self.feat_set, total_length=self.num_samples, return_label=self.return_label, **self.kwargs)
        return DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers
        )
    
    def val_dataloader(self):
        val_dataset = SytheticDataset(
            self.feat_set, 
            total_length=int(self.num_samples * (1 - self.kwargs.get('split_ratio', 0.95))), 
            return_label=self.return_label, 
            **self.kwargs)
        return DataLoader(
            val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers
        )
        
    def test_dataloader(self):
        test_dataset = SytheticDataset(self.feat_set, total_length=int(self.num_samples * 0.05), return_label=self.return_label, **self.kwargs)
        return DataLoader(
            test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers
        )
        
class FeatCodeDataModule(pl.LightningDataModule):
    def __init__(self, feat_path, code_path, weight_path, batch_size=128, num_workers=4, split_ratio=0.9, seed=42, **kwargs):
        super().__init__()
        self.feat_path = feat_path
        self.code_path = code_path
        self.weight_path = weight_path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.split_ratio = split_ratio
        self.seed = seed
        self.train_code = None
        self.val_code = None
        
    def setup(self, stage=None):
        if self.train_code is not None:
            return EasyDict(
                num_samples=len(self.train_code),
                num_train_samples=len(self.train_code),
                num_val_samples=len(self.val_code),
                num_features=self.feat.shape[0], 
                hidden_size=self.feat.shape[1],
                sparsity=self.train_code.shape[1],
                split_ratio=self.split_ratio,
                seed=self.seed,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                feat_path=self.feat_path,
                code_path=self.code_path,
                weight_path=self.weight_path,
            )
        # create dataset
        self.feat = torch.load(self.feat_path, weights_only=True)
        code = torch.load(self.code_path, weights_only=True)
        if isinstance(code, dict):
            code = code['decoded_codes']
        code = code.to(torch.long)
        
        weight = torch.load(self.weight_path, weights_only=True)
        
        self.s = code.shape[1]
        
        train_length = int(len(code) * self.split_ratio)
        val_length = len(code) - train_length
        
        generator = torch.Generator().manual_seed(self.seed)
        indices = torch.randperm(len(code), generator=generator)
        train_indices = indices[:train_length]
        val_indices = indices[train_length:]
        self.train_code = code[train_indices]
        self.val_code = code[val_indices]
        self.train_weight = weight[train_indices]
        self.val_weight = weight[val_indices]
        
        return EasyDict(
            num_samples=train_length,
            num_train_samples=train_length,
            num_val_samples=val_length,
            num_features=self.feat.shape[0], 
            hidden_size=self.feat.shape[1],
            sparsity=code.shape[1],
            split_ratio=self.split_ratio,
            seed=self.seed,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            feat_path=self.feat_path,
            code_path=self.code_path,
            weight_path=self.weight_path,
        )
        
    def train_dataloader(self):
        dataset = TensorDataset(self.train_code, self.train_weight)
        return DataLoader(dataset,  
                          batch_size=self.batch_size, 
                          shuffle=True, 
                          num_workers=self.num_workers, 
                          pin_memory=True)

    def val_dataloader(self):
        dataset = TensorDataset(self.val_code, self.val_weight)
        return DataLoader(dataset,  
                          batch_size=self.batch_size, 
                          shuffle=False, 
                          num_workers=self.num_workers, 
                          pin_memory=True)
        
    def on_before_batch_transfer(self, batch, dataloader_idx):
        # batch is a tuple of (code, weight)

        code = batch[0]
        weight = batch[1]
        
        original_shape = code.shape
        code = code.view(-1)
        batch = self.feat[code]
        batch = batch.view(original_shape[0], original_shape[1], -1)
        
        batch = (batch * weight.unsqueeze(-1)).sum(dim=1) # shape: (batch_size, d)
        
        return batch

class LitSAEWithChannel(pl.LightningModule):
    def __init__(
        self,
        config: EasyDict,
        **kwargs
    ):
        """
        Wraps the SAEWithChannel in a LightningModule for training and inference.
        
        Args:
            input_size (int): number of input features.
            hidden_size (int): number of hidden units in the SAE.
            channel_size_ls (Union[List[int], Tuple[int], int]): shape for channel dimensions.
            activation (str): type of activation function to use.
            learning_rate (float): optimizer learning rate.
            weight_decay (float): weight decay (L2 regularization) for the optimizer.
            **kwargs: additional arguments for the Activation or model.
        """
        super().__init__()
        # This allows Lightning to log hyperparameters
        # self.save_hyperparameters(ignore=['kwargs'])
        self.config = config
        
        if 'split_neuron_by_group' in kwargs:
            group_indices = self.get_group_index(config.phase_0_config.adjust_b_enc_config.group_partitions, config.num_neurons)
        
            self.model = SAEWithChannel(
                input_size=config.hidden_size,
                hidden_size=config.num_neurons,
                channel_size_ls=config.channel_size_ls,
                activation=config.activation, 
                use_neuron_weight=config.use_neuron_weight if 'use_neuron_weight' in config else False,
                group_indices=group_indices,
                **kwargs, 
            )
        else:
            self.model = SAEWithChannel(
                input_size=config.hidden_size,
                hidden_size=config.num_neurons,
                channel_size_ls=config.channel_size_ls,
                activation=config.activation, 
                use_neuron_weight=config.use_neuron_weight if 'use_neuron_weight' in config else False,
                **kwargs, 
            )
        
        # normalize W_enc
        self.normalize_W_enc()
        
        # MSE as reconstruction loss
        self.reconstruction_loss = nn.MSELoss(reduction='none', reduce=False)
        
        # alignment loss
        self.alignment_loss = AlignmentLoss(normalize=config.alignment_loss_normalize if 'alignment_loss_normalize' in config else True)
        
        # get the ground truth feature set
        if 'feat_set' in kwargs:
            self.feat_set = kwargs['feat_set']
            self.neuron_vis = random.sample(range(config.num_neurons), config.num_neuron_vis)
        else:
            self.feat_set = None
        
        # get the model_old if it is in the kwargs
        if 'model_old' in kwargs:
            self.model_old = kwargs['model_old']
        
        self.automatic_optimization = False  # Disable automatic optimization
        
        # save config
        self.save_hyperparameters(config)


        # set the threshold for each neuron
        threshold = torch.ones(config.num_neurons, device=self.model.W_enc.device) * 0.5
        threshold = threshold.unsqueeze(0)
        self.threshold = nn.Parameter(threshold)

        self.grad_norm = EasyDict()

        
    def on_load_checkpoint(self, checkpoint):
        # Get the state_dict from checkpoint
        state_dict = checkpoint['state_dict']
        
        # Adjust the checkpoint to the new model structure
        if ('model.W_enc' in state_dict or 'model._W_enc' in state_dict) and isinstance(self.model._W_enc, nn.ParameterList) and 'model._W_enc.0' not in state_dict: 
            # Split W_enc into groups based on group_indices
            W_enc = state_dict['model.W_enc'] if 'model.W_enc' in state_dict else state_dict['model._W_enc']
            start_idx = 0
            for i in range(len(self.model._W_enc)):
                if i == 0:
                    group_size = self.model.group_indices[0]
                else:
                    group_size = self.model.group_indices[i] - self.model.group_indices[i-1]
                state_dict[f'model._W_enc.{i}'] = W_enc[..., start_idx:start_idx+group_size, :]
                start_idx += group_size
            # delete the original key 
            if 'model.W_enc' in state_dict:
                del state_dict['model.W_enc']
            if 'model._W_enc' in state_dict:
                del state_dict['model._W_enc']

        if ('model.b_enc' in state_dict or "model._b_enc" in state_dict) and isinstance(self.model._b_enc, nn.ParameterList) and 'model._b_enc.0' not in state_dict:
            b_enc = state_dict['model.b_enc'] if 'model.b_enc' in state_dict else state_dict['model._b_enc']
            start_idx = 0
            for i in range(len(self.model._b_enc)):
                if i == 0:
                    group_size = self.model.group_indices[0]
                else:
                    group_size = self.model.group_indices[i] - self.model.group_indices[i-1]
                state_dict[f'model._b_enc.{i}'] = b_enc[..., start_idx:start_idx+group_size]
                start_idx += group_size
            # delete the original key
            if 'model.b_enc' in state_dict:
                del state_dict['model.b_enc']
            if 'model._b_enc' in state_dict:
                del state_dict['model._b_enc']
                
        if 'model.W_enc' in state_dict and isinstance(self.model._W_enc, nn.Parameter) and 'model._W_enc' not in state_dict: 
            state_dict['model._W_enc'] = state_dict['model.W_enc']
            
        if 'model.b_enc' in state_dict and isinstance(self.model._b_enc, nn.Parameter) and 'model._b_enc' not in state_dict:
            state_dict['model._b_enc'] = state_dict['model.b_enc']
            
        if 'model._W_enc.0' in state_dict and isinstance(self.model._W_enc, nn.Parameter):
            # find all the keys that match the pattern 'model._W_enc_\d+'
            keys = [key for key in state_dict.keys() if key.startswith('model._W_enc.')]
            concatenated = torch.cat(
                [state_dict[key] for key in sorted(keys, key=lambda x: int(x.split('.')[-1]))],
                dim=-2,
            )
            state_dict['model._W_enc'] = concatenated
            # delete the keys
            for key in keys:
                del state_dict[key]
        if 'model._b_enc.0' in state_dict and isinstance(self.model._b_enc, nn.Parameter):
            # find all the keys that match the pattern 'model._b_enc_\d+'
            keys = [key for key in state_dict.keys() if key.startswith('model._b_enc.')]
            concatenated = torch.cat(
                [state_dict[key] for key in sorted(keys, key=lambda x: int(x.split('.')[-1]))],
                dim=-1,
            )
            state_dict['model._b_enc'] = concatenated
            # delete the keys
            for key in keys:
                del state_dict[key]
                
        
                
            
        # Collect keys to delete
        keys_to_delete = []
        for key in state_dict.keys():
            if 'model.W_enc' in key or 'model.b_enc' in key:
                keys_to_delete.append(key)
                
        # Delete the collected keys
        for key in keys_to_delete:
            del state_dict[key]
            
        # Update the checkpoint with the modified state_dict
        checkpoint['state_dict'] = state_dict
        
    
    
    def get_group_index(self, group_partitions, num_neurons):
        """
        Get the group index for the given group partition and number of neurons.
        group_partitions is a list of numbers that sum to 1.
        """

        # get the cumulative sum of the group partitions
        # partition: [0.1, 0.2, 0.3, 0.4]
        # cumsum: [0.1, 0.3, 0.6, 1.0]
        group_partitions = torch.tensor(group_partitions)
        cumsum_group_partitions = torch.cumsum(group_partitions, dim=0)
        # get the index of the neuron in the group
        # rand: 0.5
        assert cumsum_group_partitions[-1] == 1.0, "The sum of the group partitions should be 1."
        neuron_index = [int(cumsum * num_neurons) for cumsum in cumsum_group_partitions] # shape: (num_groups,)
        return neuron_index
    
    def prepare_phase(self):
        """
        Prepare the adjustment of b_enc based on the mean pre-activation values.
        """
        # find all the keys in config that starts with 'adjust_b_enc_config'
        self.phase_configs = {k: v for i, (k, v) in enumerate(self.config.items()) if k.startswith('phase')}

        # move the remaining configurations into each phase
        self.remaining_configs = EasyDict({k: v for k, v in self.config.items() if k not in self.phase_configs.keys()})
        # NOTE: if we use dict.keys(), it will only loop through the keys of depth 1. If we use in dict, it will loop through all the keys in the dictionary
        for k, v in self.phase_configs.items():
            v.update(self.remaining_configs)
        
        
        # store the start and end steps for each adjust_b_enc_config
        self.phase_start_steps = []
        self.phase_end_steps = []
        
        for k, v in self.phase_configs.items():
            # find the keyword 'start_step' and 'end_step' in the config
            if 'start_step' in v:
                start_step = int(v['start_step'])
            else:
                start_step = 0
                print(f'No start_step is provided for {k}. Default to 0.')
            import math
            # num_training_bathces = self.config.data_config.num_samples // self.config.data_config.batch_size
            num_training_bathces = math.ceil(self.config.data_config.num_samples / self.config.data_config.batch_size)
            if 'end_step' in v:
                end_step = int(min(v['end_step'], self.config.max_epochs * num_training_bathces))
            else:
                end_step = int(self.config.max_epochs * num_training_bathces)
                print(f'No end_step is provided for {k}. Default to the maximum number of steps {end_step}.')
                
            self.phase_start_steps.append(start_step)
            self.phase_end_steps.append(end_step)
            
        # sanity check: 
        if len(self.phase_configs) > 0:
            # Step1: sort the start and end steps (note that each pair of start and end steps should be in the same order)
            sorted_steps = sorted(zip(self.phase_start_steps, self.phase_end_steps, self.phase_configs.items()))
            self.phase_start_steps, self.phase_end_steps, sorted_configs = zip(*sorted_steps)
            self.phase_configs = EasyDict(dict(sorted_configs))
            
            # Step2: check if the start and end steps are in the correct order
            for i in range(1, len(self.phase_start_steps)):
                if self.phase_start_steps[i] < self.phase_end_steps[i-1]:
                    raise ValueError(f'The start step for {self.phase_configs[i]} is less than the end step for {self.phase_configs[i-1]}.')
                
            # Step3: check if end steps are less than the start steps
            for i in range(len(self.phase_start_steps)):
                if self.phase_end_steps[i] < self.phase_start_steps[i]:
                    raise ValueError(f'The end step for {self.phase_configs[i]} is less than the start step.')

    def prepare_phase_start(self, save_ckpt=True):
        """
        change the required_gradient for each parameter in the model and add threshold
        """
        tune_W_enc = self.current_phase_config.get('tune_W_enc', False)
        tune_b_enc = self.current_phase_config.get('tune_b_enc', False)
        tune_b_dec = self.current_phase_config.get('tune_b_dec', False)
        tune_neuron_weight = self.current_phase_config.get('tune_neuron_weight', False)
        tune_threshold = self.current_phase_config.get('tune_threshold', False)

        for name, param in self.model.named_parameters():
            if 'W_enc' in name:
                param.requires_grad = tune_W_enc
            elif 'b_enc' in name:
                param.requires_grad = tune_b_enc
            elif 'b_dec' in name:
                param.requires_grad = tune_b_dec
            elif 'neuron_weight' in name:
                param.requires_grad = tune_neuron_weight
            else:
                param.requires_grad = False # disable gradient for other parameters
            print(f'{name} requires gradient: {param.requires_grad}')

        use_threshold_activation = self.current_phase_config.get('use_threshold_activation', False)
        if use_threshold_activation:
            threshold = self.model.b_enc.detach().clone() # set the threshold to be tensor not nn.Parameter
            self.set_threshold(threshold) # change the sign of the threshold
        
        if save_ckpt:
            # store the model state
            ckpt_dir = self.config.checkpoint_dir
            path = os.path.join(ckpt_dir, f'phase_{self.global_phase}_start.ckpt')
            self.trainer.save_checkpoint(path)


        if tune_threshold:
            print('Tune threshold.')
            self.threshold.requires_grad = True
        else:
            self.threshold.requires_grad = False
            print('Do not tune threshold.')


    @torch.no_grad()
    def normalize_W_enc(self, return_norm=False):
        """
        Normalize W_enc in-place.
        """
        # Example: L2-normalize across all dimensions
        # If W_enc shape is [(*channel_dims), hidden_size, input_size],
        # flatten and compute global norm to make the entire tensor have norm=1
        if isinstance(self.model._W_enc, torch.nn.Parameter):
            W_enc = self.model._W_enc.data
            w_norm = torch.norm(W_enc, p=2, dim=-1, keepdim=True)
            self.model._W_enc.data /= (w_norm + 1e-8)
            # self.model.W_enc.data *= self.config.W_enc_norm
        elif isinstance(self.model._W_enc, torch.nn.ParameterList):
            w_norm_ls = []
            for i in range(len(self.model._W_enc)):
                W_enc = self.model._W_enc[i].data
                w_norm = torch.norm(W_enc, p=2, dim=-1, keepdim=True)
                self.model._W_enc[i].data /= (w_norm + 1e-8)
                w_norm_ls.append(w_norm.squeeze())

        if return_norm:
            if isinstance(self.model._W_enc, torch.nn.Parameter):
                return w_norm.squeeze()
            elif isinstance(self.model._W_enc, torch.nn.ParameterList):
                w_norm = torch.cat(w_norm_ls, dim=0)
                return w_norm.squeeze()
    
    def threshold_activation(self, pre_act: torch.Tensor) -> torch.Tensor:
        """
        Apply threshold activation to the pre-activation values.
        """
        if pre_act.device != self.threshold.device:
            self.threshold = self.threshold.to(pre_act.device)
        post_act = pre_act.clamp(min=self.threshold)
        return post_act
    
    def set_threshold(self, threshold: torch.Tensor):
        """
        Set the threshold for threshold activation.
        """
        if threshold.shape == (self.config.num_neurons,):
            threshold = threshold.unsqueeze(0)
        assert threshold.shape == (1, self.config.num_neurons), f"Threshold shape {threshold.shape} does not match the hidden size {self.config.num_neurons}!"
        # self.threshold = nn.Parameter(threshold)
        self.threshold.data.copy_(threshold)

    
    def prune_threshold(self, neuron_mask=None):
        """
        Prune the threshold based on the threshold_ratio.
        """
        if neuron_mask.shape == (self.config.num_neurons,):
            neuron_mask = neuron_mask.unsqueeze(0)
        assert neuron_mask.shape == (1, self.config.num_neurons), f"neuron_mask shape {neuron_mask.shape} does not match the hidden size {self.config.num_neurons}!"
        self.threshold = nn.Parameter(self.threshold[neuron_mask], requires_grad=False)
        

    def forward(self, 
                x: torch.Tensor, 
                SAE_output_scale=1.0, 
                neuron_mask=None,  
                topk=None, 
                use_threshold_activation=False,
                use_tunable_threshold_activation=False,
                **kwargs, 
                ) -> Tuple[torch.Tensor, dict]:
        """
        Forward pass. Returns the reconstruction and info dict from the model.
        """
        b_dec = self.model.b_dec
        W_enc = self.model.W_enc
        b_enc = self.model.b_enc

        x_centered = x - b_dec
        
        proj = torch.einsum('...ij,...j->...i', W_enc, x_centered)

        if use_threshold_activation:
            # post_act = self.threshold_activation(pre_act)
            post_act = JumpReLU(self.threshold)(proj)
            pre_act = proj - self.threshold
        elif use_tunable_threshold_activation:
            post_act = ThresholdReLU()(proj, -b_enc)
            pre_act = proj + b_enc
        else:
            pre_act = proj + b_enc
            post_act = self.model.act(pre_act) # shape: (batch_size, *channel_size_ls, hidden_size)
        
        # Adjust the post activation
        if topk is not None:
            threshold = torch.topk(pre_act, topk, dim=-1, largest=True, sorted=False).values[..., -1:] 
            mask = pre_act >= threshold
            post_act = post_act * mask
            
        if neuron_mask is not None:
            # assert neuron_mask.shape == self.b_enc.shape, f"neuron_mask shape {neuron_mask.shape} does not match the hidden size {self.hidden_size}!"
            post_act = post_act * neuron_mask.float() # Apply neuron mask
        
        if hasattr(self.model, 'neuron_weight'):
            post_act = post_act * self.model.neuron_weight
        
        x_reconstructed = torch.einsum('...ij,...i->...j', W_enc, post_act) + b_dec

        return GradRescaler.apply(SAE_output_scale * x_reconstructed, 1 / SAE_output_scale), EasyDict({
            'post_act': post_act,
            'pre_act': pre_act,
            'proj': proj,
        })
        # NOTE: We use apply method for autograd function https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function
        
    def init_activation_buffers(self):
        # We'll create buffers to accumulate neuron means and maxima across the entire epoch.
        self.sum_pre_act = torch.zeros(self.config.num_neurons, device=self.device)
        self.sum_pre_act_ge_zero = torch.zeros(self.config.num_neurons, device=self.device)
        # self.max_pre_act = torch.full((self.config.num_neurons,), float('-inf'), device=self.device)
        self.max_pre_act = torch.full((self.config.num_neurons,), float(0.0), device=self.device)
        self.sum_square_pre_act = torch.zeros(self.config.num_neurons, device=self.device)
        self.count_pre_act = 0
        
        self.sum_proj = torch.zeros(self.config.num_neurons, device=self.device)
        self.sum_proj_ge_zero = torch.zeros(self.config.num_neurons, device=self.device)
        # self.max_proj = torch.full((self.config.num_neurons,), float('-inf'), device=self.device)
        self.max_proj = torch.full((self.config.num_neurons,), float(0.0), device=self.device)
        self.sum_square_proj = torch.zeros(self.config.num_neurons, device=self.device)
        self.count_proj = 0
        
    def update_activation_buffers(self, pre_act, proj=None):
        """
        Update the activation buffers with the current batch.
        """
        with torch.no_grad():
            self.sum_pre_act += pre_act.clamp(min=0.0).sum(dim=0)  # shape: (hidden_size)
            self.sum_pre_act_ge_zero += pre_act.ge(1e-6).float().sum(dim=0) # shape: (hidden_size)
            # Update total count
            self.sum_square_pre_act += pre_act.square().sum(dim=0) # shape: (hidden_size)
            self.count_pre_act += pre_act.size(0)
        
            # Update maxima
            batch_max, _ = pre_act.max(dim=0)   # shape: (hidden_size)
            self.max_pre_act = torch.max(self.max_pre_act, batch_max)
        
            if proj is not None:
                self.sum_proj += proj.clamp(min=0.0).sum(dim=0) # shape: (hidden_size)
                self.sum_proj_ge_zero += proj.ge(1e-6).float().sum(dim=0) # shape: (hidden_size)
                self.sum_square_proj += proj.square().sum(dim=0) # shape: (hidden_size)
                self.count_proj += proj.size(0)
                # Update maxima
                batch_max, _ = proj.max(dim=0)   # shape: (hidden_size)
                self.max_proj = torch.max(self.max_proj, batch_max)

    def empty_activation_buffers(self):
        """
        Empty the activation buffers.
        """
        self.sum_pre_act.zero_()
        self.sum_pre_act_ge_zero.zero_()
        self.sum_square_pre_act.zero_()
        # self.max_pre_act.fill_(float('-inf'))
        self.max_pre_act.fill_(float(0.0))
        self.count_pre_act = 0
        self.sum_proj.zero_()
        self.sum_proj_ge_zero.zero_()
        self.sum_square_proj.zero_()
        # self.max_proj.fill_(float('-inf'))
        self.max_proj.fill_(float(0.0))
        self.count_proj = 0
    
    # Called before training starts.
    def on_fit_start(self):
        self.prepare_phase()
        self.init_activation_buffers()
        
    def on_test_start(self):
        self.prepare_phase()
        return super().on_test_start()
    
    def on_predict_start(self):
        self.prepare_phase()
        return super().on_predict_start()

    def _Step(self, batch, batch_idx, step_type='train', log=True):
        """
        Lightning calls this inside the training loop.
        """
        current_phase_config = self.current_phase_config
        
        if 'divide_by' in current_phase_config:
            batch = batch / current_phase_config.divide_by
        
        # hard normalize or soft normalize
        if 'normalize_batch' in current_phase_config and current_phase_config.normalize_batch:
            batch = batch / batch.norm(dim=-1, keepdim=True)
        elif 'normalize_batch_with_tanh_threshold' in current_phase_config: 
            threshold = current_phase_config.normalize_batch_with_tanh_threshold
            batch_norm = batch.norm(dim=-1, keepdim=True)
            target_batch_norm = torch.tanh(batch_norm / threshold) * threshold
            batch = batch * (target_batch_norm / (batch_norm + 1e-6))
        if 'random_scale' in current_phase_config:
            random_scale = [current_phase_config.random_scale[0], current_phase_config.random_scale[1]]
            sampled_scale = torch.rand(batch.size(0), device=batch.device) * (random_scale[1] - random_scale[0]) + random_scale[0]
            batch = batch * sampled_scale.unsqueeze(-1)
        
        use_threshold_activation = self.current_phase_config.get('use_threshold_activation', False)
        use_tunable_threshold_activation = self.current_phase_config.get('use_tunable_threshold_activation', False)
        
        # batch: shape (batch_size, *channel_size_ls, input_size)
        x_rescaled, info = self(batch, 
                                SAE_output_scale = current_phase_config.SAE_output_scale if 'SAE_output_scale' in current_phase_config else 1.0, 
                                neuron_mask=current_phase_config.neuron_mask if 'neuron_mask' in current_phase_config else None,
                                topk=current_phase_config.topk if 'topk' in current_phase_config else None, 
                                use_threshold_activation=use_threshold_activation,
                                use_tunable_threshold_activation=use_tunable_threshold_activation,)
        
        # if model_old is in the dictionary, we add the reconstruction from model_old
        if hasattr(self, 'model_old'):
            with torch.no_grad():
                model_old_config = self.model_old.current_phase_config
                x_rescaled_old, _ = self.model_old(batch, 
                                                SAE_output_scale = model_old_config.SAE_output_scale if 'SAE_output_scale' in current_phase_config else 1.0, 
                                                neuron_mask=model_old_config.neuron_mask if 'neuron_mask' in current_phase_config else None,
                                                topk=model_old_config.topk if 'topk' in current_phase_config else None, 
                                                use_threshold_activation=use_threshold_activation,
                                                use_tunable_threshold_activation=use_tunable_threshold_activation,)
                
            reconstruction_loss = self.reconstruction_loss(x_rescaled + x_rescaled_old, batch)
            
            alignment_loss = self.alignment_loss(x_rescaled + x_rescaled_old, batch)
            
        else:
            reconstruction_loss = self.reconstruction_loss(x_rescaled, batch)
            
            alignment_loss = self.alignment_loss(x_rescaled, batch)

        # print('reconstruction_loss', reconstruction_loss.shape)
        # compute the normalized reconstruction loss, divided by the norm of the input
        batch_l2_norm_squared = torch.norm(batch, dim=-1, p=2) ** 2 # Shape: (batch_size, 1)
        # print(batch_l2_norm_squared.mean())
        # batch_l2_norm_squared = batch_l2_norm_squared.mean()
        # reconstruction_loss = reconstruction_loss.sum(dim=-1) # shape (batch_size, 1)
        # print('reconstruction_loss', reconstruction_loss.shape)
        normalized_loss = reconstruction_loss.sum(dim=-1) / batch_l2_norm_squared # shape (batch_size, 1)
        # print('normalized_loss', normalized_loss.shape)
        normalized_loss = normalized_loss.mean()
        # reconstruction_loss = reconstruction_loss.mean()
        reconstruction_loss = reconstruction_loss.sum(dim=-1).mean()

        self.log(f'{step_type}_normalized_reconstruction_loss', normalized_loss, on_step=True, on_epoch=True, prog_bar=True) if log else None

        self.log(f'{step_type}_reconstruction_loss', reconstruction_loss, on_step=True, on_epoch=True, prog_bar=True) if log else None
        
        self.log(f'{step_type}_alignment_loss', alignment_loss, on_step=True, on_epoch=True, prog_bar=True) if log else None
        
        # ------------------------------
        # Compute the mean and max pre-activation values
        # ------------------------------
        
        # L1 regularization
        if 'L1_decay' in current_phase_config:
            lbda_1 = current_phase_config.L1_decay
        else:
            lbda_1 = 0.0
        L1_loss = lbda_1 * torch.norm(
                info['post_act'] * self.model.W_enc.norm(dim=-1).unsqueeze(0).expand_as(info['post_act']),
                p=1, dim=-1).mean()
        
        self.log(f'{step_type}_L1_loss', L1_loss, on_step=True, on_epoch=True, prog_bar=True) if log else None
        
        if 'use_alignment_loss' in current_phase_config and current_phase_config.use_alignment_loss:
            loss = alignment_loss + L1_loss
        else:
            loss = reconstruction_loss + L1_loss
        
        results = EasyDict(
            reconstruction_loss=reconstruction_loss,
            alignment_loss=alignment_loss,
            normalized_reconstruction_loss=normalized_loss,
            L1_loss=L1_loss,
            loss=loss,
            x_rescaled=x_rescaled,
            info=info,
        )
        
        # print losses for debugging purposes
        if getattr(self.config, 'debug', False) and step_type == 'train':
            for k, v in results.items():
                if 'loss' in k:
                    print(f'{step_type}_{k}: {v}')
            print(f'global_step: {self.global_step}')
        
        return results
    
    def prepare_phase_end(self, save_ckpt=True):
        if save_ckpt: 
            ckpt_dir = self.config.checkpoint_dir
            path = os.path.join(ckpt_dir, f'phase_{self.global_phase}_last.ckpt')
            self.trainer.save_checkpoint(path)
        
    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop.
        """
        current_phase_config = self.current_phase_config

        if self.global_step in self.phase_start_steps:
            print(f"Start phase {self.global_phase} at step {self.global_step}.")
            self.prepare_phase_start(save_ckpt=getattr(self.config, 'save_phase_start_ckpt', True))
        if self.global_step + 1 in self.phase_end_steps:
            print(f"End phase {self.global_phase} at step {self.global_step}.")
            self.prepare_phase_end(save_ckpt=getattr(self.config, 'save_phase_end_ckpt', True))

        step_output = self._Step(batch, batch_idx, step_type='train', log=True)
        
        self.log('train_loss', step_output.loss, on_step=True, on_epoch=True, prog_bar=True)
        
        pre_act = step_output.info['pre_act']

        # log the sparsity of the pre-activation values
        self.log('train_pre_act_sparsity', 1- pre_act.ge(1e-6).float().mean(), on_step=True, on_epoch=False, prog_bar=True)
        
        # update the activation buffers
        self.update_activation_buffers(pre_act, proj=step_output.info['proj'])

        # self.manual_backward(, opt)
        step_output.loss.backward()
        
        if 'accumulate_grad_batches' in current_phase_config:
            if (batch_idx + 1) % current_phase_config.accumulate_grad_batches == 0:
                for param in self.parameters():
                    if param.grad is not None:
                        param.grad /= current_phase_config.accumulate_grad_batches
        # log the gradient norm
        self.log_record_grad_norm()
        
        # normalize the gradient
        if 'normalize_gradient' in current_phase_config and current_phase_config.normalize_gradient:
            self.normalize_gradient()
            
        # optimizer step
        opt = self.optimizers()
        opt.step()
        opt.zero_grad()
        
        # scheduler step
        sch = self.lr_schedulers()
        if sch:
            sch.step()  
            
        
        # normalize W_enc
        if 'normalize_W_enc' in current_phase_config and current_phase_config.normalize_W_enc:
            W_enc_norm = self.normalize_W_enc(return_norm=True)
            # self.log('W_enc_norm', W_enc_norm, on_step=True, on_epoch=False, prog_bar=True)
        else:
            W_enc_norm = torch.norm(self.model.W_enc, p=2, dim=-1)
        self.logger.experiment.log({
            'W_enc_norm': W_enc_norm,
        })
        
        # adjust b_enc
        if 'adjust_b_enc_config' in current_phase_config:
            self.adjust_b_enc(current_phase_config.adjust_b_enc_config)
            
                # set b_dec to be zero
        if 'b_dec_zero' in current_phase_config and current_phase_config.b_dec_zero:
            self.model.b_dec.data.zero_()
            
        if 'clamp_b_enc_max' in current_phase_config:
            if isinstance(self.model._b_enc, torch.nn.ParameterList):
                for i in range(len(self.model._b_enc)):
                    self.model._b_enc[i].data.clamp_(max=current_phase_config.clamp_b_enc_max)
            elif isinstance(self.model._b_enc, torch.nn.Parameter):
                self.model._b_enc.data.clamp_(max=current_phase_config.clamp_b_enc_max)
            
        if 'clamp_b_enc_min' in self.current_phase_config:
            if isinstance(self.model._b_enc, torch.nn.ParameterList):
                for i in range(len(self.model._b_enc)):
                    self.model._b_enc[i].data.clamp_(min=self.current_phase_config.clamp_b_enc_min)
            elif isinstance(self.model._b_enc, torch.nn.Parameter):
                self.model._b_enc.data.clamp_(min=self.current_phase_config.clamp_b_enc_min)
        
        #FIXME: newly added
        if isinstance(self.logger, pl.loggers.WandbLogger):
            self.logger.experiment.log({
                'train_b_enc': self.model.b_enc.data,
                'train_b_enc_mean': self.model.b_enc.mean(),
                'train_b_dec': self.model.b_dec.data,
                'train_b_dec_mean': self.model.b_dec.mean(),
            })
            
    
    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop.
        """
        
        step_output = self._Step(batch, batch_idx, step_type='val', log=True)
        
        self.log('val_loss', step_output.loss, on_step=False, on_epoch=True, prog_bar=True)
        
        return step_output.loss
    
    def test_step(self, batch, batch_idx):
        """
        Lightning calls this inside the test loop.
        """
        
        step_output = self._Step(batch, batch_idx, step_type='test', log=False)
        
        self.log('test_loss', step_output.loss, on_step=False, on_epoch=True, prog_bar=True)
        
        return step_output.loss
    
    def on_validation_end(self):

        # log the MIR and HDR of the SAE and the feature set
        # if self.feat_set is not None:
            # for k, feat in self.feat_set.items():
                # W_enc: shape (num_neurons, input_size)
                # feat: shape (feat_size, input_size)
                # print(self.model.W_enc.shape, feat.shape)
                # print(self.model.W_enc.T.shape, feat.T.shape)
                # mir = MIR(self.model.W_enc.T, feat.T, dim_1=-1, dim_2=-1)
                # hdr = HDR(self.model.W_enc.T, feat.T, dim_1=-1, dim_2=-1)
                # # print(mir, hdr)
                # self.logger.experiment.log({
                #     f'val_MIR_{k}': mir,
                #     f'val_HDR_{k}': hdr,
                # })

        # log the alignment of SAE and the feature set
        if self.feat_set is not None:
            for k, feat in self.feat_set.items():
                cos_sim = check_cosine_similarity(self.model.W_enc, feat, return_tensor=True, verbose=False)
                # log the cos_sim as a heatmap
                
                # NOTE: we don't need to log the heatmap
                # self.logger.experiment.log({
                #     f'val_cosine_similarity_{k}': wandb.Image(cos_sim[self.neuron_vis])
                # })
                
                # compute the Z-score of the cosine similarity
                Z_score = dominance_metrics(cos_sim, dim=-1, metrics_to_use='Z-Score')
                # compute the maximal cosine similarity for each neuron
                max_cos_sim = cos_sim.max(dim=-1).values
                # log the Z-score as scalar
                
                max_proj_to_feat_per_neuron = max_cos_sim * self.model.W_enc.norm(dim=-1).cpu()
                
                max_proj_to_feat_plus_b_enc = max_proj_to_feat_per_neuron + self.model.b_enc.cpu()
                
                if isinstance(self.logger, pl.loggers.WandbLogger):
                    self.logger.experiment.log({
                        f'val_Z_score_cos_sim_{k}': Z_score, 
                        f'val_max_cos_sim_{k}': max_cos_sim,
                        f'val_Z_score_cos_sim_{k}_mean': Z_score.mean(),
                        f'val_max_cos_sim_{k}_mean': max_cos_sim.mean(), 
                        f'val_max_proj_to_feat_per_neuron_{k}': max_proj_to_feat_per_neuron,
                        f'val_max_proj_to_feat_plus_b_enc_{k}': max_proj_to_feat_plus_b_enc,
                    })
                
        # log the b_enc of the model 
        if isinstance(self.logger, pl.loggers.WandbLogger):
            self.logger.experiment.log({
                f'val_b_enc': self.model.b_enc.data, 
                f'val_b_dec_norm': self.model.b_dec.norm(),
                f'val_b_enc_mean': self.model.b_enc.mean(),
            })
            # self.log('val_b_enc_mean', self.model.b_enc.mean())
            if self.current_phase_config.get('use_threshold_activation', False):
                self.logger.experiment.log({
                    f'val_threshold': self.threshold.data,
                    'val_threshold_mean': self.threshold.mean()
                })
                # self.log('val_threshold_mean', self.threshold.mean())
            if hasattr(self.model, 'neuron_weight'):
                self.logger.experiment.log({
                    f'val_neuron_weight': self.model.neuron_weight.data,
                    'val_neuron_weight_mean': self.model.neuron_weight.mean()
                })
                # self.log('val_neuron_weight_mean', self.model.neuron_weight.mean())
            
    
    def configure_optimizers(self):
        # Configure the optimizer.
        optimizer_dict = {
            'SGD': torch.optim.SGD,
            'Adam': torch.optim.Adam,
            'AdamW': torch.optim.AdamW,
            'RMSprop': torch.optim.RMSprop,
            'Shampoo': Shampoo,
            'signSGD': signSGD,
            'normSGD': normSGD,
            'NormalizeSGD': NormalizeSGD,
        }
        optimizer_name = self.config.optimizer
        if optimizer_name not in optimizer_dict.keys():
            raise ValueError(f"Optimizer {optimizer_name} is not implemented!")
        else:
            optimizer = optimizer_dict[optimizer_name](
                self.parameters(),
                **self.config[f'{optimizer_name}_optimizer_config']
            )
            for group in optimizer.param_groups:
                print(group['params'])
            print(f'Optimizer: {optimizer_name}')
            for k, v in self.config[f'{optimizer_name}_optimizer_config'].items():
                print(f'{k}: {v}')

            
        # Configure the learning rate scheduler.
        if self.config.lr_scheduler == "cosine":
            cosine_scheduler_config = self.config.cosine_scheduler_config
            scheduler = CosineAnnealingWarmup(
                optimizer=optimizer,
                warmup_steps=cosine_scheduler_config.warmup_steps,
                learning_rate=self.config[f'{optimizer_name}_optimizer_config'].lr,
                min_lr=cosine_scheduler_config.min_lr,
                lr_decay_steps=cosine_scheduler_config.lr_decay_steps,
            )
        elif self.config.lr_scheduler == "step":
            StepLR_config = self.config.StepLR_scheduler_config
            scheduler = torch.optim.lr_scheduler.StepLR(
                optimizer,
                step_size=StepLR_config.step_size,
                gamma=StepLR_config.gamma,
            )
        else:
            # use no scheduler
            scheduler = None
        if scheduler is not None:
            return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
        else:
            return optimizer
    
    def lr_scheduler_step(
            self,
            scheduler: LRSchedulerTypeUnion,
            metric: Any,
    ) -> None:
        scheduler.step()
    
    def manual_backward(self, loss, opt, **kwargs):
        """
        Called after each optimizer step.
        """
        loss.backward()
        opt.step()
        
        
    @property
    def global_phase(self):
        """
        Called before the adjustment of b_enc
        """
        # Step1: check if self.phase_start_steps is an empty list
        if self.phase_start_steps == []:
            return None 
        
        for i in range(len(self.phase_start_steps)):
            if self.global_step >= self.phase_start_steps[i] and self.global_step < self.phase_end_steps[i]:
                return i
            
            if i == len(self.phase_start_steps) - 1 and self.global_step == self.phase_end_steps[i]:
                return i 
            
        return None
            
    @property
    def current_phase_config(self):
        """
        Return the current phase configuration.
        """
        current_phase = self.global_phase
        if current_phase is not None:
            keys = list(self.phase_configs.keys())
            return self.phase_configs[keys[current_phase]]
        else:
            return self.remaining_configs # return the remaining configurations
    
    def get_step_interval(self, interval: Union[int, float]) -> int:
        """
        Get the number of steps between each adjustment of b_enc.         
        Example: if config.adjust_interval = 1.0, we do it once every epoch
        If config.adjust_interval = 0.5, we do it twice per epoch, etc.
        """
        total_epoch_steps = self.trainer.num_training_batches
        if isinstance(interval, float):
            step_interval = max(int(interval * total_epoch_steps), 1)
        elif isinstance(interval, int):
            step_interval = interval
        else:
            raise ValueError("Invalid type for interval. Must be int or float.")
        return step_interval

    def aggregate_and_log_activation_buffers(self, log=True):
        denom = max(self.count_pre_act, 1)
            
        # compute the average pre-activation value
        mean_pre_act = self.sum_pre_act / denom
        fraction_pre_act_ge_zero = self.sum_pre_act_ge_zero / denom
        mean_square_pre_act = self.sum_square_pre_act / denom
        std_pre_act = (mean_square_pre_act - mean_pre_act.square()).sqrt()
        max_pre_act = self.max_pre_act
        Z_score_pre_act = (max_pre_act - mean_pre_act) / (std_pre_act + 1e-6)
        
        if isinstance(self.logger, pl.loggers.WandbLogger) and log:
            self.logger.experiment.log({
                'mean_pre_act': mean_pre_act,
                'fraction_pre_act_ge_zero': fraction_pre_act_ge_zero,
                'Z_score_pre_act': Z_score_pre_act,
                'fraction_pre_act_ge_zero_mean': fraction_pre_act_ge_zero.mean(),
                'mean_Z_score_pre_act': Z_score_pre_act.mean(),
                'max_pre_act': max_pre_act,
            })
                
        # compute the average projection value
        denom = max(self.count_proj, 1)
        mean_proj = self.sum_proj / denom
        fraction_proj_ge_zero = self.sum_proj_ge_zero / denom
        mean_square_proj = self.sum_square_proj / denom
        std_proj = (mean_square_proj - mean_proj.square()).sqrt()
        max_proj = self.max_proj
        Z_score_proj = (max_proj - mean_proj) / (std_proj + 1e-6)
        
        if isinstance(self.logger, pl.loggers.WandbLogger) and log:
            self.logger.experiment.log({
                'mean_proj': mean_proj,
                'fraction_proj_ge_zero': fraction_proj_ge_zero,
                'Z_score_proj': Z_score_proj,
                'fraction_proj_ge_zero_mean': fraction_proj_ge_zero.mean(),
                'mean_Z_score_proj': Z_score_proj.mean(),
                'max_proj': max_proj,
            })
            
        return EasyDict({
            'mean_pre_act': mean_pre_act,
            'fraction_pre_act_ge_zero': fraction_pre_act_ge_zero,
            'max_pre_act': max_pre_act,
            'std_pre_act': std_pre_act,
            
            'mean_proj': mean_proj,
            'fraction_proj_ge_zero': fraction_proj_ge_zero,
            'max_proj': max_proj,
            'std_proj': std_proj,
        })
            
    
    
    def adjust_b_enc(self, adjust_b_enc_config):
        """
        Adjust the bias term b_enc based on the mean pre-activation values.
        This method only works if the SAE is not splitted into groups.
        """
        # if self.config.adjust_b_enc is False:
        #     return # Skip if adjustment is disabled
        
        step_interval = self.get_step_interval(adjust_b_enc_config.interval)

        # The global_step is the total number of optimizer steps taken so far
        # across all epochs. So we check if we've reached a multiple of step_interval.
        if step_interval > 0 and (self.global_step - self.phase_start_steps[0]) !=0  and ((self.global_step - self.phase_start_steps[0]) % step_interval == 0):
            with torch.no_grad():
                
                # denom = max(self.count_pre_act, 1)
                
                # # compute the average pre-activation value
                # mean_pre_act = self.sum_pre_act / denom
                # fraction_ge_zero = self.sum_pre_act_ge_zero / denom
                # max_pre_act = self.max_pre_act
                
                # # log the mean pre-activation, fraction of pre-activation >= 0, and max pre-activation
                # if isinstance(self.logger, pl.loggers.WandbLogger):
                #     self.logger.experiment.log({
                #         'mean_pre_act': mean_pre_act,
                #         'fraction_ge_zero': fraction_ge_zero,
                #         'fraction_ge_zero_mean': fraction_ge_zero.mean(),
                #         'max_pre_act': max_pre_act
                #     })
                # aggregate the activation buffers
                aggregated_info = self.aggregate_and_log_activation_buffers()
                fraction_pre_act_ge_zero = aggregated_info.fraction_pre_act_ge_zero
                max_pre_act = aggregated_info.max_pre_act
                
                if adjust_b_enc_config.get('adjust_b_enc', False):
                    # detect if the neuron is never activated or too often activated
                    never_activated = (fraction_pre_act_ge_zero < adjust_b_enc_config.freq_threshold_low)
                    too_often_activated = (fraction_pre_act_ge_zero > adjust_b_enc_config.freq_threshold_high)
                    
                    # adjust the bias term
                    b_enc = self.model.b_enc.data
                    
                    # if the neuron is never activated, increase the bias term
                    b_enc[never_activated] += adjust_b_enc_config.factor_up * max_pre_act[torch.logical_not(never_activated)].mean() if torch.any(torch.logical_not(never_activated)) else adjust_b_enc_config.factor_up * 1e-3
                    self.logger.experiment.log({
                        f'max_pre_act_moderate_activated_1': max_pre_act[torch.logical_not(never_activated)].mean(),
                    })
                    
                    # if the neuron is too often activated, decrease the bias term
                    b_enc[too_often_activated] -= adjust_b_enc_config.factor_down * max_pre_act[too_often_activated]
                    self.logger.experiment.log({
                        f'max_pre_act_too_often_activated_1': max_pre_act[too_often_activated].mean(),
                    })
                    
                    # # if the neuron is not in the above two categories, we apply a small dicrease in the bias term 
                    # b_enc[torch.logical_not(torch.logical_or(never_activated, too_often_activated))] -= adjust_b_enc_config.L1_decay
                    
                    # clamp the bias term below zero
                    if 'clamp_b_enc_max' in adjust_b_enc_config:
                        b_enc.clamp_(max=adjust_b_enc_config.clamp_b_enc_max)
                
                # empty the buffers
                self.empty_activation_buffers()
    
    def on_train_epoch_end(self):
        pass
    
    def prune_neurons(self, neuron_mask, verbose=True):
        """
        Prune the neurons based on the neuron_mask.
        """
        self.model.prune_neurons(neuron_mask, verbose=verbose)
        self.config.num_neurons = self.model.hidden_size
        # check if self has hyperparameters
        if hasattr(self, 'hparams'):
            self.hparams.num_neurons = self.model.hidden_size
    
    def log_record_grad_norm(self):
        """
        Log the gradient norm for each parameter.
        """
        with torch.no_grad():
            for name, param in self.named_parameters():
                if param.grad is not None:
                    grad_norm = param.grad.norm(dim=-1).mean()
                    self.log(f'{name}_grad_norm', grad_norm, on_step=True, on_epoch=False, prog_bar=True)
                    if hasattr(self.grad_norm, name):
                        self.grad_norm[name] = 0.95 * self.grad_norm[name] + 0.05 * grad_norm
                    else:
                        self.grad_norm[name] = grad_norm
            if getattr(self.config, 'debug', False):
                for name, param in self.named_parameters():
                    if param.grad is not None:
                        print(f'{name}_grad_norm: {self.grad_norm[name]}')
                        
                print(f'global_step: {self.global_step}')
    
    def normalize_gradient(self):
        """
        Normalize the gradient for each parameter.
        """
        cum_grad_norm = 0
        with torch.no_grad():
            for name, grad_norm in self.grad_norm.items():
                cum_grad_norm += grad_norm
                
        for name, param in self.named_parameters():
            if param.grad is not None:
                param.grad /= cum_grad_norm

def get_experiment_name(config: EasyDict) -> str:
    """
    Generate a string to name your experiment based on the configuration parameters.
    Adjust this function to include whatever parameters are most relevant.
    """
    # Build a list of name components
    optimizer = config.optimizer
    learning_rate = config[optimizer + '_optimizer_config'].lr
    weight_decay = config[optimizer + '_optimizer_config'].weight_decay
    name_parts = [
        "SAE",  # or something else to indicate the mode
        f'{config.data_config.data_type}',
        f"h{config.hidden_size}",
        f"n{config.num_neurons}",
        f"act-{config.activation}",
        f"lr{learning_rate}",
        f"opt-{optimizer}",
        f"wd{weight_decay}",
        f"bs{config.batch_size}",
    ]
    
    name_parts.append(f"topk-{config.topk}") if 'topk' in config else None
    name_parts.append(f"L1-{config.L1_decay}") if 'L1_decay' in config else None
    # Join all parts using an underscore
    experiment_name = "_".join(name_parts)

    return experiment_name

def get_run_name(config: EasyDict) -> str:
    """
    Generate a string to name your run based on the configuration parameters.
    Adjust this function to include whatever parameters are most relevant.
    """
    # Build a list of name components
    run_name_parts = [
        "seed-" + str(config.seed),
        datetime.now().strftime("%Y%m%d_%H%M%S"),
    ]
    
    return "_".join(run_name_parts)


def set_up_experiment(config, 
                      data_config, 
                      get_experiment_name, 
                      get_run_name, 
                      DataModuleClass=SynSAEDataModule,
                      ModelClass=LitSAEWithChannel,
                      debug=False):
    """
    Set up the experiment with the given configuration.
    """
    data_module = DataModuleClass(**data_config)
    # Create the model
    model = ModelClass(config, feat_set=data_module.feat_set)
    
    # seed everything
    if config.seed is not None:
        pl.seed_everything(config.seed)
    else:
        seed = random.randint(0, 1000)
        pl.seed_everything(seed)
    config.seed = seed  # save the seed in the config
    print(f'Using seed {seed}')
    
    run_name = get_run_name(config)
    exp_name = get_experiment_name(config)
    
    logger = pl.loggers.WandbLogger(
        project=config.wandb_config.wandb_project,
        entity=config.wandb_config.wandb_entity,
        group=exp_name,
        name=run_name, 
        config = config
    ) if not debug else None
    
    checkpoint_dir = os.path.join(current_dir, f'checkpoints/{exp_name}-{run_name}')
    
    config.checkpoint_dir = checkpoint_dir

    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath=checkpoint_dir,
        filename='{epoch}-{val_loss:.4f}',
        monitor='val_loss', # this depends on logging in the LightningModule
        mode='min',
        save_top_k=1,
        save_last=True
    )
    
    lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='step')
    callbacks = [lr_monitor, checkpoint_callback] if not debug else []
    
    trainer_kwargs = dict(
        max_epochs=config.max_epochs,
        logger=logger,
        callbacks=callbacks,
        val_check_interval=getattr(config, 'val_check_interval', 1.0),
        precision=getattr(config, 'precision', '32'),
        accelerator='gpu', 
        # devices=1,
    )
    
    if getattr(trainer_kwargs, 'precision', None) is not None:
        print(f'Using {trainer_kwargs["precision"]} precision.')
    
    return data_module, model, trainer_kwargs

import re
def resume_experiment(ckpt_path, 
                      config, 
                      data_config, 
                      get_experiment_name, 
                      get_run_name, 
                      DataModuleClass=SynSAEDataModule,
                      ModelClass=LitSAEWithChannel,
                      debug=False):
    """
    Resume the experiment from the checkpoint.
    """
    data_module = DataModuleClass(**data_config)
    
    model = ModelClass.load_from_checkpoint(ckpt_path, config=config, feat_set=data_module.feat_set)
    
    # get the seed from the ckpt path by parsing the string
    string_with_seed = ckpt_path.split('seed-')[1]
    # use regular expression to match the first integer
    seed = int(re.search(r'\d+', string_with_seed).group())
    # seed = int(ckpt_path.split('seed-')[1].split('_')[0])

    config.seed = seed
    
    # get the global step from the ckpt
    ckpt_dict = torch.load(ckpt_path)
    config.last_global_step = ckpt_dict['global_step']
    
    run_name = get_run_name(config)
    exp_name = get_experiment_name(config)
    run_name += f' - {config.last_global_step}'
    
    logger = pl.loggers.WandbLogger(
        project=config.wandb_config.wandb_project,
        entity=config.wandb_config.wandb_entity,
        group=exp_name,
        name=run_name, 
        config = config
    ) if not debug else None
    
    checkpoint_dir = os.path.join(current_dir, f'checkpoints/{exp_name}-{run_name}')
    
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath=checkpoint_dir,
        filename='{epoch}-{val_loss:.4f}',
        monitor='val_loss', # this depends on logging in the LightningModule
        mode='min',
        save_top_k=1,
        save_last=True
    )
    
    lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='step')
    callbacks = [lr_monitor, checkpoint_callback] if not debug else []
    
    trainer_kwargs = dict(
        max_epochs=config.max_epochs,
        logger=logger,
        callbacks=callbacks,
        val_check_interval=getattr(config, 'val_check_interval', 1.0),
        precision=getattr(config, 'precision', '32'),
        accelerator='gpu', 
        # devices=1,
    )
    
    if getattr(trainer_kwargs, 'precision', None) is not None:
        print(f'Using {trainer_kwargs["precision"]} precision.')
    
    return data_module, model, trainer_kwargs

def generate_post_act(model, x, batch_size = 256, verbose=True):
    """
    model: the SAE model
    data: tensor of shape (num_samples, hidden_size)

    Returns:
        post_act: tensor of shape (num_samples, num_neurons)
    """
    
    post_act_list = []
    loss = 0  
    alignment_loss = 0
    model.eval()
    with torch.no_grad():
        for i in tqdm(range(0, x.size(0), batch_size), desc="Generating post activations"):
            x_batch = x[i:i + batch_size]
            x_reconstructed, info = model(x_batch, SAE_output_scale=model.config.SAE_output_scale)
            loss += model.reconstruction_loss(x_reconstructed, x_batch).detach().cpu().item() * x_batch.size(0)
            alignment_loss += model.alignment_loss(x_reconstructed, x_batch).detach().cpu().item() * x_batch.size(0)
            post_act_list.append(info['post_act'])
    post_act = torch.cat(post_act_list, dim=0)
    loss /= x.size(0)
    alignment_loss /= x.size(0)
    if verbose:
        print(f"Reconstruction loss: {loss}")
        print(f"Alignment loss: {alignment_loss}")
    return post_act


if __name__ == '__main__':

    # ------------------------------
    # data configurations
    # ------------------------------

    feat_set = {
        'feat0': os.path.join(current_dir, 'data/basis_d32_0.pt'),
        'feat1': os.path.join(current_dir, 'data/basis_d32_1.pt'),
        'feat2': os.path.join(current_dir, 'data/basis_d32_2.pt')
    }
    # feat_set = {
    #     'feat0': os.path.join(parent_dir, 'data/L0_Aout_var_0_feat.pt'),
    #     'feat1': os.path.join(parent_dir, 'data/L0_Aout_var-2_feat.pt'),
    #     'feat2': os.path.join(parent_dir, 'data/L0_Aout_var-4_feat.pt')
    # }

    batch_size = 128
    num_workers = 3
    num_samples = 10_0000
    # input_norm = 16
    input_norm = None
    ### Balanced data configuration
    data_config = EasyDict(
        feat_set=feat_set,
        batch_size=batch_size,
        num_workers=num_workers,
        num_samples=num_samples,
        input_norm=input_norm,
        data_type='balanced',
    )
    # ### Perturbed data configuration
    # data_config = EasyDict(
    #     feat_set=feat_set,
    #     batch_size=batch_size,
    #     num_workers=num_workers,
    #     num_samples=num_samples,
    #     data_type='perturbed',
    #     clamp_min=0.5,
    #     clamp_max=1.5,
    # )
    # ### Imbalanced data configuration
    # data_config = EasyDict(
    #     feat_set=feat_set,
    #     batch_size=batch_size,
    #     num_workers=num_workers,
    #     num_samples=num_samples,
    #     data_type='imbalanced',
    #     factor_1=0.8,
    #     factor_2=1.2,
    #     factor_3=1.0,
    # )

    ### Model configuration

    hidden_size = 32 # be consistent with the input size of the model
    model_config = EasyDict(
        hidden_size = hidden_size, # Number of input features
        num_neurons = 1024, # Number of neurons in the hidden layer
        activation = 'relu', # Activation function to use
        channel_size_ls = [],
        use_neuron_weight = True, # Use neuron weight
    )

    train_config = EasyDict(clever_load(os.path.join(parent_dir, 'Simtransformer/simtransformer/configurations/train_config_default.yaml')))
    train_config.update(
        {
            'max_epochs': 1000,
            'batch_size': data_config.batch_size,
            'optimizer': 'AdamW',
            'wandb_config':{
                'wandb_project': 'SAE_synthetic_data',
                'wandb_entity': 'SAE_atomic', 
            },
            'num_neuron_vis': 50, # Number of neurons to visualize
            'seed': None,
            'lr_scheduler': None,
            'normalize_W_enc': True,
            'SAE_output_scale': 1,
            'b_dec_zero': True,
        })


    train_config.update(
        {
            'phase_0_config': {
                'start_step': 0,
                'end_step': 4e5,
                # 'end_step': 5, 
                'clamp_b_enc_max': 0.0,
                'tune_W_enc': True, # Tune W_enc
                'tune_b_enc': False, # b_enc requires_grad
                # 'use_alignment_loss': False,
                'use_alignment_loss':True,
                'adjust_b_enc_config': {
                    'interval': 0.2, # adjust b_enc every 20% of an epoch
                    'factor_up': 0.2, 
                    'factor_down': 0.2,
                    'freq_threshold_low': 1 / (data_config.num_samples * 0.2), 
                    'freq_threshold_high': 0.002,
                    'clamp_b_enc_max': 0.0,
                },
                'L1_decay': 1e-6,
            },
            'phase_1_config': {
                'start_step': 4e5,
                'end_step': 10e6,
                'tune_neuron_weight': True, # Tune neuron weight
                'tune_b_enc': False, # Tune b_enc
                'tune_W_enc': True, # Tune W_enc
                'use_threshold_activation': True,
                #     'adjust_b_enc_config': {
                #     'interval': 0.5, # adjust b_enc every 20% of an epoch
                #     'factor_up': 1e-4, 
                #     'factor_down': 1e-4,
                #     'freq_threshold_low': 1 / (data_config.num_samples * 0.2), 
                #     'freq_threshold_high': 0.2,
                #     'clamp_b_enc_max': 0.0,
                # },
                'L1_decay': 1e-6,
            },
            # 'phase_2_config': {
            #     'start_step': 3e5,
            #     'end_step': 1e10,
            #     'adjust_b_enc_config': {
            #         'interval': 0.5, # adjust b_enc every 20% of an epoch
            #         'factor_up': 1e-4, 
            #         'factor_down': 1e-4,
            #         'freq_threshold_low': 1 / (data_config.num_samples * 0.2), 
            #         'freq_threshold_high': 0.05,
            #         'clamp_b_enc_max': 0.0
            #     }, 
            #     'L1_decay': 1e-5,
            # },
        }
    )
    config = EasyDict({
        'data_config': data_config,
        **model_config,
        **train_config
    })      
    # config = EasyDict(
    #     # **data_config,
    #     **model_config,
    #     **train_config
    # )
    
     # Create the model
    debug = False # FIXME: set to False when running the script
    
    data_module = SynSAEDataModule(**data_config)
    
    # ckpt_path = os.path.join(current_dir, 'checkpoints/SAE_h256_n1024_act-relu_lr0.001_opt-AdamW_wd0.01_bs128_os0.1_AdjBenc_ClpBenc_bdec0-seed-100 - 2025-01-11-15:10:38/last.ckpt')
    ckpt_path = None
    
    if ckpt_path is not None:
        model = LitSAEWithChannel.load_from_checkpoint(ckpt_path, config=config, feat_set=data_module.feat_set)
        
        # get the seed from the ckpt path by parsing the string
        seed = int(ckpt_path.split('seed-')[1].split(' - ')[0])
        config.seed = seed
        
        # get the global step from the ckpt
        ckpt_dict = torch.load(ckpt_path)
        config.last_global_step = ckpt_dict['global_step']
        
    else:
        # init the model
        model = LitSAEWithChannel(config, feat_set=data_module.feat_set)
        
    # seed everything
    if config.seed is not None:
        pl.seed_everything(config.seed)
    else:
        seed = random.randint(0, 1000)
        pl.seed_everything(seed)
    config.seed = seed  # save the seed in the config
    print(f'Using seed {seed}')
    
    if 'last_global_step' in config:
        run_name = 'seed-' + str(config.seed) + ' - ' + datetime.now().strftime("%Y%m%d_%H%M%S")  + f'- {config.last_global_step}'
    else:
        run_name = 'seed-' + str(config.seed) + ' - ' + datetime.now().strftime("%Y%m%d_%H%M%S")  
    exp_name = get_experiment_name(config)
    
    logger = pl.loggers.WandbLogger(
        project=config.wandb_config.wandb_project,
        entity=config.wandb_config.wandb_entity,
        group=exp_name,
        name=run_name, 
        config = config,
        reinit = True
    ) if not debug else None
    
    checkpoint_dir = os.path.join(current_dir, f'checkpoints/{exp_name}-{run_name}')
    
    # make sure the checkpoint directory exists
    os.makedirs(checkpoint_dir, exist_ok=True)
    # save the configuration
    with open(os.path.join(checkpoint_dir, 'config.yaml'), 'w') as f:
        yaml.dump(config, f,  default_flow_style=False)

    config.checkpoint_dir = checkpoint_dir

    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath=checkpoint_dir,
        filename='{epoch}-{val_loss:.4f}',
        monitor='val_loss', # this depends on logging in the LightningModule
        mode='min',
        save_top_k=1,
        save_last=True
    )
    
    lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='step')
    callbacks = [lr_monitor, checkpoint_callback]
    
    trainer_kwargs = dict(
        max_epochs=config.max_epochs,
        logger=logger,
        callbacks=callbacks,
        val_check_interval=getattr(config, 'val_check_interval', 1.0),
        precision=getattr(config, 'precision', '32'),
        accelerator='gpu', 
        # devices=1,
    )
    
    if getattr(trainer_kwargs, 'precision', None) is not None:
        print(f'Precision set to {trainer_kwargs.precision}.')
        
    trainer = pl.Trainer(**trainer_kwargs)
    trainer.fit(model, data_module)