from datasets import load_dataset
from transformer_lens import HookedTransformer
from torch.utils.data import Dataset, DataLoader, random_split
import os, sys

import multiprocessing
from tqdm import tqdm, trange
# os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["VSCODE_PROXY_CUDA_DEVICE"] # FIXME: remove this line in sbatch script
current_dir = os.path.dirname(os.path.realpath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)
grandpa_dir = os.path.dirname(parent_dir)
sys.path.append(grandpa_dir)

import random
import torch
import torch.nn as nn
import lightning.pytorch as pl
import wandb
import yaml

from Simtransformer.simtransformer.utils import (
    EasyDict, clever_load, check_cosine_similarity,  Shampoo, signSGD, normSGD, NormalizeSGD, CosineAnnealingWarmup, dominance_metrics, AlignmentLoss,
    MIR, HDR,
    clean_config
)

from Simtransformer.simtransformer.model_base import SAEWithChannel, GradRescaler, SparseAutoEncoder
from typing import List, Tuple, Union, Any, Optional
from lightning.pytorch.utilities.types import LRSchedulerTypeUnion
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch import seed_everything
import lightning
import time
import copy
from Group_SAE.SAE_model_v2 import LitSAEWithChannel
from datetime import datetime
import matplotlib.pyplot as plt
import h5py
import argparse
import math



class HDF5Dataset(Dataset):
    def __init__(self, h5_path):
        self.h5_path = h5_path
        # Open the file briefly to get the number of samples.
        with h5py.File(self.h5_path, 'r') as f:
            self.length = f["non_padding_cache"].shape[0]
        self.h5_file = None
        
    def __len__(self):
        return self.length
    
    def _init_file(self):
        # Initialize the file handle if it hasn't been done yet.
        if self.h5_file is None:
            self.h5_file = h5py.File(self.h5_path, 'r')
            
    def __getitem__(self, idx):
        # Open the file each time in __getitem__ to avoid issues in multi-worker setups.
        self._init_file()
        sample = self.h5_file["non_padding_cache"][idx]
        # with h5py.File(self.h5_path, 'r') as f:
        #     sample = f["non_padding_cache"][idx]
        return sample

class BatchHDF5DataModule(pl.LightningDataModule):
    def __init__(self, h5_path, batch_size=128, num_workers=3, split_ratio=0.9, seed=42, **kwargs):
        super().__init__()
        self.h5_path = h5_path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.split_ratio = split_ratio  # e.g., 0.8 means 80% training, 20% validation
        self.seed = seed

    def setup(self, stage=None):
        # Create the full dataset
        full_ds = HDF5Dataset(self.h5_path)
        dataset_length = len(full_ds)
        train_length = int(self.split_ratio * dataset_length)
        val_length = dataset_length - train_length
        
        # Optionally, set a manual seed for reproducibility
        generator = torch.Generator().manual_seed(self.seed)
        
        self.train_ds, self.val_ds = random_split(
            full_ds, [train_length, val_length], generator=generator
        )
    
    def train_dataloader(self):
        return DataLoader(
            self.train_ds, 
            batch_size=self.batch_size, 
            shuffle=True, 
            num_workers=self.num_workers, 
            # persistent_workers=True,
            pin_memory=True, 
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_ds, 
            batch_size=self.batch_size, 
            shuffle=False, 
            num_workers=self.num_workers, 
            # persistent_workers=True, # since we are using running buffer, we should remove this persistent_workers
            pin_memory=True,
        )


# An in-memory dataset that holds a contiguous chunk (buffer) from the HDF5 file.
class BufferedHDF5Dataset(Dataset):
    def __init__(self, h5_path, start_idx, end_idx):
        """
        Loads data from start_idx to end_idx from the HDF5 file into memory.
        """
        self.h5_path = h5_path
        self.start_idx = start_idx
        self.end_idx = end_idx
        with h5py.File(self.h5_path, 'r') as f:
            # Read the specified slice into a NumPy array.
            self.buffer_data = f["non_padding_cache"][start_idx:end_idx]

    def __len__(self):
        return self.buffer_data.shape[0]

    def __getitem__(self, idx):
        # Return one sample (you can wrap in a dict if needed).
        # return {"features": self.buffer_data[idx]}
        return self.buffer_data[idx]

# A DataModule that loads one buffer (chunk) at a time.
class BufferedBatchHDF5DataModule(pl.LightningDataModule):
    def __init__(self, h5_path, buffer_size_samples, batch_size=128, num_workers=0, split_ratio=0.9, seed=42, **kwargs):
        """
        h5_path: Path to the HDF5 file.
        buffer_size_samples: Number of samples to load into memory (the buffer size).
        batch_size: Batch size for the DataLoader.
        num_workers: Number of workers for DataLoader (0 means data loaded in main process).
        split_ratio: Fraction for training split (if using a validation split on the buffer).
        seed: For reproducible splits.
        """
        super().__init__()
        self.h5_path = h5_path
        self.buffer_size_samples = int(buffer_size_samples)
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.split_ratio = split_ratio # ratio for training set
        self.seed = seed

        # Get total number of samples in the file.
        with h5py.File(self.h5_path, 'r') as f:
            self.total_samples = f["non_padding_cache"].shape[0]

        # Start pointer for the buffer.
        self.current_buffer_start = 0
        
        # Check if the buffer size is valid.
        assert self.train_sample_end > 0, "The training set should not be empty."
        assert self.val_sample_start < self.total_samples, "The validation set should not be empty."

    @property
    def train_sample_end(self):
        return int(self.split_ratio * self.total_samples)
    @property
    def val_sample_start(self):
        return int(self.split_ratio * self.total_samples)
    
    def setup(self, stage=None):
        # Load the initial buffer.
        self.load_current_buffer()
        
        self.val_ds = BufferedHDF5Dataset(
            self.h5_path, 
            self.val_sample_start, 
            self.total_samples,
        )

    def load_current_buffer(self):
        start = self.current_buffer_start
        end = min(self.current_buffer_start + self.buffer_size_samples, self.train_sample_end)
        print(f"Loading buffer: samples {start} to {end}")
        # Create an in-memory dataset from the current buffer.
        self.train_ds = BufferedHDF5Dataset(self.h5_path, start, end)
        # (Optional) If you want a train/validation split on the buffer, you could split here.
        # if stage == "fit" or stage is None:
        # dataset_length = len(self.current_dataset)
        # train_length = int(self.split_ratio * dataset_length)
        # val_length = dataset_length - train_length
        # generator = torch.Generator().manual_seed(self.seed)
        # self.train_ds, self.val_ds = torch.utils.data.random_split(
        #     self.current_dataset, [train_length, val_length], generator=generator
        # )
        pass
        

    def train_dataloader(self, **kwargs):
        # Return a DataLoader for the current training dataset.
        return DataLoader(
            self.train_ds,
            batch_size=self.batch_size if 'batch_size' not in kwargs else kwargs['batch_size'],
            shuffle=True,
            num_workers=self.num_workers if 'num_workers' not in kwargs else kwargs['num_workers'],
            pin_memory=True,
        )

    def val_dataloader(self, **kwargs):
        # Return a DataLoader for the current validation dataset.
        return DataLoader(
            self.val_ds,
            batch_size=self.batch_size if 'batch_size' not in kwargs else kwargs['batch_size'],
            shuffle=False,
            num_workers=0, # don't need multiple workers for validation
            pin_memory=True,
        )

    def update_buffer(self):
        # Move the pointer to the next buffer chunk.
        self.current_buffer_start += self.buffer_size_samples
        if self.current_buffer_start >= self.train_sample_end:
            print("All buffers processed. Restarting from beginning.")
            self.current_buffer_start = 0
        # Reload the dataset with the new buffer.
        self.load_current_buffer()

# A callback to update the buffer at the end of each training epoch.
class BufferUpdateCallback(pl.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        datamodule = trainer.datamodule
        if hasattr(datamodule, "update_buffer"):
            print("Updating buffer after epoch.")
            datamodule.update_buffer()
            # If you want to force the trainer to use the new buffer,
            # you may need to reset the train dataloader.
            # trainer.reset_train_dataloader()

class GroupedAdamWOptimizer(torch.optim.AdamW):
    def __init__(self, params, **kwargs):
        super(GroupedAdamWOptimizer, self).__init__(params, **kwargs)
        
        # Initialize state for each parameter
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                # Initialize state for AdamW
                state['step'] = 0
                state['exp_avg'] = torch.zeros_like(p.data)
                state['exp_avg_sq'] = torch.zeros_like(p.data)
        
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
            
        for group in self.param_groups:
            beta1, beta2 = group['betas']
            eps = group['eps']
            weight_decay = group['weight_decay']
            
            for p in group['params']:
                if p.grad is None:
                    continue
                    
                grad = p.grad.data
                state = self.state[p]
                
                # Update step count
                state['step'] += 1
                
                # Decay the first and second moment running average coefficient
                state['exp_avg'].mul_(beta1).add_(grad, alpha=1 - beta1)
                state['exp_avg_sq'].mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                
                # Compute bias-corrected first and second moment estimates
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                
                # If this is W_enc and we have group information
                if p is self.param_groups[0]['params'][0] and 'group_indices' in group:
                    # Apply weight decay
                    if weight_decay != 0:
                        p.data.mul_(1 - group['lr'][-1] * weight_decay)  # Use last lr for weight decay
                    
                    # Update each group with its specific learning rate
                    group_indices = group['group_indices']
                    group_lrs = group['lr'] 
                    
                    # Update the first group
                    start_idx = 0
                    end_idx = group_indices[0]
                    step_size = group_lrs[0] / bias_correction1
                    denom = (state['exp_avg_sq'][start_idx:end_idx].sqrt() / math.sqrt(bias_correction2)).add_(eps)
                    p.data[start_idx:end_idx].addcdiv_(
                        state['exp_avg'][start_idx:end_idx],
                        denom,
                        value=-step_size
                    )
                    
                    # Update the remaining groups
                    for i in range(1, len(group_indices)):
                        start_idx = group_indices[i-1]
                        end_idx = group_indices[i]
                        step_size = group_lrs[i] / bias_correction1
                        denom = (state['exp_avg_sq'][start_idx:end_idx].sqrt() / math.sqrt(bias_correction2)).add_(eps)
                        p.data[start_idx:end_idx].addcdiv_(
                            state['exp_avg'][start_idx:end_idx],
                            denom,
                            value=-step_size
                        )
                else:
                    # Apply weight decay
                    if weight_decay != 0:
                        p.data.mul_(1 - group['lr'] * weight_decay)
                    
                    # Compute step size
                    step_size = group['lr'] / bias_correction1
                    
                    # Compute denominator
                    denom = (state['exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)).add_(eps)
                    
                    # Update the entire parameter with its learning rate
                    p.data.addcdiv_(state['exp_avg'], denom, value=-step_size)
                    
        return loss

class LitSAEWithChannelNew(LitSAEWithChannel):
    
    def configure_optimizers(self):
        # Configure the optimizer.
        optimizer_dict = {
            'SGD': torch.optim.SGD,
            'Adam': torch.optim.Adam,
            'AdamW': torch.optim.AdamW,
            'RMSprop': torch.optim.RMSprop,
            'Shampoo': Shampoo,
            'signSGD': signSGD,
            'normSGD': normSGD,
            'NormalizeSGD': NormalizeSGD,
            'GroupedAdamW': GroupedAdamWOptimizer,
        }
        optimizer_name = self.config.optimizer
        if optimizer_name not in optimizer_dict.keys():
            raise ValueError(f"Optimizer {optimizer_name} is not implemented!")
        else:
            lr = self.config[f'{optimizer_name}_optimizer_config'].lr
            adjust_b_enc_config = self.config.phase_0_config.adjust_b_enc_config
            
            param_groups = []
            # Add parameter groups for b_enc and b_dec
            param_groups += [{'params': [self.model.b_dec], 'lr': lr,}]
            
            # Create parameter groups with their respective learning rates
            if 'group_partitions' in adjust_b_enc_config:
                param_groups += [
                    {'params': [self.model._W_enc[i]], 
                    'lr': adjust_b_enc_config[f'group_{i+1}'].get('lr', lr),} 
                    for i in range(len(self.group_indices))
                ]
                param_groups += [
                    {'params': [self.model._b_enc[i]], 
                    'lr': adjust_b_enc_config[f'group_{i+1}'].get('lr', lr),} 
                    for i in range(len(self.group_indices))
                ]
            else:
                param_groups += [
                    {'params': [self.model._W_enc], 'lr': lr,}, 
                    {'params': [self.model._b_enc], 'lr': lr,}
                ]

            # Create the optimizer with the parameter groups
            optimizer = optimizer_dict[optimizer_name](
                param_groups,
                **self.config[f'{optimizer_name}_optimizer_config']
            )
            
            # # Handle W_enc with group-specific learning rates
            # group_lrs = []
            # for i in range(len(self.group_indices)):
            #     group_key = f'group_{i+1}'
            #     if group_key in adjust_b_enc_config:
            #         # Case 1: Group has its own 'lr'
            #         if 'lr' in adjust_b_enc_config[group_key]:
            #             group_lrs.append(adjust_b_enc_config[group_key]['lr'])
            #         # Case 2: Group has 'group_lrs'
            #         elif 'group_lrs' in adjust_b_enc_config:
            #             group_lrs.append(adjust_b_enc_config['group_lrs'][i])
            #         else:
            #             group_lrs.append(lr)
            #     else:
            #         group_lrs.append(lr)
            
            # param_groups.append({
            #     'params': [self.model.W_enc],
            #     'lr': group_lrs,  # Store group learning rates in 'lr' key
            #     'group_indices': self.group_indices
            # })
            
            # Print optimizer configuration
            for group in optimizer.param_groups:
                print(f"Params: {group['params']}")
                print(f"Learning rate: {group['lr']}")
            print(f'Optimizer: {optimizer_name}')
            for k, v in self.config[f'{optimizer_name}_optimizer_config'].items():
                print(f'{k}: {v}')
            
        # Configure the learning rate scheduler.
        if self.config.lr_scheduler == "cosine":
            cosine_scheduler_config = self.config.cosine_scheduler_config
            scheduler = CosineAnnealingWarmup(
                optimizer=optimizer,
                warmup_steps=cosine_scheduler_config.warmup_steps,
                learning_rate=[group['lr'] for group in optimizer.param_groups], 
                min_lr=cosine_scheduler_config.min_lr,
                lr_decay_steps=cosine_scheduler_config.lr_decay_steps,
            )
        elif self.config.lr_scheduler == "step":
            StepLR_config = self.config.StepLR_scheduler_config
            scheduler = torch.optim.lr_scheduler.StepLR(
                optimizer,
                step_size=StepLR_config.step_size,
                gamma=StepLR_config.gamma,
            )
        else:
            # use no scheduler
            scheduler = None
        if scheduler is not None:
            return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
        else:
            return optimizer
    # FIXME: ValueError: can't optimize a non-leaf Tensor
    
    def group_start(self, group_idx):
        """
        Get the start index of the group. Note that group_idx starts from 1. 
        """
        return self.group_indices[group_idx - 2] if group_idx > 1 else 0
        
        
    def group_end(self, group_idx):
        """
        Get the end index of the group. Note that group_idx starts from 1. 
        """
        return self.group_indices[group_idx - 1]
    
    @property 
    def group_partitions(self):
        return self.config.phase_0_config.adjust_b_enc_config.group_partitions
    
    @property 
    def group_indices(self):
        if not hasattr(self, '_group_indices'):
            self._group_indices = self.get_group_index(self.group_partitions, self.config.num_neurons)
        return self._group_indices

    def adjust_b_enc(self, adjust_b_enc_config):
        """
        Adjust the bias term b_enc based on the mean pre-activation values.
        """
        # if self.config.adjust_b_enc is False:
        #     return # Skip if adjustment is disabled
        
        step_interval = self.get_step_interval(adjust_b_enc_config.interval)
        
        # The global_step is the total number of optimizer steps taken so far
        # across all epochs. So we check if we've reached a multiple of step_interval.
        if step_interval > 0 and (self.global_step - self.phase_start_steps[0]) !=0  and ((self.global_step - self.phase_start_steps[0]) % step_interval == 0):
            with torch.no_grad():
                
                # NOTE: Newly added, aggregate the activation buffers
                aggregated_info = self.aggregate_and_log_activation_buffers()
                fraction_pre_act_ge_zero = aggregated_info.fraction_pre_act_ge_zero
                max_pre_act = aggregated_info.max_pre_act
                
                if adjust_b_enc_config.get('adjust_b_enc', False): # NOTE: This is a global flag, which aims to work even if an individual group's flag is not provided. If turned off, then no adjustment will be made no matter what the group flag is; If turned on, then adjustment will be applied to all groups unless specified in each individual group's flag.
                    # b_enc = self.model.b_enc.data
                    ## adjust the bias term by group
                    group_num = len(adjust_b_enc_config.group_partitions)
                    neuron_index = self.get_group_index(adjust_b_enc_config.group_partitions, self.config.num_neurons)
                    for i in range(group_num):
                        b_enc = self.model._b_enc[i].data
                        freq_threshold_low = adjust_b_enc_config[f'group_{i + 1}']['freq_threshold_low']
                        freq_threshold_high = adjust_b_enc_config[f'group_{i + 1}']['freq_threshold_high']
                        if i == 0:
                            start_idx = 0
                        else:
                            start_idx = neuron_index[i-1]
                        end_idx = neuron_index[i]
                        # detect if the neuron(in the group) is never activated or too often activated
                        never_activated = (fraction_pre_act_ge_zero[start_idx:end_idx] < freq_threshold_low)
                        too_often_activated = (fraction_pre_act_ge_zero[start_idx:end_idx] > freq_threshold_high)
                        self.logger.experiment.log({
                            f'never_activated_{i + 1}': sum(never_activated),
                            f'too_often_activated_{i + 1}': sum(too_often_activated),
                            f'never_activated_rate_{i + 1}': sum(never_activated)/(end_idx - start_idx),
                            f'too_often_activated_rate_{i + 1}': sum(too_often_activated)/(end_idx - start_idx),
                        })
                        # # if the neuron is never activated, increase the bias term
                        # b_enc[start_idx:end_idx][never_activated] += adjust_b_enc_config.factor_up * max_pre_act[start_idx:end_idx][torch.logical_not(never_activated)].mean()
                        
                        # if the neuron is never activated, increase the bias term
                        if adjust_b_enc_config[f'group_{i + 1}'].get('adjust_b_enc', True):
                            factor_up = adjust_b_enc_config[f'group_{i + 1}'].get('factor_up', adjust_b_enc_config.factor_up)
                            b_enc[never_activated] += factor_up * max_pre_act[start_idx:end_idx][torch.logical_not(never_activated)].mean() if torch.any(torch.logical_not(never_activated)) else adjust_b_enc_config.factor_up * 1e-3
                        # b_enc[start_idx:end_idx][never_activated] += adjust_b_enc_config.factor_up * max_pre_act[start_idx:end_idx][torch.logical_not(never_activated)].mean()
                        self.logger.experiment.log({
                            f'max_pre_act_moderate_activated_{i + 1}': max_pre_act[start_idx:end_idx][torch.logical_not(never_activated)].mean(),
                        })
                        
                        # if the neuron is too often activated, decrease the bias term
                        if adjust_b_enc_config[f'group_{i + 1}'].get('adjust_b_enc', True):
                            factor_down = adjust_b_enc_config[f'group_{i + 1}'].get('factor_down', adjust_b_enc_config.factor_down)
                            b_enc[too_often_activated] -= factor_down * max_pre_act[start_idx:end_idx][too_often_activated]
                        self.logger.experiment.log({
                            f'max_pre_act_too_often_activated_{i + 1}': max_pre_act[start_idx:end_idx][too_often_activated].mean(),
                        })
                    
                # # if the neuron is not in the above two categories, we apply a small dicrease in the bias term 
                # b_enc[torch.logical_not(torch.logical_or(never_activated, too_often_activated))] -= adjust_b_enc_config.L1_decay
                
                # clamp the bias term below zero
                if 'clamp_b_enc_max' in adjust_b_enc_config:
                    b_enc.clamp_(max=adjust_b_enc_config.clamp_b_enc_max)
                
                # NOTE: Newly added, empty the buffers
                self.empty_activation_buffers()
        
    
def preprocess_ds(ds, model, hook_name='0_mlp_out', batch_size=128):
    """
    Preprocess the dataset to get the residual stream for each example.
    """
    residual_stream_ds = []
    for i in tqdm(range(0, len(ds), batch_size)):
        batch_ds = ds.select(range(i, i + batch_size))
        _logits, cache = model.run_with_cache(torch.tensor(batch_ds['input_ids']))
        residual_stream, labels = cache.decompose_resid(return_labels=True, mode="all")
        index = labels.index(hook_name)
        residual_stream_ds.append(residual_stream[index])
    return torch.stack(residual_stream_ds) # shape: (num_samples,position)


def get_experiment_name(config: EasyDict) -> str:
    """
    Generate a string to name your experiment based on the configuration parameters.
    Adjust this function to include whatever parameters are most relevant.
    """
    # Build a list of name components
    optimizer = config.optimizer
    learning_rate = config[optimizer + '_optimizer_config'].lr
    weight_decay = config[optimizer + '_optimizer_config'].weight_decay
    name_parts = [
        f"h{config.hidden_size}",
        f"n{config.num_neurons}",
        f"act-{config.activation}",
        f"lr{learning_rate}",
        f"opt-{optimizer}",
        f"bs{config.batch_size}",
        f"tb_dec-{config.tune_b_dec}",
    ]
    name_parts.append(f"num_samples-{config.num_samples}") if 'num_samples' in config else None
    name_parts.append(f"topk-{config.topk}") if 'topk' in config else None
    name_parts.append(f"L1-{config.L1_decay}") if 'L1_decay' in config else None
    # Join all parts using an underscore
    experiment_name = "_".join(name_parts)

    return experiment_name

def get_run_name(config: EasyDict) -> str:
    """
    Generate a string to name your run based on the configuration parameters.
    Adjust this function to include whatever parameters are most relevant.
    """
    # Build a list of name components
    run_name_parts = [
        "seed-" + str(config.seed),
        datetime.now().strftime("%Y%m%d_%H%M%S"),
    ]
    
    return "_".join(run_name_parts)

# class LitSAEWithChannelNew(LitSAEWithChannel):
#     def validation_step(self, batch, batch_idx):
#         """
#         Lightning calls this inside the validation loop.
#         """
        
#         # step_output = self._Step(batch, batch_idx, step_type='val', log=False)
#         step_output = self._Step(batch, batch_idx, step_type='val', log=True)
#         self.log('val_loss', step_output.loss, on_step=False, on_epoch=True, prog_bar=True)

#         return step_output.loss


if __name__ == '__main__':

    torch.set_float32_matmul_precision('high')
    
    # ------------------------------
    # data configurations
    # ------------------------------
    # multiprocessing.set_start_method("spawn", force=True)

    # ds = load_dataset("apollo-research/roneneldan-TinyStories-tokenizer-gpt2")
    
    # model = HookedTransformer.from_pretrained("tiny-stories-1L-21M")

    parser = argparse.ArgumentParser(description="SAE Transformer Training Arguments")
    parser.add_argument("--batch_size", type=int, default=128, help="Batch size for training")
    parser.add_argument("--num_workers", type=int, default=0, help="Number of workers for data loading")
    parser.add_argument("--input_norm", type=int, default=None, help="Input normalization parameter (if any)")
    parser.add_argument("--max_epochs", type=int, default=3, help="Maximum number of epochs for training")
    parser.add_argument("--buffer_in_GB", type=int, default=2, help="Buffer size in GB for data loading")
    parser.add_argument("--dataset", type=str, default="Pile_github-Qwen2.5-1.5B-L26-mlp-out-2048", help="Dataset name")
    parser.add_argument("--split_ratio", type=float, default=0.999, help="Split ratio for training and validation datasets")
    def float_or_str(value):
        try:
            return float(value)
        except ValueError:
            return value

    parser.add_argument("--freq_threshold_high", type=float_or_str, default='mixed', help="Frequency threshold for high frequency neurons (can be float or string)")

    args = parser.parse_args()

    batch_size = args.batch_size
    num_workers = args.num_workers
    input_norm = args.input_norm
    max_epochs = args.max_epochs
    buffer_in_GB = args.buffer_in_GB
    dataset_name = args.dataset
    split_ratio = args.split_ratio
    freq_threshold_high = args.freq_threshold_high
    debug = False # FIXME: set to False when running the script
    
    
    # data_path = os.path.join(parent_dir, dataset_name, 'non_padding_cache.h5')
    data_path = os.path.join("<SCRATCH_DIR>", dataset_name, 'train_data.h5')
    # data_info_path = os.path.join(parent_dir, dataset_name, 'dataset_info.yaml')
    def load_data_info(data_path: str) -> dict:
        info_path = os.path.join(os.path.dirname(data_path), "dataset_info.yaml")
        return clever_load(info_path)
    data_info_dict = load_data_info(data_path)  
    
    hidden_size = data_info_dict['dimensions']
    num_samples = data_info_dict['length']
    data_type = data_info_dict['data_type'] # 'float32' or 'float64'
    
    # decide the running buffer size
    bytes_per_sample = hidden_size * 4
    buffer_size_samples = int(buffer_in_GB * 1024**3 // bytes_per_sample)
    num_buffer_per_epoch = int(math.ceil(num_samples / buffer_size_samples))

    data_config = EasyDict(
        # ds = ds,
        # model = model,
        ds_name = dataset_name,
        model_name = 'Qwen/Qwen2.5-1.5B',
        # hook_name = hook_name,
        batch_size=batch_size,
        num_workers=num_workers,
        data_type=data_type,
        num_samples= int(num_samples * split_ratio), 
    )
    
    model_config = EasyDict(
        hidden_size = hidden_size, # Number of input features
        # num_neurons = 1024, # Number of neurons in the hidden layer
        # num_neurons = 8192,
        num_neurons = 65536,
        activation = 'relu', # Activation function to use
        channel_size_ls = [],
        use_neuron_weight = True, # Use neuron weight
    )
    
    # data_module = TransformerDataModule(ds, model, **data_config)
    # data_module = BatchHDF5DataModule(data_path, batch_size=batch_size, num_workers=num_workers, split_ratio=0.9, seed=42)
    
    data_module = BufferedBatchHDF5DataModule(
        h5_path=data_path,
        buffer_size_samples=buffer_size_samples,
        batch_size=batch_size,
        num_workers=num_workers,
        split_ratio=split_ratio, 
        seed=42,
    )
    
    
    train_config = EasyDict(clever_load(os.path.join(parent_dir, 'Simtransformer/simtransformer/configurations/train_config_default.yaml')))
    
    total_steps = int(num_samples // batch_size + 1) * max_epochs
    
    train_config.update(
        {
            'max_epochs': max_epochs,
            'batch_size': data_config.batch_size,
            'optimizer': 'AdamW',
            'AdamW_optimizer_config': {
                'lr': 1e-6,
                'weight_decay': 0.01,
                'betas': [0.9, 0.999],
            },
            'wandb_config':{
                'wandb_project': dataset_name,
                'wandb_entity': 'SAE_atomic', 
            },
            'num_neuron_vis': 50, # Number of neurons to visualize
            'seed': None,
            # 'lr_scheduler': None, # FIXME: set to None for now
            'lr_scheduler': 'cosine',
            'cosine_scheduler_config': {
                'lr_decay_steps': total_steps - 1000 if total_steps > 1000 else total_steps,
                'min_lr': 1e-6, 
                'warmup_steps': 1000 if total_steps > 1000 else 0,
            },
            # 'normalize_W_enc': True,
            'normalize_W_enc': False,
            'SAE_output_scale': 1,
            'b_dec_zero': False,
            'divide_by': 5, # FIXME:Rescaling factor for the input
            'normalize_batch_with_tanh_threshold': 20, 
        })

    train_config.update(
            {
                'tune_b_dec': True,
                'phase_0_config': {
                    'start_step': 0,
                    'end_step': total_steps,
                    'clamp_b_enc_max': 0.0,
                    # 'clamp_b_enc_min': -16.0,
                    'clamp_b_enc_min': -24.0, # FIXME: set to -24.0 for now
                    'tune_W_enc': True, # Tune W_enc
                    'tune_b_enc': False, # b_enc requires_grad
                    # 'use_alignment_loss': False,
                    'use_alignment_loss':False,
                    'adjust_b_enc_config': {
                        'group_size': 1,
                        # 'group_partitions':[0.01, 0.2, 0.3, 0.49], # sum of the partitions should be 1
                        # 'group_partitions':[0.1, 0.2, 0.3, 0.4],
                        'group_partitions':[1.0], 
                        'interval': 50, # adjust b_enc every 50 steps
                        'factor_up': 0.01, 
                        'factor_down': 0.2,
                        'group_1':{ # the most frequent neurons
                            'freq_threshold_low': 1 / (num_samples * 0.1), 
                            'freq_threshold_high': freq_threshold_high,
                        },
                        # 'group_2':{ # the second most frequent neurons
                        #     'freq_threshold_low': 1 / (num_samples * 0.1), 
                        #     'freq_threshold_high': 0.05,
                        # },
                        # 'group_3':{ # the third most frequent neurons
                        #     'freq_threshold_low': 1 / (num_samples * 0.1), 
                        #     'freq_threshold_high': 0.005,
                        # },
                        # 'group_4':{ # the fourth most frequent neurons
                        #     'freq_threshold_low': 1 / (num_samples * 0.1), 
                        #     'freq_threshold_high': 0.001,
                        # },
                    },
                    'L1_decay': 0.0,
                },
                # 'phase_1_config': {
                #     'start_step': num_samples // batch_size, 
                #     'end_step': (num_samples // batch_size + 1) * max_epochs, 
                #     'tune_neuron_weight': True, # Tune neuron weight
                #     'tune_b_enc': False, # Tune b_enc
                #     'tune_W_enc': True, # Tune W_enc
                #     'use_threshold_activation': True,
                #     'L1_decay': 0.0,
                # },
        
            }
        )
    if not isinstance(freq_threshold_high, float) and freq_threshold_high == "mixed":
        train_config.phase_0_config.adjust_b_enc_config.update(
            {
                'group_size': 4,
                # 'group_partitions':[0.01, 0.1, 0.4, 0.49], # sum of the partitions should be 1
                'group_partitions':[0.1, 0.2, 0.3, 0.4],
                # 'group_partitions':[1.0], 
                'interval': 50, # adjust b_enc every 50 steps
                'factor_up': 0.01, 
                'factor_down': 0.2,
                'group_1':{ # the most frequent neurons
                    'freq_threshold_low': 1 / (num_samples * 0.1), 
                    'freq_threshold_high': 0.2,
                },
                'group_2':{ # the second most frequent neurons
                    'freq_threshold_low': 1 / (num_samples * 0.1), 
                    'freq_threshold_high': 0.05,
                },
                'group_3':{ # the third most frequent neurons
                    'freq_threshold_low': 1 / (num_samples * 0.1), 
                    'freq_threshold_high': 0.005,
                },
                'group_4':{ # the fourth most frequent neurons
                    'freq_threshold_low': 1 / (num_samples * 0.1), 
                    'freq_threshold_high': 0.001,
                },
                # 'group_1':{ # the most frequent neurons         # FIXME: try boost the sparsity
                #     'freq_threshold_low': 1 / (num_samples * 0.1), 
                #     'freq_threshold_high': 0.1,
                # },
                # 'group_2':{ # the second most frequent neurons
                #     'freq_threshold_low': 1 / (num_samples * 0.1), 
                #     'freq_threshold_high': 0.05,
                # },
                # 'group_3':{ # the third most frequent neurons
                #     'freq_threshold_low': 1 / (num_samples * 0.1), 
                #     'freq_threshold_high': 0.001,
                # },
                # 'group_4':{ # the fourth most frequent neurons
                #     'freq_threshold_low': 1 / (num_samples * 0.1), 
                #     'freq_threshold_high': 0.0005,
                # },
                # 'group_1':{ # the most frequent neurons         # FIXME: try boost the sparsity
                #     'freq_threshold_low': 1 / (num_samples * 0.1), 
                #     'freq_threshold_high': 0.4,
                # },
                # 'group_2':{ # the second most frequent neurons
                #     'freq_threshold_low': 1 / (num_samples * 0.1), 
                #     'freq_threshold_high': 0.1,
                # },
                # 'group_3':{ # the third most frequent neurons
                #     'freq_threshold_low': 1 / (num_samples * 0.1), 
                #     'freq_threshold_high': 0.005,
                # },
                # 'group_4':{ # the fourth most frequent neurons
                #     'freq_threshold_low': 1 / (num_samples * 0.1), 
                #     'freq_threshold_high': 0.001,
                # },
            }
        )

    config = EasyDict({
        'data_config': data_config,
        **model_config,
        **train_config
    })      
    # config = EasyDict(
    #     # **data_config,
    #     **model_config,
    #     **train_config
    # )
    
     # Create the model
    
        # ckpt_path = os.path.join(current_dir, 'checkpoints/SAE_h256_n1024_act-relu_lr0.001_opt-AdamW_wd0.01_bs128_os0.1_AdjBenc_ClpBenc_bdec0-seed-100 - 2025-01-11-15:10:38/last.ckpt')
    ckpt_path = None
    
    if ckpt_path is not None:
        # model = LitSAEWithChannel.load_from_checkpoint(ckpt_path, config=config, feat_set=None)
        model = LitSAEWithChannelNew.load_from_checkpoint(ckpt_path, config=config, feat_set=None)
        
        # get the seed from the ckpt path by parsing the string
        seed = int(ckpt_path.split('seed-')[1].split(' - ')[0])
        config.seed = seed
        
        # get the global step from the ckpt
        ckpt_dict = torch.load(ckpt_path)
        config.last_global_step = ckpt_dict['global_step']
        
    else:
        # init the model
        # model = LitSAEWithChannel(config, feat_set=None)
        model = LitSAEWithChannelNew(config, feat_set=None)
    # seed everything
    if config.seed is not None:
        pl.seed_everything(config.seed)
    else:
        seed = random.randint(0, 1000)
        pl.seed_everything(seed)
    config.seed = seed  # save the seed in the config
    print(f'Using seed {seed}')
    
    if 'last_global_step' in config:
        run_name = 'seed-' + str(config.seed) + ' - ' + datetime.now().strftime("%Y%m%d_%H%M%S")  + f'- {config.last_global_step}'
    else:
        run_name = 'seed-' + str(config.seed) + ' - ' + datetime.now().strftime("%Y%m%d_%H%M%S")  
    exp_name = get_experiment_name(config)
    
    logger = pl.loggers.WandbLogger(
        project=config.wandb_config.wandb_project,
        entity=config.wandb_config.wandb_entity,
        group=exp_name,
        name=run_name, 
        config = config,
        reinit = True
    ) if not debug else None
    
    checkpoint_dir = os.path.join(current_dir, f'checkpoints/{exp_name}-{run_name}')
    
    # make sure the checkpoint directory exists
    os.makedirs(checkpoint_dir, exist_ok=True)
    # save the configuration
    with open(os.path.join(checkpoint_dir, 'config.yaml'), 'w') as f:
        yaml.dump(clean_config(config), f,  default_flow_style=False) 

    config.checkpoint_dir = checkpoint_dir

    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath=checkpoint_dir,
        filename='{step}-{train_normalized_reconstruction_loss_step:.4f}',
        # monitor='val_loss', # this depends on logging in the LightningModule
        monitor='train_normalized_reconstruction_loss_step',
        mode='min',
        save_top_k=2,
        every_n_train_steps=500,
        save_last=True,
    )
    
    lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='step')
    callbacks = [lr_monitor, checkpoint_callback, BufferUpdateCallback()]
    
    trainer_kwargs = dict(
        max_epochs=config.max_epochs * num_buffer_per_epoch, # max_epochs here should be the total number of buffers
        logger=logger,
        callbacks=callbacks,
        val_check_interval=getattr(config, 'val_check_interval', 0.5),
        precision=getattr(config, 'precision', '32'),
        accelerator='gpu', 
        log_every_n_steps = 30, 
        # devices=1,
    )
    
    if getattr(trainer_kwargs, 'precision', None) is not None:
        print(f'Precision set to {trainer_kwargs.precision}.')
        
    trainer = pl.Trainer(**trainer_kwargs)
    trainer.fit(model, data_module)