from sae_lens import SAETrainingRunner
from sae_lens.training.sae_trainer import SAETrainer
from sae_lens.evals import EvalConfig, run_evals
from tqdm import tqdm
import torch
import os
from typing import override, cast, Any
import wandb

class ModifiedSAETrainer(SAETrainer):
    @override
    def fit(self):
        pbar = tqdm(total=self.cfg.total_training_tokens, desc="Training SAE")
        # save the activations store norm factor
        self.activations_store.set_norm_scaling_factor_if_needed()

        # update initial config to include the norm scaling factor
        # unfold here
        if self.cfg.unfold_estimated_norm_factor:
            print("Unfolding current estimated norm factor")
            self.sae.W_enc.data = self.sae.W_enc.data.detach().clone() / self.activations_store.estimated_norm_scaling_factor
            self.sae.W_dec.data = self.sae.W_dec.data.detach().clone() * self.activations_store.estimated_norm_scaling_factor
            self.sae.b_dec.data = self.sae.b_dec.data.detach().clone() * self.activations_store.estimated_norm_scaling_factor

        if hasattr(self.cfg, 'checkpoint_thresholds'):
            # convert units from training steps to training tokens
            self.checkpoint_thresholds = [self.cfg.train_batch_size_tokens * threshold for threshold in self.cfg.checkpoint_thresholds]
                
        initial_config_dir = os.path.join(self.cfg.checkpoint_path, 'initial_config')
        if not(os.path.isdir(initial_config_dir)):
            os.makedirs(initial_config_dir)
        print("Saving initial model to", initial_config_dir)
        self.sae.save_model(initial_config_dir)
        print("Saving estimated norm scaling factor", self.activations_store.estimated_norm_scaling_factor)
        torch.save(self.activations_store.estimated_norm_scaling_factor, os.path.join(self.cfg.checkpoint_path, "estimated_norm_scaling_factor.pt"))
        
        # Train loop
        while self.n_training_tokens < self.cfg.total_training_tokens:
            # Do a training step.
            layer_acts = self.activations_store.next_batch()[:, 0, :].to(
                self.sae.device
            )
            
            if self.n_training_steps == 0:
                # wandb log and eval step 0...
                with torch.no_grad():
                    train_step_output = self.sae.training_forward_pass(
                        sae_in=layer_acts,
                        dead_neuron_mask=self.dead_neurons,
                        current_l1_coefficient=self.current_l1_coefficient,
                    )
                    if self.cfg.log_to_wandb:
                        if self.current_l1_coefficient != 0:
                            self._log_train_step(train_step_output)
                            self._run_and_log_evals()
                self.n_training_steps = 0 # now should be at 0


            self.n_training_tokens += self.cfg.train_batch_size_tokens

            step_output = self._train_step(sae=self.sae, sae_in=layer_acts)

            if self.cfg.log_to_wandb:
                self._log_train_step(step_output)
                self._run_and_log_evals()

            self._checkpoint_if_needed()
            self.n_training_steps += 1
            self._update_pbar(step_output, pbar)                
            ### If n_training_tokens > sae_group.cfg.training_tokens, then we should switch to fine-tuning (if we haven't already)
            self._begin_finetuning_if_needed()

        # fold the estimated norm scaling factor into the sae weights

        
        if self.activations_store.estimated_norm_scaling_factor is not None:
            self.sae.fold_activation_norm_scaling_factor(
                self.activations_store.estimated_norm_scaling_factor
            )
            self.activations_store.estimated_norm_scaling_factor = None

        
        # save final sae group to checkpoints folder
        self.save_checkpoint(
            trainer=self,
            checkpoint_name=f"final_{self.n_training_tokens}",
            wandb_aliases=["final_model"],
        )

        pbar.close()
        return self.sae
    def _run_and_log_evals(self):
        # record loss frequently, but not all the time.
        if (self.n_training_steps) % (
            self.cfg.wandb_log_frequency * self.cfg.eval_every_n_wandb_logs
        ) == 0:
            self.sae.eval()
            eval_metrics, _ = run_evals(
                sae=self.sae,
                activation_store=self.activations_store,
                model=self.model,
                eval_config=self.trainer_eval_config,
                model_kwargs=self.cfg.model_kwargs,
            )  # not calculating featurwise metrics here.

            # Remove eval metrics that are already logged during training
            eval_metrics.pop("metrics/explained_variance", None)
            eval_metrics.pop("metrics/explained_variance_std", None)
            eval_metrics.pop("metrics/l0", None)
            eval_metrics.pop("metrics/l1", None)
            eval_metrics.pop("metrics/mse", None)

            # Remove metrics that are not useful for wandb logging
            eval_metrics.pop("metrics/total_tokens_evaluated", None)

            W_dec_norm_dist = self.sae.W_dec.detach().float().norm(dim=1).cpu().numpy()
            eval_metrics["weights/W_dec_norms"] = wandb.Histogram(W_dec_norm_dist)  # type: ignore

            if self.sae.cfg.architecture == "standard":
                b_e_dist = self.sae.b_enc.detach().float().cpu().numpy()
                eval_metrics["weights/b_e"] = wandb.Histogram(b_e_dist)  # type: ignore
            elif self.sae.cfg.architecture == "gated":
                b_gate_dist = self.sae.b_gate.detach().float().cpu().numpy()
                eval_metrics["weights/b_gate"] = wandb.Histogram(b_gate_dist)  # type: ignore
                b_mag_dist = self.sae.b_mag.detach().float().cpu().numpy()
                eval_metrics["weights/b_mag"] = wandb.Histogram(b_mag_dist)  # type: ignore

            wandb.log(
                eval_metrics,
                step=self.n_training_steps,
            )
            self.sae.train()


class ModifiedSAETrainingRunner(SAETrainingRunner):
    @override
    def run(self):
        """
        Run the training of the SAE.
        """
        if self.cfg.log_to_wandb:
            wandb.init(
                project=self.cfg.wandb_project,
                entity=self.cfg.wandb_entity,
                config=cast(Any, self.cfg),
                name=self.cfg.run_name + f"{self.cfg.normalize_activations} {self.cfg.initialization}",
                notes=f"{self.cfg.dir_id}\n",
                id=self.cfg.wandb_id,
            )

        trainer = ModifiedSAETrainer(
            model=self.model,
            sae=self.sae,
            activation_store=self.activations_store,
            save_checkpoint_fn=self.save_checkpoint,
            cfg=self.cfg,
        )

        self._compile_if_needed()
        sae = self.run_trainer_with_interruption_handling(trainer)

        if self.cfg.log_to_wandb:
            wandb.finish()

        return sae



