from collections import defaultdict
from dataclasses import asdict
from fnmatch import fnmatchcase
from typing import Sized
from glob import glob

import torch
import torch.distributed as dist
from datasets import Dataset as HfDataset
from natsort import natsorted
from safetensors.torch import load_model
from torch import Tensor, nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import PreTrainedModel, get_linear_schedule_with_warmup

from .config import TrainConfig
from .data import MemmapDataset
from .sparse_coder import SparseCoder
from .utils import get_layer_list, resolve_widths, generate_data_from_basis_with_variance, get_resize_hook


class Trainer:
    def __init__(
        self,
        cfg: TrainConfig,
        dataset: HfDataset | MemmapDataset,
        model: PreTrainedModel,
    ):
        if cfg.hookpoints:
            assert not cfg.layers, "Cannot specify both `hookpoints` and `layers`."

            # Replace wildcard patterns
            raw_hookpoints = []
            for name, _ in model.named_modules():
                if any(fnmatchcase(name, pat) for pat in cfg.hookpoints):
                    raw_hookpoints.append(name)

            # Natural sort to impose a consistent order
            cfg.hookpoints = natsorted(raw_hookpoints)
        else:
            # If no layers are specified, train on all of them
            if not cfg.layers:
                N = model.config.num_hidden_layers
                cfg.layers = list(range(0, N, cfg.layer_stride))

            # Now convert layers to hookpoints
            layers_name, _ = get_layer_list(model)
            cfg.hookpoints = [f"{layers_name}.{i}" for i in cfg.layers]

        self.cfg = cfg
        self.dataset = dataset
        self.distribute_modules()
        self.random_basis = None

        N = len(cfg.hookpoints)
        assert isinstance(dataset, Sized)
        num_examples = len(dataset)

        device = model.device
        input_widths = resolve_widths(model, cfg.hookpoints)
        unique_widths = set(input_widths.values())

        if cfg.distribute_modules and len(unique_widths) > 1:
            # dist.all_to_all requires tensors to have the same shape across ranks
            raise ValueError(
                f"All modules must output tensors of the same shape when using "
                f"`distribute_modules=True`, got {unique_widths}"
            )

        self.model = model

        # Initialize all the SAEs
        print(f"Initializing SAEs with random seed(s) {cfg.init_seeds}")
        self.saes = {}
        for hook in self.local_hookpoints():
            for seed in cfg.init_seeds:
                torch.manual_seed(seed)

                # Add suffix to the name to disambiguate multiple seeds
                name = f"{hook}/seed{seed}" if len(cfg.init_seeds) > 1 else hook
                self.saes[name] = SparseCoder(
                    input_widths[hook], cfg.sae, device, dtype=torch.float32
                )

        pgs = [
            {
                "params": sae.parameters(),
                # Auto-select LR using 1 / sqrt(d) scaling law from Fig 3 of the paper
                "lr": cfg.lr or 2e-4 / (sae.num_latents / (2**14)) ** 0.5,
            }
            for sae in self.saes.values()
        ]
        # Dedup the learning rates we're using, sort them, round to 2 decimal places
        lrs = [f"{lr:.2e}" for lr in sorted(set(pg["lr"] for pg in pgs))]
        print(f"Learning rates: {lrs}" if len(lrs) > 1 else f"Learning rate: {lrs[0]}")

        try:
            from bitsandbytes.optim import Adam8bit as Adam

            print("Using 8-bit Adam from bitsandbytes")
        except ImportError:
            from torch.optim import Adam

            print("bitsandbytes 8-bit Adam not available, using torch.optim.Adam")
            print("Run `pip install bitsandbytes` for less memory usage.")

        self.global_step = 0
        self.num_tokens_since_fired = {
            name: torch.zeros(sae.num_latents, device=device, dtype=torch.long)
            for name, sae in self.saes.items()
        }

        self.optimizer = Adam(pgs)
        self.lr_scheduler = get_linear_schedule_with_warmup(
            self.optimizer, cfg.lr_warmup_steps, num_examples // cfg.batch_size
        )

        num_latents = list(self.saes.values())[0].num_latents
        self.initial_k = min(num_latents, round(list(input_widths.values())[0] * 10))
        self.final_k = self.cfg.sae.k

    def load_state(self, path: str):
        """Load the trainer state from disk."""
        device = self.model.device

        # Load the train state first so we can print the step number
        train_state = torch.load(
            f"{path}/state.pt", map_location=device, weights_only=True
        )
        train_state["num_tokens_since_fired"] = {}

        for file in glob(f"{path}/rank_*_state.pt"):
            rank_train_state = torch.load(file, map_location=device)
            train_state["num_tokens_since_fired"].update(rank_train_state["num_tokens_since_fired"])


        self.global_step = train_state["global_step"]
        self.num_tokens_since_fired = {
            k: train_state["num_tokens_since_fired"][k]
            for k in self.local_hookpoints()
        }

        print(
            f"\033[92mResuming training at step {self.global_step} from '{path}'\033[0m"
        )

        lr_state = torch.load(
            f"{path}/lr_scheduler.pt", map_location=device, weights_only=True
        )
        opt_state = torch.load(
            f"{path}/optimizer.pt", map_location=device, weights_only=True
        )
        self.optimizer.load_state_dict(opt_state)
        self.lr_scheduler.load_state_dict(lr_state)

        for name, sae in self.saes.items():
            load_model(sae, f"{path}/{name}/sae.safetensors", device=str(device))

    def get_current_k(self) -> int:
        """Get the current k value based on a linear decay schedule."""
        if self.global_step >= self.cfg.k_decay_steps:
            return self.final_k
        
        progress = self.global_step / self.cfg.k_decay_steps
        return round(self.initial_k * (1 - progress) + self.final_k * progress)
    
    
    def fit(self):
        # Use Tensor Cores even for fp32 matmuls
        torch.set_float32_matmul_precision("high")

        rank_zero = not dist.is_initialized() or dist.get_rank() == 0
        ddp = dist.is_initialized() and not self.cfg.distribute_modules

        if self.cfg.log_to_wandb and rank_zero:
            try:
                import wandb

                wandb.init(
                    name=self.cfg.run_name,
                    project="sae",
                    config=asdict(self.cfg),
                    save_code=True,
                )

                # dump args in wandb
                wandb_dict = asdict(self.cfg)
                wandb.config.update(self.cfg)

            except ImportError:
                print("Weights & Biases not installed, skipping logging.")
                self.cfg.log_to_wandb = False

        num_sae_params = sum(
            p.numel() for s in self.saes.values() for p in s.parameters()
        )
        num_model_params = sum(p.numel() for p in self.model.parameters())
        print(f"Number of SAE parameters: {num_sae_params:_}")
        print(f"Number of model parameters: {num_model_params:_}")

        num_batches = len(self.dataset) // self.cfg.batch_size
        if self.global_step > 0:
            assert hasattr(self.dataset, "select"), "Dataset must implement `select`"

            n = self.global_step * self.cfg.batch_size
            ds = self.dataset.select(range(n, len(self.dataset)))  # type: ignore
        else:
            ds = self.dataset

        device = self.model.device
        dl = DataLoader(
            ds,  # type: ignore
            batch_size=self.cfg.batch_size,
            # NOTE: We do not shuffle here for reproducibility; the dataset should
            # be shuffled before passing it to the trainer.
            shuffle=False,
        )
        pbar = tqdm(
            desc="Training",
            disable=not rank_zero,
            initial=self.global_step,
            total=num_batches,
        )

        did_fire = {
            name: torch.zeros(sae.num_latents, device=device, dtype=torch.bool)
            for name, sae in self.saes.items()
        }
        act_count = {
            name: torch.zeros(sae.num_latents, device=device, dtype=torch.float32)
            for name, sae in self.saes.items()
        }
        num_tokens_window = 0
        num_tokens_in_step = 0

        # For logging purposes
        avg_auxk_loss = defaultdict(float)
        avg_fvu = defaultdict(float)
        avg_l2_loss = defaultdict(float)
        avg_multi_topk_fvu = defaultdict(float)

        input_dict: dict[str, Tensor] = {}
        output_dict: dict[str, Tensor] = {}
        name_to_module = {
            name: self.model.get_submodule(name) for name in self.cfg.hookpoints
        }
        maybe_wrapped: dict[str, DDP] | dict[str, SparseCoder] = {}
        module_to_name = {v: k for k, v in name_to_module.items()}

        def hook(module: nn.Module, inputs, outputs):
            # Maybe unpack tuple inputs and outputs
            if isinstance(inputs, tuple):
                inputs = inputs[0]
            if isinstance(outputs, tuple):
                outputs = outputs[0]

            name = module_to_name[module]
            output_dict[name] = outputs.flatten(0, 1)

            # Remember the inputs if we're training a transcoder
            if self.cfg.sae.transcode:
                input_dict[name] = inputs.flatten(0, 1)

        k = self.get_current_k()
        for name, sae in self.saes.items():
            sae.cfg.k = k

        d_model = self.model.config.n_embd if hasattr(self.model.config, 'n_embd') else self.model.config.hidden_size

        if self.cfg.noise_dim > 0:
            self.generate_basis(d_model, self.cfg.noise_dim)

        input_widths = resolve_widths(self.model, self.cfg.hookpoints)

        for batch in dl:
            input_dict.clear()
            output_dict.clear()

            # Bookkeeping for dead feature detection
            N = batch["input_ids"].numel()
            num_tokens_in_step += N
            num_tokens_window += N

            # Forward pass on the model to get the next batch of activations
            handles = [
                mod.register_forward_hook(hook) for mod in name_to_module.values()
            ]
            handles_resizing = [
                mod.register_forward_hook(get_resize_hook(self.model.config.n_embd)) for name, mod in name_to_module.items() if input_widths[name] != d_model
            ]
            handles.extend(handles_resizing)
            if self.cfg.noise_dim == 0:
                try:
                    with torch.no_grad():
                        self.model(batch["input_ids"].to(device))
                finally:
                    for handle in handles:
                        handle.remove()
            elif self.cfg.noise_dim > 0:
                for module_name, module in name_to_module.items():
                    # generate noise that has size batch x seq x noise_dim
                    if self.cfg.noise_variance is None:
                        noise = self.generate_data_from_basis(batch["input_ids"].shape[0], batch["input_ids"].shape[1])
                    else:
                        noise = generate_data_from_basis_with_variance(self.random_basis, batch["input_ids"].shape[0], batch["input_ids"].shape[1], self.cfg.noise_variance)
                    output_dict[module_name] = noise.flatten(0, 1).to(device)


            if self.cfg.distribute_modules:
                input_dict = self.scatter_hiddens(input_dict)
                output_dict = self.scatter_hiddens(output_dict)

            for name, raw in self.saes.items():
                # Name may optionally contain a suffix of the form /seedN where N is an
                # integer. We only care about the part before the slash.
                hookpoint, _, _ = name.partition("/")

                # 'inputs' is distinct from outputs iff we're transcoding
                outputs = output_dict[hookpoint]
                inputs = input_dict.get(name, outputs)

                # On the first iteration, initialize the decoder bias
                if self.global_step == 0:
                    mean = self.maybe_all_reduce(outputs.mean(0))
                    raw.b_dec.data = mean.to(raw.dtype)

                if not maybe_wrapped:
                    # Wrap the SAEs with Distributed Data Parallel. We have to do this
                    # after we set the decoder bias, otherwise DDP will not register
                    # gradients flowing to the bias after the first step.
                    maybe_wrapped = (
                        {
                            name: DDP(sae, device_ids=[dist.get_rank()])
                            for name, sae in self.saes.items()
                        }
                        if ddp
                        else self.saes
                    )

                # Make sure the W_dec is still unit-norm if we're autoencoding
                if raw.cfg.normalize_decoder and not self.cfg.sae.transcode:
                    raw.set_decoder_norm_to_unit_norm()

                acc_steps = self.cfg.grad_acc_steps * self.cfg.micro_acc_steps
                denom = acc_steps * self.cfg.wandb_log_frequency
                wrapped = maybe_wrapped[name]

                # Save memory by chunking the activations
                in_chunks = inputs.chunk(self.cfg.micro_acc_steps)
                out_chunks = outputs.chunk(self.cfg.micro_acc_steps)
                for in_chunk, out_chunk in zip(in_chunks, out_chunks):
                    out = wrapped(
                        x=in_chunk,
                        y=out_chunk,
                        dead_mask=(
                            self.num_tokens_since_fired[name]
                            > self.cfg.dead_feature_threshold
                            if self.cfg.auxk_alpha > 0
                            else None
                        ),
                    )

                    avg_fvu[name] += float(
                        self.maybe_all_reduce(out.fvu.detach()) / denom
                    )
                    avg_l2_loss[name] += float(
                        self.maybe_all_reduce(out.l2_loss.detach()) / denom
                    )
                    if self.cfg.auxk_alpha > 0:
                        avg_auxk_loss[name] += float(
                            self.maybe_all_reduce(out.auxk_loss.detach()) / denom
                        )
                    if self.cfg.sae.multi_topk:
                        avg_multi_topk_fvu[name] += float(
                            self.maybe_all_reduce(out.multi_topk_fvu.detach()) / denom
                        )

                    loss = (
                        out.fvu
                        + self.cfg.auxk_alpha * out.auxk_loss
                        + out.multi_topk_fvu / 8
                    )
                    loss.div(acc_steps).backward()

                    # Update the did_fire mask
                    did_fire[name][out.latent_indices.flatten()] = True
                    self.maybe_all_reduce(did_fire[name], "max")  # max is boolean "any"

                    # # update the activation count
                    indices = out.latent_indices.flatten()
                    updates = torch.ones(indices.shape, dtype=act_count[name].dtype, device=act_count[name].device)
                    act_count[name].index_add_(0, indices, updates)

                # Clip gradient norm independently for each sparse coder
                torch.nn.utils.clip_grad_norm_(raw.parameters(), 1.0)

            self.save()
            exit(1)

            # Check if we need to actually do a training step
            step, substep = divmod(self.global_step + 1, self.cfg.grad_acc_steps)
            if substep == 0:
                if self.cfg.sae.normalize_decoder and not self.cfg.sae.transcode:
                    for sae in self.saes.values():
                        sae.remove_gradient_parallel_to_decoder_directions()

                self.optimizer.step()
                self.optimizer.zero_grad()
                self.lr_scheduler.step()

                k = self.get_current_k()
                for name, sae in self.saes.items():
                    sae.cfg.k = k

                ###############
                with torch.no_grad():
                    # Update the dead feature mask
                    for name, counts in self.num_tokens_since_fired.items():
                        counts += num_tokens_in_step
                        counts[did_fire[name]] = 0

                    # Reset stats for this step
                    num_tokens_in_step = 0
                    for mask in did_fire.values():
                        mask.zero_()

                if (
                    self.cfg.log_to_wandb
                    and (step + 1) % self.cfg.wandb_log_frequency == 0
                ):
                    info = {}

                    for name in self.saes:
                        mask = (
                            self.num_tokens_since_fired[name]
                            > self.cfg.dead_feature_threshold
                        )


                        info.update(
                            {
                                f"fvu/{name}": avg_fvu[name],
                                f"l2_loss/{name}": avg_l2_loss[name],
                                f"dead_pct/{name}": mask.mean(
                                    dtype=torch.float32
                                ).item(),
                            }
                        )
                        if self.cfg.auxk_alpha > 0:
                            info[f"auxk/{name}"] = avg_auxk_loss[name]
                        if self.cfg.sae.multi_topk:
                            info[f"multi_topk_fvu/{name}"] = avg_multi_topk_fvu[name]

                    avg_auxk_loss.clear()
                    avg_fvu.clear()
                    avg_l2_loss.clear()
                    avg_multi_topk_fvu.clear()

                    if self.cfg.distribute_modules:
                        outputs = [{} for _ in range(dist.get_world_size())]
                        dist.gather_object(info, outputs if rank_zero else None)
                        info.update({k: v for out in outputs for k, v in out.items()})

                    if (step + 1) % self.cfg.save_every == 0:
                        # get density statistics
                        for name, sae in self.saes.items():
                            densities = act_count[name] / num_tokens_window
                            info[f'density/below_1e-5/{name}'] = (densities < 1e-5).sum().item()
                            info[f'density/above_1e-1/{name}'] = (densities > 1e-1).sum().item()
                            info[f'density/above_2e-1/{name}'] = (densities > 2e-1).sum().item()

                            act_count[name].zero_()
                            num_tokens_window = 0     

                    if rank_zero:
                        info["k"] = k

                        wandb.log(info, step=step)

                if (step + 1) % self.cfg.save_every == 0:
                    self.save()

            self.global_step += 1
            pbar.update()

        self.save()
        pbar.close()

    def generate_basis(self, b, n):
        """
        Generates an orthonormal basis for an n-dimensional subspace in ℝ^b.

        Parameters:
        - b (int): Ambient space dimensionality.
        - n (int): Subspace dimensionality (must be less than b).

        Returns:
        - basis: A torch.Tensor of shape (b, n) whose columns form an orthonormal basis.
        """
        assert n < b, "Subspace dimensionality n must be less than ambient dimensionality b."
        A = torch.randn(b, n)
        Q, _ = torch.linalg.qr(A, mode='reduced')  # Q has shape (b, n)
        self.random_basis = Q

    def generate_data_from_basis_with_variance(self, m, k, variance_explained):
        """
        Generates a batch of m sets, each containing k vectors in ℝ^b that lie in the subspace
        spanned by the given basis. The variation along each basis direction is determined by the
        provided variance_explained vector.

        Parameters:
        - basis: torch.Tensor of shape (b, n) representing an orthonormal basis for an n-dimensional subspace.
        - m (int): Batch size (number of sets).
        - k (int): Number of vectors per set.
        - variance_explained: torch.Tensor of shape (n,), representing the variance along each basis direction.
        
        Returns:
        - data: A torch.Tensor of shape (m, k, b) containing the generated vectors.
        """
        b, n = self.random_basis.shape
        variance_explained = torch.tensor(variance_explained, device=self.random_basis.device)
        assert variance_explained.shape[0] == n, "variance_explained must have length equal to n"
        
        # Compute standard deviations for each direction
        stds = torch.sqrt(variance_explained)  # shape (n,)
        
        # Generate random coefficients for each basis direction.
        # Each coefficient is drawn from N(0, std_i^2), i.e. std_i * N(0,1)
        # coeffs shape: (m, k, n)
        coeffs = torch.randn(m, k, n) * stds.view(1, 1, n)
        
        # Multiply the coefficients with the basis to produce vectors in ℝ^b.
        # basis.T has shape (n, b), so the result has shape (m, k, b)
        data = torch.matmul(coeffs, self.random_basis.T)
        
        return data

    def generate_data_from_basis(self, m, k, noise_std=1.0):
        """
        Generates a batch of m sets, each containing k vectors in ℝ^b that lie in the subspace
        spanned by the given basis. The variation is provided by random Gaussian coefficients along the basis.

        Parameters:
        - basis: torch.Tensor of shape (b, n) representing an orthonormal basis.
        - m (int): Batch size (number of sets).
        - k (int): Number of vectors per set.
        - noise_std (float): Standard deviation of the Gaussian noise for the coefficients.

        Returns:
        - data: A torch.Tensor of shape (m, k, b) containing the generated vectors.
        """
        b, n = self.random_basis.shape
        # Generate random coefficients for each set; shape (m, k, n)
        coeffs = noise_std * torch.randn(m, k, n)
        # Multiply coefficients with the transpose of the basis to produce vectors in ℝ^b.
        # basis.T has shape (n, b), so the result is of shape (m, k, b)
        data = torch.matmul(coeffs, self.random_basis.T)
        return data

    def local_hookpoints(self) -> list[str]:
        return (
            self.module_plan[dist.get_rank()]
            if self.module_plan
            else self.cfg.hookpoints
        )

    def maybe_all_cat(self, x: Tensor) -> Tensor:
        """Concatenate a tensor across all processes."""
        if not dist.is_initialized() or self.cfg.distribute_modules:
            return x

        buffer = x.new_empty([dist.get_world_size() * x.shape[0], *x.shape[1:]])
        dist.all_gather_into_tensor(buffer, x)
        return buffer

    def maybe_all_reduce(self, x: Tensor, op: str = "mean") -> Tensor:
        if not dist.is_initialized() or self.cfg.distribute_modules:
            return x

        if op == "sum":
            dist.all_reduce(x, op=dist.ReduceOp.SUM)
        elif op == "mean":
            dist.all_reduce(x, op=dist.ReduceOp.SUM)
            x /= dist.get_world_size()
        elif op == "max":
            dist.all_reduce(x, op=dist.ReduceOp.MAX)
        else:
            raise ValueError(f"Unknown reduction op '{op}'")

        return x

    def distribute_modules(self):
        """Prepare a plan for distributing modules across ranks."""
        if not self.cfg.distribute_modules:
            self.module_plan = []
            print(f"Training on modules: {self.cfg.hookpoints}")
            return

        layers_per_rank, rem = divmod(len(self.cfg.hookpoints), dist.get_world_size())
        assert rem == 0, "Number of modules must be divisible by world size"

        # Each rank gets a subset of the layers
        self.module_plan = [
            self.cfg.hookpoints[start : start + layers_per_rank]
            for start in range(0, len(self.cfg.hookpoints), layers_per_rank)
        ]
        for rank, modules in enumerate(self.module_plan):
            print(f"Rank {rank} modules: {modules}")

    def scatter_hiddens(self, hidden_dict: dict[str, Tensor]) -> dict[str, Tensor]:
        """Scatter & gather the hidden states across ranks."""
        # Short-circuit if we have no data
        if not hidden_dict:
            return hidden_dict

        outputs = [
            # Add a new leading "layer" dimension to each tensor
            torch.stack([hidden_dict[hook] for hook in hookpoints], dim=1)
            for hookpoints in self.module_plan
        ]
        local_hooks = self.module_plan[dist.get_rank()]
        shape = next(iter(hidden_dict.values())).shape

        # Allocate one contiguous buffer to minimize memcpys
        buffer = outputs[0].new_empty(
            # The (micro)batch size times the world size
            shape[0] * dist.get_world_size(),
            # The number of layers we expect to receive
            len(local_hooks),
            # All other dimensions
            *shape[1:],
        )

        # Perform the all-to-all scatter
        inputs = buffer.split([len(output) for output in outputs])
        dist.all_to_all([x for x in inputs], outputs)

        # Return a list of results, one for each layer
        return {hook: buffer[:, i] for i, hook in enumerate(local_hooks)}

    def save(self):
        """Save the SAEs to disk."""

        save_run_name = self.cfg.run_name.replace(' ', '_').replace('openai-community/', '')
        path = f"sae-ckpts/{save_run_name}" or "sae-ckpts"
        rank_zero = not dist.is_initialized() or dist.get_rank() == 0

        if self.random_basis is not None:
            torch.save(self.random_basis, f"{path}/random_basis.pt")

        if rank_zero or self.cfg.distribute_modules:
            print("Saving checkpoint")

            for name, sae in self.saes.items():
                assert isinstance(sae, SparseCoder)

                sae.save_to_disk(f"{path}/{name}")

            rank = 0 if rank_zero else dist.get_rank()
            torch.save({"num_tokens_since_fired": self.num_tokens_since_fired}, f"{path}/rank_{rank}_state.pt")

        if rank_zero:
            torch.save(self.lr_scheduler.state_dict(), f"{path}/lr_scheduler.pt")
            torch.save(self.optimizer.state_dict(), f"{path}/optimizer.pt")
            torch.save({"global_step": self.global_step}, f"{path}/state.pt")

            self.cfg.save_json(f"{path}/config.json")

        # Barrier to ensure all ranks have saved before continuing
        if dist.is_initialized():
            dist.barrier()


# Support old name for compatibility
SaeTrainer = Trainer
