import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformer_lens import HookedTransformer, loading_from_pretrained
from sae_lens import SAE
from sparsify.sparsify.data import chunk_and_tokenize
from datasets import load_dataset


# Load SAE
def load_sae(model_name, width, layer_idx, location='res', device='cuda'):

    if model_name == 'gpt2-small':
        if location == 'res':
            location = 'resid-post'
            loc = 'resid_post'
        else:
            raise ValueError('Invalid location')
        release = f"gpt2-small-{location}-v5-32k"
        sae_id = f"blocks.{layer_idx}.hook_{loc}"
    
    elif 'gemma' in model_name:
        release = f'gemma-scope-2b-pt-{location}-canonical'
        sae_id = f"layer_{layer_idx}/width_{width}/canonical"
    else:
        raise ValueError('Invalid model name')
    sae, cfg_dict, sparsity = SAE.from_pretrained(release=release, sae_id=sae_id, device=device)

    return sae


def tl_name_to_hf_name(model_name): 
    hf_model_name = loading_from_pretrained.get_official_model_name(model_name)
    return hf_model_name


def load_model_from_tl_name(model_name, device='cuda', cache_dir=None, hf_token=None, hf_model=False): 
    hf_model_name = tl_name_to_hf_name(model_name)
    print(f"Loading model from {hf_model_name}")

    tokenizer = AutoTokenizer.from_pretrained(hf_model_name, trust_remote_code=True, cache_dir=cache_dir, token=hf_token)

    if hf_model:
        model = AutoModelForCausalLM.from_pretrained(hf_model_name, token=hf_token, cache_dir=cache_dir)
        return model, tokenizer

    #loading model 
    if "llama" in model_name.lower() or "gemma" in model_name.lower() or "mistral" in model_name.lower(): 
        hf_model = AutoModelForCausalLM.from_pretrained(hf_model_name, token=hf_token, cache_dir=cache_dir)
        model = HookedTransformer.from_pretrained(model_name=model_name, hf_model=hf_model, tokenizer=tokenizer, device=device, cache_dir=cache_dir)
    else: 
        model = HookedTransformer.from_pretrained(model_name, device=device, cache_dir=cache_dir)


    return model, tokenizer 


def adjust_vectors(v, u, target_values):
    """
    Adjusts a batch of vectors v such that their projections along the unit vector u equal the target values.

    Parameters:
    - v: A 2D tensor of shape (n, d), representing the batch of vectors to be adjusted.
    - u: A 1D unit tensor of shape (d,), representing the direction along which the adjustment is made.
    - target_values: A 1D tensor of shape (n,), representing the desired projection values of the vectors in v along u.

    Returns:
    - adjusted_v: The adjusted batch of vectors such that their projections along u are equal to the target values.
    """
    current_projections = v @ u  # Current projections of v onto u
    delta = target_values - current_projections  # Differences needed to reach the target projections
    adjusted_v = v + delta.unsqueeze(-1) * u  # Adjust v by the deltas along the direction of u
    return adjusted_v


def ablate_subspace(v, X):
    """
    Adjusts a batch of vectors v such that their projections onto the subspace spanned by U are zero.

    Parameters:
    - v: A 2D tensor (or array) of shape (n, d), representing the batch of vectors to be adjusted.
    - U: A 2D tensor (or array) of shape (m, d), representing the set of vectors defining the subspace.

    Returns:
    - adjusted_v: The adjusted batch of vectors, with no component along any vector in U.
    """
    # if X is bfloat16, convert it to float32
    X_dtype = X.dtype
    if X_dtype == torch.bfloat16:
        X = X.float()
        v = v.float()

    X_pinv = torch.pinverse(X)
    proj = v @ X_pinv @ X
    adjusted_v = v - proj

    if X_dtype == torch.bfloat16:
        adjusted_v = adjusted_v.to(torch.bfloat16)
    return adjusted_v


def project_onto_subspace(v, X):
    """
    Projects a batch of vectors v onto the subspace spanned by X.

    Parameters:
    - v: A 2D tensor (or array) of shape (n, d), representing the batch of vectors.
    - X: A 2D tensor (or array) of shape (k, d), whose rows span the subspace to keep.

    Returns:
    - projected_v: The projection of v onto the subspace spanned by X.
    """
    # if v is bfloat16, convert it to float32
    X_dtype = X.dtype
    if X_dtype == torch.bfloat16:
        X = X.float()
        v = v.float()

    X_pinv = torch.pinverse(X)  # shape (d, k)
    projected_v = v @ X_pinv @ X  # shape (n, d)

    if X_dtype == torch.bfloat16:
        projected_v = projected_v.to(torch.bfloat16)
    return projected_v


def load_tokenized_data(cache_dir, dataset_name, n_sequences, context_length, model_name):
    """
    Load tokenized data from a cache if available; otherwise, load the raw dataset,
    tokenize it, save the tokenized data to cache, and return it.

    Args:
        cache_dir (str): Directory where cached data is stored.
        dataset_name (str): Name of the dataset.
        n_sequences (int): Minimum number of sequences required.
        context_length (int): Number of tokens per sequence.
        model_name (str): Name of the model for the tokenizer.

    Returns:
        tokenized: The tokenized dataset.
    """
    folder = os.path.join(cache_dir, f"tokenized_{model_name}", dataset_name)
    if not os.path.exists(folder):
        os.makedirs(folder, exist_ok=True)
    
    # Search for a cached file with sufficient sequences and matching context length.
    for file in os.listdir(folder):
        base, _ = os.path.splitext(file)
        parts = base.split('_')
        if len(parts) < 2:
            continue
        try:
            file_n_sequences = int(parts[0])
            file_context_length = int(parts[1])
        except ValueError:
            continue
        if file_n_sequences == n_sequences and file_context_length == context_length:
            path = os.path.join(folder, file)
            print(f"Loading tokenized data from {path}")
            return torch.load(path, weights_only=False)
    
    print(f"No tokenized data found for {n_sequences} sequences and {context_length} tokens. Generating new data.")
    
    # Load and process the dataset if no cache is found.
    dataset = load_dataset(dataset_name, split="train", cache_dir=cache_dir, streaming=False)
    dataset = dataset.select(range(n_sequences))
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    tokenized = chunk_and_tokenize(dataset, tokenizer, max_seq_len=context_length)
    
    # Save the tokenized data to cache.
    save_path = os.path.join(folder, f"{n_sequences}_{context_length}")
    print(f"Saving tokenized data to {save_path}")
    torch.save(tokenized, save_path)
    
    return tokenized


def principal_angles_from_vectors(A, B):
    """
    Compute the principal angles between the subspaces spanned by the columns of A and B.

    Parameters:
    - A: torch.Tensor of shape (d, n), an arbitrary set of vectors spanning the first subspace.
    - B: torch.Tensor of shape (d, m), an arbitrary set of vectors spanning the second subspace.

    Returns:
    - angles: torch.Tensor containing the principal angles (in radians) between the two subspaces.
    
    Steps:
      1. Orthonormalize the columns of A and B via QR decomposition.
      2. Compute M = Q_A.T @ Q_B, whose singular values equal the cosines of the principal angles.
      3. Return the principal angles by taking the arccos of these singular values.
    """
    A = A.to('cpu')
    B = B.to('cpu')
    # Orthonormalize A and B. The 'reduced' mode returns Q with a number of columns equal to the rank.
    Q_A, _ = torch.linalg.qr(A, mode='reduced')  # Q_A: (d, r1)
    Q_B, _ = torch.linalg.qr(B, mode='reduced')  # Q_B: (d, r2)
    
    # Compute the matrix of cosines between the basis vectors
    M = Q_A.T @ Q_B  # shape: (r1, r2)
    
    # Compute the singular values of M; these are the cosines of the principal angles.
    singular_values = torch.linalg.svdvals(M)
    
    # Clamp singular values to account for potential numerical issues.
    singular_values = torch.clamp(singular_values, -1.0, 1.0)
    
    # Compute the principal angles (in radians)
    angles = torch.acos(singular_values)
    return angles

