from pathlib import Path
import torch
import os
import json
from nnsight import LanguageModel
from sae_auto_interp.autoencoders import load_eai_autoencoders
import sae_bench_utils.general_utils as general_utils
import evals.core.main as core
import evals.sparse_probing.main as sparse_probing
import custom_saes.custom_sae_config as custom_sae_config
from experiment_config import config
from config_loader import parse_config_overrides, apply_overrides
from rerandomized_model import RerandomizedModel
from noise_embedding import NoiseEmbeddingNNsight
from model_adapter import ModelArchitectureAdapter


def process_model():
    """Process model using SAE bench evaluations"""
    save_dir = config.save_directory
    
    if not save_dir.exists():
        raise FileNotFoundError(f"Model directory not found: {save_dir}")
        
    print(f"Using model: {config.model_name}")
    print(f"Using dataset: {config.dataset}")
    if config.dataset_name:
        print(f"Dataset config: {config.dataset_name}")
    
    # Setup SAE bench environment
    device = general_utils.setup_environment()
    
    # Use config's evaluation directory
    eval_dir = config.eval_directory
    output_folders = {
        "core": eval_dir / "sae_bench/core",
        "sparse_probing": eval_dir / "sae_bench/sparse_probing",
    }
    
    # Create output directories
    for folder in output_folders.values():
        folder.mkdir(parents=True, exist_ok=True)
    
    # Initialize model
    model = LanguageModel(
        config.model_name,
        device_map=config.device_map,
        dispatch=True,
        torch_dtype=getattr(torch, config.torch_dtype),
        revision="step0" if config.use_step0 else None,
        cache_dir=config.cache_dir
    )

    # Create model adapter
    adapter = ModelArchitectureAdapter(model)
    print(f"Detected model architecture: {adapter.model_type}")
    
    if config.rerandomize:
        print(f"Rerandomizing model parameters...")
        model = RerandomizedModel(
            model,
            rerandomize_embeddings=config.rerandomize_embeddings,
            rerandomize_layer_norm=config.rerandomize_layer_norm,
            seed=config.random_seed
        ).model

    if config.use_random_control:
        print(f"Applying random control with noise std: {config.noise_std}")
        model = NoiseEmbeddingNNsight(model, std=config.noise_std)

    print("Loading autoencoders...")
    # Load autoencoders with correct layer indices
    layer_indices = list(range(0, adapter.num_layers(), config.layer_stride))
    module_name = adapter.return_module_str
    if config.use_random_control:
        submodule_dict, model.model = load_eai_autoencoders(
            model.model,
            layer_indices,
            weight_dir=str(config.save_directory),
            module="res",
            module_str=module_name
        )
    else:
        submodule_dict, model = load_eai_autoencoders(
            model,
            layer_indices,
            weight_dir=str(config.save_directory),
            module="res",
            module_str=module_name
        )

    # # Define wrapper methods
    # def wrapped_forward(self, x):
    #     pre_acts = self.pre_acts(x)
    #     top_acts, top_indices = self.select_topk(pre_acts)
    #     return self.decode(top_acts, top_indices)

    # def wrapped_decode(self, h):
    #     top_acts, top_indices = self.select_topk(h)
    #     sae_out = self.original_decode(top_acts, top_indices)
    #     return sae_out
                
    # def wrapped_encode(self, x):
    #     return self.pre_acts(x)
    
    def wrapped_forward(self, x):
        x = self.encode(x)
        recon = self.decode(x)
        return recon
    
    def wrapped_encode(self, x):
        post_relu_feat_acts_BF = self.pre_acts(x)

        tops_acts_BK, top_indices_BK = self.select_topk(post_relu_feat_acts_BF)

        buffer_BF = torch.zeros_like(post_relu_feat_acts_BF)
        encoded_acts_BF = buffer_BF.scatter_(dim=-1, index=top_indices_BK, src=tops_acts_BK)
        return encoded_acts_BF
    
    def wrapped_decode(self, x):
        return (x @ self.W_dec) + self.b_dec

    # Convert loaded SAEs to format expected by SAE bench
    selected_saes = []
    num_layers = adapter.num_layers()
    
    # Store original decode methods to restore later
    original_decodes = {}
    
    # Process each layer
    for layer in range(0, num_layers, config.layer_stride):
        module_name = adapter.get_layer_prefix(layer)
        if module_name in submodule_dict:
            sae = submodule_dict[module_name].ae.ae
            
            # Add W_enc attribute that SAE bench expects
            sae.W_enc = sae.encoder.weight.T
            
            # Store original methods
            original_decodes[f"layer_{layer}"] = {
                'decode': sae.decode,
                'encode': sae.encode,
                'forward': sae.forward
            }
            
            #Apply wrappers
            sae.original_decode = sae.decode
            sae.decode = wrapped_decode.__get__(sae)
            sae.encode = wrapped_encode.__get__(sae)
            sae.forward = wrapped_forward.__get__(sae)
            
            # Create SAE bench config
            sae.cfg = custom_sae_config.CustomSAEConfig(
                model_name=config.model_name,
                d_in=sae.d_in,
                d_sae=sae.num_latents,
                hook_name=adapter.get_hook_pattern(layer),
                hook_layer=layer,
                dtype=config.torch_dtype,
                architecture="trained_sae",
                training_tokens=config.max_tokens
            )
            
            unique_sae_id = f"trained_sae_layer_{layer}"
            selected_saes.append((unique_sae_id, sae))
            print(f"Loaded SAE for layer {layer}")

    if not selected_saes:
        print("No valid SAEs found to evaluate!")
        return

    # Run evaluations
    print("\nRunning core evaluations...")
    _ = core.multiple_evals(
        selected_saes=selected_saes,
        n_eval_reconstruction_batches=config.n_eval_reconstruction_batches,
        n_eval_sparsity_variance_batches=config.n_eval_sparsity_variance_batches,
        eval_batch_size_prompts=config.eval_batch_size_prompts,
        compute_featurewise_density_statistics=False,
        compute_featurewise_weight_based_metrics=False,
        exclude_special_tokens_from_reconstruction=True,
        dataset=config.dataset,
        context_size=config.cache_ctx_len,
        output_folder=output_folders["core"],
        verbose=True,
        force_rerun=True,
        dtype=config.torch_dtype,
        config=config,
        gpt_model=model,
        random_control=config.use_random_control,
    )


    # Restore original methods
    for i, (sae_id, sae) in enumerate(selected_saes):
        layer = i * config.layer_stride
        methods = original_decodes[f"layer_{layer}"]
        sae.decode = methods['decode']
        sae.encode = methods['encode']
        sae.forward = methods['forward']

    print("\nSAE bench evaluation complete!")

def main():
    # Load and apply configuration overrides
    overrides = parse_config_overrides()
    if "--no-reinit_non_embedding" in os.sys.argv:
        overrides["reinit_non_embedding"] = False
    apply_overrides(config, overrides)
    
    os.environ['HF_HOME'] = str(config.cache_dir)
    process_model()

if __name__ == "__main__":
    main()