import os
from typing import Any, Type, TypeVar, cast

import torch
from accelerate.utils import send_to_device
from torch import Tensor, nn
from transformers import PreTrainedModel
from functools import partial

T = TypeVar("T")


def assert_type(typ: Type[T], obj: Any) -> T:
    """Assert that an object is of a given type at runtime and return it."""
    if not isinstance(obj, typ):
        raise TypeError(f"Expected {typ.__name__}, got {type(obj).__name__}")

    return cast(typ, obj)


def get_layer_list(model: PreTrainedModel) -> tuple[str, nn.ModuleList]:
    """Get the list of layers to train SAEs on."""
    N = assert_type(int, model.config.num_hidden_layers)
    candidates = [
        (name, mod)
        for (name, mod) in model.named_modules()
        if isinstance(mod, nn.ModuleList) and len(mod) == N
    ]
    assert len(candidates) == 1, "Could not find the list of layers."

    return candidates[0]


@torch.inference_mode()
def resolve_widths(
    model: PreTrainedModel,
    module_names: list[str],
    dim: int = -1,
) -> dict[str, int]:
    """Find number of output dimensions for the specified modules."""
    module_to_name = {model.get_submodule(name): name for name in module_names}
    shapes: dict[str, int] = {}

    def hook(module, _, output):
        # Unpack tuples if needed
        if isinstance(output, tuple):
            output, *_ = output

        name = module_to_name[module]
        shapes[name] = output.shape[dim]

    handles = [mod.register_forward_hook(hook) for mod in module_to_name]
    dummy = send_to_device(model.dummy_inputs, model.device)
    try:
        model(**dummy)
    except RuntimeError as e:
        if 'expected input with shape' in str(e):
            print('Shape mismatch should not be a problem')
    finally:
        for handle in handles:
            handle.remove()

    return shapes

def get_resize_hook(size):
    def resize_hook(module, _, output, output_size):
        """Resize the output tensor to match the given shape."""
        if isinstance(output, tuple):
            output, *_ = output

        # get first two dimensions of output
        batch_size, seq_len = output.shape[:2]
        # generate zeros with shape (batch_size, seq_len, size)
        zeros = torch.zeros(batch_size, seq_len, output_size, device=output.device)
        zeros = zeros.to(output.dtype)

        return zeros
        
    return partial(resize_hook, output_size=size)

# Fallback implementation of SAE decoder
def eager_decode(top_indices: Tensor, top_acts: Tensor, W_dec: Tensor):
    # check whther device of top_acts is mps
    if top_acts.device.type == "mps":
        was_mps = True
        # convert to cpu
        top_acts = top_acts.cpu()
        W_dec = W_dec.cpu()
        top_indices = top_indices.cpu()

    # check whether top_acts has 3 dimensions
    if top_acts.dim() == 3:
        batch_size, num_samples, _ = top_acts.shape
        # if so, flatten it to 2 dimensions
        top_acts = top_acts.flatten(0, 1)
        top_indices = top_indices.flatten(0, 1)
        d_model = W_dec.shape[0]
        out = nn.functional.embedding_bag(
            top_indices, W_dec.mT, per_sample_weights=top_acts, mode='sum'
        )
        # reshape out to (batch_size, num_samples, d_model)
        out = out.view(batch_size, num_samples, d_model)
    else:
        out = nn.functional.embedding_bag(
            top_indices, W_dec.mT, per_sample_weights=top_acts, mode='sum'
        )
    if was_mps:
        # convert back to mps
        out = out.to('mps')
    return out

# Triton implementation of SAE decoder
def triton_decode(top_indices: Tensor, top_acts: Tensor, W_dec: Tensor):
    return xformers_embedding_bag(top_indices, W_dec.mT, top_acts)


try:
    from .xformers import xformers_embedding_bag
except ImportError:
    decoder_impl = eager_decode
    print("Triton not installed, using eager implementation of SAE decoder.")
else:
    if os.environ.get("SAE_DISABLE_TRITON") == "1":
        print("Triton disabled, using eager implementation of SAE decoder.")
        decoder_impl = eager_decode
    else:
        decoder_impl = triton_decode


def generate_data_from_basis_with_variance(basis, m, k, variance_explained):
    """
    Generates a batch of m sets, each containing k vectors in ℝ^b that lie in the subspace
    spanned by the given basis. The variation along each basis direction is determined by the
    provided variance_explained vector.

    Parameters:
    - basis: torch.Tensor of shape (b, n) representing an orthonormal basis for an n-dimensional subspace.
    - m (int): Batch size (number of sets).
    - k (int): Number of vectors per set.
    - variance_explained: torch.Tensor of shape (n,), representing the variance along each basis direction.
    
    Returns:
    - data: A torch.Tensor of shape (m, k, b) containing the generated vectors.
    """
    b, n = basis.shape
    variance_explained = torch.tensor(variance_explained, device=basis.device)
    assert variance_explained.shape[0] == n, "variance_explained must have length equal to n"
    
    # Compute standard deviations for each direction
    stds = torch.sqrt(variance_explained)  # shape (n,)
    
    # Generate random coefficients for each basis direction.
    # Each coefficient is drawn from N(0, std_i^2), i.e. std_i * N(0,1)
    # coeffs shape: (m, k, n)
    coeffs = torch.randn(m, k, n).to(stds.device) * stds.view(1, 1, n)
    
    # Multiply the coefficients with the basis to produce vectors in ℝ^b.
    # basis.T has shape (n, b), so the result has shape (m, k, b)
    data = torch.matmul(coeffs, basis.T)
    
    return data
