import copy
import gc
import inspect
import math
import os
from collections import Counter
from datetime import datetime
from typing import Iterable, Optional

import torch as t
import wandb
from einops import einsum, rearrange
from loguru import logger
from schedulefree import AdamWScheduleFree
from torch.cuda.amp import GradScaler
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

from auto_encoder import device
from auto_encoder.config import AutoEncoderConfig
from auto_encoder.config_enums import ResamplingType, SchedulerType
from auto_encoder.helpers.activation_resampling import apply_resampling_
from auto_encoder.helpers.ae_metrics import AutoEncoderMetrics, MetricsCollection
from auto_encoder.helpers.ae_output_types import SupermodelOutput
from auto_encoder.helpers.benchmarking import wall_clock_timer
from auto_encoder.helpers.buffer import ActivationBuffer
from auto_encoder.helpers.ema_weights import EmaModel
from auto_encoder.helpers.geometric_median import estimate_activations_geometric_median
from auto_encoder.helpers.logging_fns import _train_logging, eval_logging
from auto_encoder.helpers.schedulers import CustomScheduler, TopK_M_DecreasingScheduler
from auto_encoder.models.base_ae import AutoEncoderBase
from auto_encoder.training.supermodel import FrozenTransformerAutoencoderSuperModel
from auto_encoder.utils import log_largest_tensors
from data.ae_data import get_eval_dataloader, get_train_dataloader

if t.cuda.is_available():
    from bitsandbytes.optim import AdamW8bit


class SAETrainer:
    def __init__(
        self,
        ae_config: AutoEncoderConfig,
        use_wandb: bool = True,
        supermodel: Optional[FrozenTransformerAutoencoderSuperModel] = None,
        train_dataloader: Optional[DataLoader] = None,
        eval_dataloader: Optional[DataLoader] = None,
        finetuning: bool = False,
        device: str = device,
    ):
        self.use_wandb = use_wandb
        self.ae_config = ae_config
        self.device = device

        try:
            tokenizer = AutoTokenizer.from_pretrained(ae_config.transformer_model_name)
        except Exception:
            tokenizer = AutoTokenizer.from_pretrained("gpt2")

        print(ae_config)

        logger.info("Transformer model batch size: ", ae_config.transformer_batch_size)

        self.train_dataloader = train_dataloader or get_train_dataloader(
            batch_size=ae_config.transformer_batch_size,
            tokenizer=tokenizer,
            max_length=ae_config.seq_len,
        )
        self.eval_dataloader = eval_dataloader or get_eval_dataloader(
            batch_size=ae_config.transformer_batch_size // 2,
            tokenizer=tokenizer,
            max_length=ae_config.seq_len,
        )

        if supermodel is None:
            transformer_model_name = ae_config.transformer_model_name
            assert transformer_model_name in TRANSFORMER_MODELS
            self.supermodel = supermodel or self.initialise_supermodel(transformer_model_name)
        else:
            self.supermodel = supermodel

        self.train_step_count = 0

        self.optimizer = self._get_optimizer()
        self.scheduler = self._get_scheduler()

        if self.ae_config.ema_multiplier is not None:
            self.ema_ae_model = EmaModel(self.supermodel.autoencoder, ema_multiplier=0.999)

        self.use_scaler = self.device.startswith("cuda") and self.ae_config.use_loss_scaling
        self.scaler = GradScaler(enabled=self.use_scaler)

        if self.use_scaler:
            logger.info("Using GradScaler")
        else:
            logger.info("Not using GradScaler")

        self.is_autocast_enabled = self.ae_config.autocast_is_enabled
        self.autocast_device_type = self.ae_config.autocast_dtype

        if finetuning:
            self.og_autoencoder = copy.deepcopy(self.supermodel.autoencoder)
            self.og_autoencoder.eval()
            self.og_autoencoder.to(self.device)
            self.og_autoencoder.requires_grad_(False)
        else:
            self.og_autoencoder = None

        # if self.ae_config.use_decoder_kernel:
        #     logger.warning(
        #         "Using decoder kernel so using regular float32 precision for matmuls"
        #     )
        #     # We can't use this with the decoder kernel as it's not supported and gives errors
        # else:
        # t.set_float32_matmul_precision("high")  # switch from float32 to tf32 for speedups
        # logger.info("Using tf32 matmul precision for speedups")

        t.set_float32_matmul_precision("high")  # switch from float32 to tf32 for speedups
        logger.info(
            "Using tf32 matmul precision for speedups"
        )  # We're switching back right before using the kernel to avoid slowdowns there

        self.activation_buffer = ActivationBuffer(
            supermodel=self.supermodel,
            batch_size=self.ae_config.batch_size,
            train_dataloader=self.train_dataloader,
            do_shuffle=True,
        )

        # This holds the number of times that a given feature has been activated since the last resampling event
        self.features_activated_tracker_F = t.zeros(
            self.ae_config.num_features, device=self.device
        ).bool()

        self.best_metrics_to_date = AutoEncoderMetrics(math.inf, math.inf, math.inf)
        self.best_proxy_metric_to_date = -math.inf

        self.timer = wall_clock_timer()

    def _get_optimizer(self) -> t.optim.Optimizer:
        optimizer_grouped_parameters = self._get_optimizer_grouped_parameters()

        device_type = "cuda" if device.startswith("cuda") else "cpu"

        # Use the fused optimiser if available and on GPU
        fused_available = "fused" in inspect.signature(t.optim.AdamW).parameters
        use_fused = fused_available and device_type == "cuda"

        if self.ae_config.use_schedule_free_adam:
            optimizer = AdamWScheduleFree(
                optimizer_grouped_parameters,
                # lr=self.ae_config.learning_rate,
                betas=self.ae_config.betas,
                eps=self.ae_config.adam_eps,
                weight_decay=self.ae_config.weight_decay,
                warmup_steps=min(self.ae_config.num_total_steps // 10, 1_000),
            )

        elif self.ae_config.use_8bit_adam:
            optimizer = AdamW8bit(
                optimizer_grouped_parameters,
                # lr=self.ae_config.learning_rate,
                betas=self.ae_config.betas,
                eps=self.ae_config.adam_eps,
            )
            logger.info("Using bnb 8-bit Adam")
        else:
            optimizer = t.optim.AdamW(
                optimizer_grouped_parameters,
                # lr=self.ae_config.learning_rate,
                betas=self.ae_config.betas,
                eps=self.ae_config.adam_eps,
                fused=use_fused,
            )
            logger.info("Using torch's regular AdamW optimizer")
        return optimizer

    def _get_optimizer_grouped_parameters(self) -> list[dict]:

        # Collect all parameters that require gradients from the autoencoder (not the transformer)
        param_dict = {
            param_name: parameters
            for param_name, parameters in self.supermodel.autoencoder.named_parameters()
        }
        param_dict = {
            param_name: parameters
            for param_name, parameters in param_dict.items()
            if parameters.requires_grad
        }

        # Create optim groups to set the weight decay separately for certain params
        # Never decay the decoder weights, router weights or the bias terms and other 1D tensors
        params_names_for_decay = [
            param_name
            for param_name, parameters in param_dict.items()
            if parameters.dim() >= 2
            and "decoder" not in param_name
            and "router" not in param_name
            and "routing" not in param_name
        ]

        decay_params = [
            parameters
            for param_name, parameters in param_dict.items()
            if param_name in params_names_for_decay
        ]
        no_decay_params = [
            parameters
            for param_name, parameters in param_dict.items()
            if param_name not in params_names_for_decay
        ]

        optimizer_grouped_parameters = [
            {
                "params": decay_params,
                "weight_decay": self.ae_config.weight_decay,
                "lr": self.ae_config.learning_rate,
            },
            {
                "params": no_decay_params,
                "weight_decay": 0.0,
                "lr": self.ae_config.learning_rate
                * self.ae_config.decoder_learning_rate_multiple,
            },
        ]

        num_decay_params = sum(params.numel() for params in decay_params)
        num_no_decay_params = sum(params.numel() for params in no_decay_params)

        logger.debug(
            f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters"
        )
        logger.debug(
            f"num non-decayed parameter tensors: {len(no_decay_params)}, with {num_no_decay_params:,} parameters"
        )

        return optimizer_grouped_parameters

    def _get_scheduler(self):
        schedule_type = self.ae_config.schedule_type

        if schedule_type == SchedulerType.COSINE:
            return t.optim.lr_scheduler.CosineAnnealingLR(
                self.optimizer, T_max=self.ae_config.num_total_steps
            )

        elif schedule_type == SchedulerType.CUSTOM:
            return CustomScheduler(
                optimizer=self.optimizer,
                config=self.ae_config,
            )

        elif schedule_type == SchedulerType.DECREASING_K_M:
            return TopK_M_DecreasingScheduler(
                optimizer=self.optimizer,
                config=self.ae_config,
            )

        elif schedule_type == SchedulerType.NONE:
            return None

        else:
            raise ValueError(f"Unsupported scheduler type: {schedule_type}")

    def train_step(self) -> tuple[t.Tensor, t.Tensor, Optional[t.Tensor], t.Tensor, float]:
        ### TRAINING STEP ###

        activation_batch_BSN = self.activation_buffer.get_activation_batch()
        activation_batch_BSN = activation_batch_BSN.to(self.device)

        self.supermodel.autoencoder.train()
        if self.ae_config.use_schedule_free_adam:
            assert isinstance(self.optimizer, AdamWScheduleFree)
            self.optimizer.train()

        self.get_params_from_scheduler()
        self.optimizer.zero_grad()

        features_activated_F: t.Tensor = t.zeros(
            self.ae_config.num_features, device=self.device
        ).bool()
        batch_expert_usage: list[t.Tensor] = []

        loss_accum = t.tensor(0.0, device=self.device)

        for _minibatch_num, activation_minibatch_BSN in enumerate(
            activation_batch_BSN.split(self.ae_config.minibatch_size)
        ):

            ### Minibatch Loss Calculation ###
            with t.autocast(
                enabled=self.is_autocast_enabled,
                device_type="cuda" if self.device.startswith("cuda") else "cpu",
                dtype=self.autocast_device_type,
            ):
                activation_minibatch_BSN: t.Tensor

                supermodel_output: SupermodelOutput = self.supermodel.autoencoder_loss(
                    activation_minibatch_BSN,
                    dtype=(
                        self.autocast_device_type
                        if self.is_autocast_enabled
                        else activation_minibatch_BSN.dtype
                    ),
                )
                feature_activations = supermodel_output.feature_activations_BSF
                expert_usage_E = supermodel_output.metrics.expert_usage
                loss = supermodel_output.scalar_loss

                if (
                    self.og_autoencoder is not None
                    and self.ae_config.finetune_kl_penalty_coef > 0
                ):
                    # If finetuning then apply KL-penalty
                    ft_model_recons_BSN = (
                        supermodel_output.reconstructed_neuron_activations_BSN
                    )

                    with t.no_grad():
                        og_model_recons_BSN = self.og_autoencoder(activation_minibatch_BSN)

                    kl_loss = t.nn.functional.kl_div(
                        t.log(1e-8 + ft_model_recons_BSN),
                        og_model_recons_BSN,
                        reduction="batchmean",
                    )
                    loss = loss + kl_loss * self.ae_config.finetune_kl_penalty_coef

                if self.ae_config.autoencoder_type.secondary_loss:
                    loss_secondary = supermodel_output.secondary_loss

            ### From minibatch to full batch ###
            grad_scale = self.scaler.get_scale()  # returns 1.0 if not using GradScaler

            # Average the loss over the minibatches for gradient accumulation
            loss_accum += loss.detach()
            loss = loss.float() / self.ae_config.num_minibatches

            # Scale the loss to improve numerical stability for the optimizer and then backprop
            loss = self.scaler.scale(loss)

            if self.ae_config.autoencoder_type.secondary_loss and loss_secondary is not None:
                assert loss_secondary is not None

                try:
                    loss_scale_factor = self.scaler._scale
                except AttributeError:
                    loss_scale_factor = 1.0

                assert loss_scale_factor is not None
                loss_secondary = loss_secondary * loss_scale_factor

                grad_main = t.autograd.grad(
                    loss, self.supermodel.autoencoder.parameters(), create_graph=True, allow_unused=True  # type: ignore
                )
                grad_secondary = t.autograd.grad(
                    loss_secondary, self.supermodel.autoencoder.parameters(), create_graph=True, allow_unused=True  # type: ignore
                )

                grad_resolved = self.combine_main_and_secondary_gradients(
                    grad_main, grad_secondary
                )

                for param, grad in zip(
                    self.supermodel.autoencoder.parameters(), grad_resolved
                ):
                    param.grad = grad

            else:
                loss.backward()

            # Collect the feature activations and expert usage
            current_features_activated_F = t.sum(feature_activations, dim=(0, 1)).bool()
            features_activated_F = features_activated_F & current_features_activated_F
            if expert_usage_E is not None:
                batch_expert_usage.append(expert_usage_E.float())

        # Combine the feature activations and expert usage from the minibatches
        if batch_expert_usage:
            expert_usage_E = t.stack(batch_expert_usage, dim=0).mean(dim=0)

        autoencoder: AutoEncoderBase = self.supermodel.autoencoder

        # Remove component of gradients that are parallel to dictionary vectors each step
        self.remove_gradients_parallel_to_decoder_(autoencoder)
        # self.remove_gradients_parallel_to_correlation_(autoencoder.decoder.weight, autoencoder)

        if self.autocast_device_type == t.float32:
            # Rescale the gradients for the optimiser for the gradient clipping etc.
            self.scaler.unscale_(self.optimizer)

            # Clip gradient norms
            grad_norm = clip_grad_norm_(
                autoencoder.parameters(), max_norm=self.ae_config.max_grad_norm
            )
        else:
            grad_norm = t.tensor(1.0, device=self.device)

        ### Step the optimizer ###
        if self.use_scaler:
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            self.optimizer.step()

        if self.scheduler:
            self.scheduler.step()

        ### Apply EMA multiplier to weights ###
        if self.ae_config.ema_multiplier is not None:
            self.ema_ae_model.step()

        # Ensure updated dictionary vectors (decoder weights) have unit norm
        self.set_decoder_vectors_to_unit_norm_(autoencoder)

        normalised_loss = loss_accum / self.ae_config.num_minibatches

        return (
            normalised_loss,
            features_activated_F,
            expert_usage_E,
            grad_norm,
            grad_scale,
        )

    def remove_gradients_parallel_to_decoder_(self, autoencoder: AutoEncoderBase) -> None:
        """Avoiding shrinkage by removing the component of the gradient that is parallel to the dictionary vectors"""
        with t.no_grad():
            autoencoder.decoder.weight.grad -= (
                autoencoder.decoder.weight.grad * autoencoder.decoder.weight.data
            ) / t.norm(autoencoder.decoder.weight.data, dim=0, keepdim=True)

    def combine_main_and_secondary_gradients(
        self, grad_main: Iterable[t.Tensor], grad_secondary: Iterable[t.Tensor]
    ) -> Iterable[t.Tensor]:
        grad_resolved = []
        for grad_main_element, grad_secondary_element in zip(grad_main, grad_secondary):
            if grad_secondary_element is None:
                grad_resolved.append(grad_main_element)
                continue

            projection = t.sum(grad_main_element * grad_secondary_element) / t.sum(
                grad_main_element * grad_main_element
            )
            grad_secondary_without_main_component = (
                grad_secondary_element - projection * grad_main_element
            )

            grad_resolved_element = grad_main_element + grad_secondary_without_main_component
            grad_resolved.append(grad_resolved_element)

        return grad_resolved

    def _decorrelation_score(self, decoder_weights_NF: t.Tensor) -> float:
        _num_neurons, num_features = decoder_weights_NF.shape

        with t.no_grad():
            decoder_weights_normed_NF = decoder_weights_NF / t.norm(
                decoder_weights_NF, dim=0, keepdim=True
            )
            # Compute correlation matrix
            neural_feature_matrix_FF = einsum(
                decoder_weights_normed_NF,
                decoder_weights_normed_NF,
                "num_neurons num_features1, num_neurons num_features2 -> num_features1 num_features2",
            )

            # Renove diagonal elements (1s)
            identity_FF = t.eye(num_features, device=neural_feature_matrix_FF.device)
            neural_feature_matrix_FF = neural_feature_matrix_FF - identity_FF

            decorrelation_score: t.Tensor = t.sum(neural_feature_matrix_FF**2)

            normalisation = num_features * (num_features - 1)
            decorrelation_score = decorrelation_score / normalisation

            sqrt_decorr_score = t.sqrt(decorrelation_score)

            return sqrt_decorr_score.item()

    def get_params_from_scheduler(self) -> None:
        if isinstance(self.scheduler, CustomScheduler):
            self.supermodel.auxiliary_l1_sparsity_coef = self.scheduler.sparsity_lambda
            (
                self.supermodel.load_balancing_loss_coef,
                self.supermodel.expert_importance_loss_coef,
            ) = self.scheduler.balancing_loss_coefs

        if isinstance(self.scheduler, TopK_M_DecreasingScheduler):
            self.supermodel.autoencoder.topm = self.scheduler.topm
            self.supermodel.autoencoder.topk = self.scheduler.topk

            self.supermodel.autoencoder.stochastic_topk_temperature = (
                self.scheduler.stochastic_topk_temperature
            )

    @staticmethod
    def set_decoder_vectors_to_unit_norm_(autoencoder: AutoEncoderBase) -> None:
        with t.no_grad():
            autoencoder.decoder.weight.data /= t.norm(
                autoencoder.decoder.weight.data, dim=0, keepdim=True
            )

    def fit(
        self, save_best_model: bool = False, wandb_run_name: str = "sae"
    ) -> FrozenTransformerAutoencoderSuperModel:
        resample_steps = self.ae_config.resample_steps

        logger.info(f"Batch size: {self.ae_config.batch_size}")
        logger.info(f"Minibatch size: {self.ae_config.minibatch_size}")

        next(self.timer)  # Initialize the timer

        for self.train_step_count in range(self.ae_config.num_total_steps):
            current_step = self.train_step_count

            loss, features_activated_F, expert_usage_E, grad_norm, current_grad_scale = (
                self.train_step()
            )

            ### UPDATE TRACKERS ###

            if current_step % resample_steps > resample_steps / 2:
                # Only track activations for the second half of the resampling period
                self.features_activated_tracker_F = (
                    self.features_activated_tracker_F & features_activated_F
                )

            ### LOGGING ###

            # Print the loss every 100 steps
            if current_step % 100 == 0:
                logger.info(f"Run {(wandb.run.name if wandb.run else '')}:")
                logger.info(f"Train loss: {loss.item():.4} after {current_step} samples")

                num_dead_features, num_features = self.get_num_dead_features()

                logger.info(
                    f"{num_dead_features}/{num_features} dead features after {current_step} samples"
                )

            if current_step % 10 == 0:
                # Log train_loss to wandb every 10 steps
                self.train_logging(
                    current_step=current_step,
                    loss=loss,
                    grad_norm=grad_norm,
                    grad_scale=current_grad_scale,
                )

            ### EVALUATION ###

            if current_step % self.ae_config.eval_steps == 0 and current_step > 0:
                # Log full eval_metrics to wandb every eval_steps steps
                self.eval(wandb_run_name=wandb_run_name, save_best_model=save_best_model)

            ### SAVING ###

            if current_step % self.ae_config.save_steps == 0 and current_step > 0:
                date = datetime.now().strftime("%m-%d_%H-%M")
                self.supermodel.save_autoencoder(
                    f"checkpoint_{current_step}_{self.ae_config.autoencoder_type.value}_{date}"
                )

        if self.ae_config.ema_multiplier is not None:
            self.ema_ae_model.update_model_weights()

        self.supermodel.autoencoder.eval()
        if self.ae_config.use_schedule_free_adam:
            assert isinstance(self.optimizer, AdamWScheduleFree)
            self.optimizer.eval()

        return self.supermodel

    def get_num_dead_features(self) -> tuple[int, int]:
        feature_activity_queue = self.supermodel.autoencoder.feature_activation_queue
        recent_activity_F = feature_activity_queue.recent_activity()
        dead_features_F = recent_activity_F == 0
        num_dead_features = int(t.sum(dead_features_F).item())

        return num_dead_features, len(recent_activity_F)

    def train_logging(
        self, current_step: int, loss: t.Tensor, grad_norm: t.Tensor, grad_scale: float
    ):
        auxiliary_gating_l1_coef = self.supermodel.auxiliary_l1_sparsity_coef

        current_learning_rate = self.optimizer.param_groups[0]["lr"]

        num_tokens = self.ae_config.batch_size * self.ae_config.seq_len * current_step

        _train_logging(
            self.use_wandb,
            current_step,
            loss.item(),
            l1_sparsity_coef=auxiliary_gating_l1_coef,
            learning_rate=current_learning_rate,
            auxiliary_balancing_loss_coef=self.supermodel.load_balancing_loss_coef,
            expert_importance_loss_coef=self.supermodel.expert_importance_loss_coef,
            capacity_factor=0,
            grad_norm=grad_norm.item(),
            grad_scale=grad_scale,
            num_tokens=num_tokens,
            stochastic_topk_temperature=self.supermodel.autoencoder.stochastic_topk_temperature,
        )

        # Log memory usage
        if t.cuda.is_available():
            logger.debug(t.cuda.memory_summary(device=self.device))

            if logger.level == "DEBUG":
                log_largest_tensors()

    def clear_cache(self, feature_activations: t.Tensor) -> None:
        del feature_activations
        if t.cuda.is_available():
            t.cuda.empty_cache()

        # Print out largest sources of memory usage
        self.optimizer.zero_grad()
        if t.cuda.is_available():
            t.cuda.empty_cache()
        gc.collect()

    def prep_resampling(self, current_step: int) -> tuple[int, int]:
        dead_features_F = ~self.features_activated_tracker_F  # num_features
        num_dead_features = int(t.sum(dead_features_F).item())

        logger.debug(dead_features_F)

        logger.debug(self.features_activated_tracker_F)

        logger.warning(
            f"{num_dead_features} dead features detected after {current_step} samples"
        )
        logger.warning(
            f"That's {num_dead_features / self.ae_config.num_features * 100:.2f}% of the features"
        )

        num_dead_experts = 0

        return num_dead_features, num_dead_experts

    def eval(self, wandb_run_name: str, save_best_model: bool = False) -> None:
        eval_dataloader = self.eval_dataloader

        # COLLECT METRICS
        metrics_list: list[AutoEncoderMetrics] = []
        features_per_token_count = Counter()

        for sample_num, batch in enumerate(eval_dataloader):
            print("Sample num: ", sample_num)
            # if sample_num >= self.ae_config.eval_num_batches:
            if sample_num >= 10_000:
                break

            # Sample `batch_size` examples from the dataset
            batch_input: t.Tensor = batch["input_ids"]  # batch_size, seq_len
            batch_input = batch_input.to(device)

            # Run the model on the batch
            self.supermodel.autoencoder.eval()
            if self.ae_config.use_schedule_free_adam:
                assert isinstance(self.optimizer, AdamWScheduleFree)
                self.optimizer.eval()

            with t.no_grad():
                if (
                    self.ae_config.ema_multiplier is not None
                    and self.ema_ae_model.ema_steps > 0
                ):
                    with self.ema_ae_model.use_ema_weights():
                        supermodel_output: SupermodelOutput = self.supermodel(
                            batch_input, output_eval_metrics=True
                        )
                else:
                    supermodel_output: SupermodelOutput = self.supermodel(
                        batch_input, output_eval_metrics=True
                    )

            metrics = supermodel_output.metrics
            metrics_list.append(metrics)

            features_BSF = supermodel_output.feature_activations_BSF
            features_per_token_BS = t.sum((features_BSF > 0).float(), dim=-1)
            features_per_token_Bs = rearrange(features_per_token_BS, "b s -> (b s)")
            features_per_token_count.update(features_per_token_Bs.tolist())

        metrics_collection = MetricsCollection(metrics_list)
        reduced_metrics = metrics_collection.reduce()

        decoder_weight_NF = self.supermodel.autoencoder.decoder.weight

        reduced_metrics.decorr_score = self._decorrelation_score(
            decoder_weights_NF=decoder_weight_NF
        )

        logger.debug(f"Reduced metrics: {reduced_metrics}")
        if t.cuda.is_available():
            logger.debug(t.cuda.memory_summary(device=device))

        if reduced_metrics.overall_loss is t.nan:
            logger.error("NaN loss detected")
            raise ValueError(f"NaN loss detected after {self.train_step_count} steps")

        proxy_metric = reduced_metrics.proxy_sweep_metric(
            num_features=self.ae_config.num_features,
            density_penalty=self.ae_config.density_penalty_for_proxy_sweep_metric,
        )

        if proxy_metric > self.best_proxy_metric_to_date:
            self.best_proxy_metric_to_date = proxy_metric
            self.best_metrics_to_date = reduced_metrics
            logger.success(f"New best proxy metric: {proxy_metric}")

            if save_best_model:
                self.supermodel.save_autoencoder(
                    f"{wandb_run_name}_best_ckpt",
                    other_details=f"best_checkpoint, proxy_metric = {proxy_metric:.4f}",
                )
                wandb.save(f"{wandb_run_name}_best_ckpt.pt")

        logger.success(
            f"Evaluated model after {self.train_step_count}/{self.ae_config.num_total_steps} samples"
        )

        time_elapsed = next(self.timer)

        num_dead_features, _total_num_features = self.get_num_dead_features()

        print(features_per_token_count)

        eval_logging(
            train_sample_num=self.train_step_count,
            reduced_metrics=reduced_metrics,
            num_features=self.ae_config.num_features,
            use_wandb=self.use_wandb,
            density_penalty=self.ae_config.density_penalty_for_proxy_sweep_metric,
            time_elapsed_mins=time_elapsed / 60,
            total_num_steps=self.ae_config.num_total_steps,
            num_dead_features=num_dead_features,
        )

        if self.train_step_count == self.ae_config.num_total_steps:
            logger.success(f"Final eval metrics: {reduced_metrics}")

            logger.success(f"Best eval metrics: {self.best_metrics_to_date}")

    def train(
        self,
        alert_on_success: bool = True,
        wandb_run_name: Optional[str] = None,
        save_best_model: bool = True,
        save_final_model: bool = False,
    ) -> FrozenTransformerAutoencoderSuperModel:
        ae_config = self.ae_config

        if wandb_run_name:
            pass
        else:
            current_date_time = datetime.now()
            formatted_date_time = current_date_time.strftime("%m-%d %H:%M")
            wandb_run_name = f"{formatted_date_time} | {ae_config.autoencoder_type.value}"

        if self.use_wandb:
            wandb_config = ae_config.to_dict()
            wandb_config["device"] = device

            wandb_run = wandb.init(
                project="mo_auto_encoder-1",
                name=wandb_run_name,
                config=wandb_config,
                dir="_wandb",
            )

        try:
            supermodel = self.fit(
                save_best_model=save_best_model, wandb_run_name=wandb_run_name
            )

            logger.success("Training complete!")
            logger.success(ae_config)

            if save_final_model:
                final_model_name = f"{wandb_run_name}_final"
                supermodel.save_autoencoder(final_model_name, other_details="final_checkpoint")
                if self.use_wandb:
                    wandb.save(final_model_name + ".pt")
                    logger.info(f"Final model saved to wandb")

            if self.use_wandb:
                if alert_on_success:
                    wandb_run.alert(
                        title="Training Complete",
                        text=f"Training Complete: {wandb.run.name if wandb.run else ''}!",
                    )
                wandb.finish()

            return supermodel

        except Exception as e:
            logger.error(f"Training failed: {e}")
            if t.cuda.is_available():
                logger.error(t.cuda.memory_summary(device=device))
            logger.error(ae_config)
            if self.use_wandb:
                wandb_run.alert(
                    title="Training Failed",
                    text=f"Error: {e}" + f"Config: {ae_config}",
                    level="ERROR",
                )
            raise e

    def initialise_supermodel(
        self,
        transformer_name: str,
        get_saved_tensors_if_exist: bool = True,
    ) -> FrozenTransformerAutoencoderSuperModel:
        logger.info("Initialising supermodel")

        ae_config = self.ae_config
        autoencoder_type = ae_config.autoencoder_type

        # Define model
        pretrained_transformer = get_transformer(transformer_name, device=device)

        mock_supermodel = FrozenTransformerAutoencoderSuperModel(
            pretrained_transformer,
            ae_config=ae_config,
            device=device,
            medoid_initial_tensor_N=None,
            expert_initial_tensors=None,
            scaling_factor=None,
        )

        # Paths for saving/loading tensors
        geom_median_path = f"geometric_median_{transformer_name}.pt"
        expert_centroids_path = f"expert_centroids_{0}_{transformer_name}.pt"
        scaling_factor_path = f"scaling_factor_{transformer_name}.pt"

        # Attempt to load saved tensors if they exist
        if (
            os.path.exists(geom_median_path)
            and os.path.exists(scaling_factor_path)
            and get_saved_tensors_if_exist
        ):
            geometric_median_tensor = t.load(geom_median_path, map_location=device)
            scaling_factor = t.load(scaling_factor_path, map_location=device)
            logger.info("Loaded geometric median tensor and scaling factor from disk")
        else:
            # Initialise the neuron bias to an estimate of the geometric median of the training data
            geometric_median_tensor, scaling_factor = estimate_activations_geometric_median(
                mock_supermodel, self.train_dataloader
            )
            geometric_median_tensor = geometric_median_tensor.detach()
            t.save(geometric_median_tensor, geom_median_path)
            t.save(scaling_factor, scaling_factor_path)
            logger.info("Geometric median tensor calculated and saved")

        expert_centroids = None
        logger.info("Not using routing so not retrieved expert centroids")

        supermodel = FrozenTransformerAutoencoderSuperModel(
            pretrained_transformer,
            ae_config=ae_config,
            device=device,
            medoid_initial_tensor_N=geometric_median_tensor,
            expert_initial_tensors=expert_centroids,
            scaling_factor=scaling_factor,
        )

        logger.info(f"Models initialised on device {device}")
        logger.info(f"Autoencoder type: {type(supermodel.autoencoder).__name__}")

        return supermodel


def main(
    ae_config: AutoEncoderConfig,
    device: str = device,
    use_wandb: bool = False,
    alert_on_success: bool = True,
    wandb_run_name: Optional[str] = None,
):
    sae_trainer = SAETrainer(ae_config=ae_config, use_wandb=use_wandb, device=device)
    sae_trainer.train(alert_on_success=alert_on_success, wandb_run_name=wandb_run_name)
