import os
import sys
import pandas as pd
import torch
from functools import partial
import re 

script_dir = os.path.dirname(os.path.abspath(__file__))
project_dir = os.path.join(script_dir, '..')
sys.path.append(script_dir)
sys.path.append(project_dir)

from utils import load_sae, project_onto_subspace, ablate_subspace
from sparsify.sparsify import Sae


def get_sae_subspace(model_name, hook_layer, device, type='dense', threshold=0.1):
    # load densities
    file_model_name = model_name.replace('google/', '')
    if 'Llama-3.1-8B' in model_name:
        file_model_name = 'llama3.1-8b'
    path = f"{project_dir}/data/frequencies/{file_model_name}/{hook_layer}.json"

    df = pd.read_json(path)

    # make 'feature_idx' column the index
    df = df.set_index('feature_idx')

    # remove rows with duplicate feature_idx
    df = df[~df.index.duplicated(keep='first')]

    # sort by feature_idx
    df = df.sort_index()

    print(f'Loaded {len(df)} entries')

    if type == 'dense':
        latents_subset = df[df['frac_nonzero'] > threshold].index
    elif type == 'sparse':
        latents_subset = df[(df['frac_nonzero'] <= threshold) & (df['frac_nonzero'] > 0)].index
    else:
        raise ValueError(f'Unknown type: {type}')

    # print number of dense features
    print(f'Found {len(latents_subset)} dense features')

    # load sae
    sae = load_sae(model_name.lower(), width='16k', layer_idx=hook_layer, location='res', device='cpu')

    W_dec = sae.W_dec.detach().float().to(device)

    # take decoder weights corresponding to dense features
    W_dec_dense = W_dec[latents_subset]

    return W_dec_dense


def get_custom_sae_subspace(model_name, ckpt_name, device, type='dense', threshold=0.1, dataset='stas/c4-en-10k', num_docs=5000, decoder=True):
    # load SAE
    folder = f'{project_dir}/sae-ckpts'
    path = f'{folder}/{ckpt_name}/'
    saes = Sae.load_many(path, local=True, device=device)
    sae = list(saes.values())[0]

    if '-ckpt' in ckpt_name:
        ckpt_name += '_ablated_subspace'
    path = f'{project_dir}/data/custom_sae_frequencies/{model_name}/{ckpt_name}/{dataset.split("/")[1]}_{num_docs}/densities.pt'
    print(f'Loading densities from {path}')
    densities = torch.load(path)

    if type == 'dense':
        latents_subset = torch.where(densities > threshold)[0]
    elif type == 'sparse':
        latents_subset = torch.where((densities <= threshold) & (densities > 0))[0]
    else:
        raise ValueError(f'Unknown type: {type}')
    print(f'Found {len(latents_subset)} dense features')

    if decoder:
        W = sae.W_dec.detach().float().to(device)
    else:
        W = sae.encoder.weight.detach().float().to(device)

    # take decoder weights corresponding to dense features
    W_dense = W[latents_subset]

    return W_dense


def get_subspace_ablation(model_name, hook_layer, device, threshold=0.1, type='only_dense', ckpt_name=None, dataset='stas/c4-en-10k', num_docs=5000, hf_hook=True, dtype=torch.float32, random_indices=None):
    if hf_hook:
        model_name = 'gpt2' if 'openai-community/gpt2' in model_name else model_name

    if 'sparse_subset' in type:
        n_dims = int(type.split('_')[-1])
        if ckpt_name is None:
            W_dec_subset = get_sae_subspace(model_name, hook_layer, device, type='sparse', threshold=threshold)
        else:
            W_dec_subset = get_custom_sae_subspace(model_name, ckpt_name, device, type='sparse', threshold=threshold, dataset=dataset, num_docs=num_docs)
        # take a random subset of n_dims
        random_indices = torch.randperm(W_dec_subset.size(0))[:n_dims] if random_indices is None else random_indices
        W_dec_subset = W_dec_subset[random_indices]
        print(f'Using a random sparse subset of {len(random_indices)} latents: {random_indices}')

    elif type in ['only_dense', 'only_sparse'] or 'dec_score' in type:
        if ckpt_name is None:
            W_dec_subset = get_sae_subspace(model_name, hook_layer, device, type='dense', threshold=threshold)
        else:
            W_dec_subset = get_custom_sae_subspace(model_name, ckpt_name, device, type='dense', threshold=threshold, dataset=dataset, num_docs=num_docs)

        if 'dense_dec_score' in type:
            dec_score_threshold = float(re.search(r'dec_score_(\d+\.\d+)', type).group(1))
            print(f'Dense latents with dec score >{dec_score_threshold}')
            # load dec scores
            path = f'{project_dir}/data/dec_scores/gemma-2-2b/dense.jsonl'
            df = pd.read_json(path, lines=True)
            dec_scores_4 = torch.tensor(df['4'].to_numpy())
            indices = torch.where(dec_scores_4 > dec_score_threshold)[0]
            W_dec_subset = W_dec_subset[indices]
            print(f'Found {len(indices)} dense latents with dec score >{dec_score_threshold}: {indices}')

    if hf_hook:
        # change the dtype of W_dec_dense to bfloat16
        W_dec_subset = W_dec_subset.to(dtype)

    if 'dense' in type or 'projection_sparse_subset' in type:
        def hook_fn(activations, hook, subspace):
            return project_onto_subspace(activations, subspace)
    
    elif type == 'only_sparse' or 'ablation_sparse_subset' in type:
        def hook_fn(activations, hook, subspace):
            return ablate_subspace(activations, subspace)
        
    else:
        raise ValueError(f'Unknown type: {type}')
    
    fn = partial(hook_fn, subspace=W_dec_subset)
        
    if hf_hook:
        def hf_hook_fn(module, inputs, outputs):
            # Maybe unpack tuple outputs
            if isinstance(outputs, tuple):
                everything_else = outputs[1:]
                outputs = outputs[0]
                return (fn(outputs, module),) + everything_else
            
            else:
                return fn(outputs, module)
        
        return hf_hook_fn, random_indices
    else:
        return fn, random_indices
    
