from nnsight import LanguageModel
from sae_auto_interp.autoencoders import load_eai_autoencoders
from sae_auto_interp.config import CacheConfig
from sae_auto_interp.features import FeatureCache
from sae_auto_interp.utils import load_tokenized_data
from noise_embedding import NoiseEmbeddingNNsight
import os
import torch
from pathlib import Path
from experiment_config import config
from config_loader import parse_config_overrides, apply_overrides
from rerandomized_model import RerandomizedModel
from model_adapter import ModelArchitectureAdapter
import json

def process_model():
    """Process model activations based on configuration"""
    save_dir = config.save_directory
    
    if not save_dir.exists():
        raise FileNotFoundError(f"Model directory not found: {save_dir}")
    
    print(f"Using model: {config.model_name}")
    print(f"Using dataset: {config.dataset}")
    if config.dataset_name:
        print(f"Dataset config: {config.dataset_name}")
    print(f"Dataset split: {config.test_dataset_split}")
    
    # Initialize model
    model = LanguageModel(
        config.model_name,
        device_map=config.device_map,
        dispatch=True,
        torch_dtype=getattr(torch, config.torch_dtype),
        revision="step0" if config.use_step0 else None,
        cache_dir=config.cache_dir
    )
    
    # Create model adapter
    adapter = ModelArchitectureAdapter(model)
    print(f"Detected model architecture: {adapter.model_type}")
    
    if config.rerandomize:
        print(f"Rerandomizing model parameters:")
        print(f"  - Embeddings: {'included' if config.rerandomize_embeddings else 'preserved'}")
        print(f"  - Layer Norm: {'randomized' if config.rerandomize_layer_norm else 'frozen'}")
        model = RerandomizedModel(
            model,
            rerandomize_embeddings=config.rerandomize_embeddings,
            rerandomize_layer_norm=config.rerandomize_layer_norm,
            seed=config.random_seed
        ).model

    if config.use_random_control:
        print(f"Applying random control with noise std: {config.noise_std}")
        model = NoiseEmbeddingNNsight(model, std=config.noise_std)

    print("Loading autoencoders...")
    
    # Get layer indices using adapter
    layer_indices = list(range(0, adapter.num_layers(), config.layer_stride))
    
    module_str = adapter.return_module_str
    if config.use_random_control:
        submodule_dict, model.model = load_eai_autoencoders(
            model.model,
            layer_indices,
            weight_dir=str(save_dir),
            module="res",
            module_str=module_str
        )
    else:
        submodule_dict, model = load_eai_autoencoders(
            model,
            layer_indices,
            weight_dir=str(save_dir),
            module="res",
            module_str=module_str
        )

    cfg = CacheConfig(
        dataset_repo=config.dataset,
        dataset_split=config.test_dataset_split,
        batch_size=config.cache_batch_size,
        ctx_len=config.cache_ctx_len,
        n_tokens=config.cache_n_tokens,
        n_splits=5,
    )  
    
    # Load tokenized data
    tokens = load_tokenized_data(
        ctx_len=cfg.ctx_len,
        tokenizer=model.tokenizer,
        dataset_repo=cfg.dataset_repo,
        dataset_split=cfg.dataset_split,
        dataset_row=config.text_key
    )
    
    # Create feature cache with model adapter
    cache = FeatureCache(
        model,
        submodule_dict,
        batch_size=cfg.batch_size
    )

    cfg_dict = cfg.to_dict()

    print("Processing tokens...")
    cache.run(cfg.n_tokens, tokens)
    
    # Save cache data
    save_dir = config.latents_directory
    print(f"Saving results to {save_dir}")
    
    cache.save_splits(
        n_splits=cfg.n_splits,
        save_dir=save_dir
    )
    
    cache.save_config(
        save_dir=save_dir,
        cfg=cfg,
        model_name=config.model_name
    )
    
    # Save additional configuration information
    if config.use_random_control:
        cfg_dict['noise_std'] = config.noise_std
        cfg_dict['embedding_type'] = 'pure_gaussian_noise'
    
    cfg_dict['num_layers'] = adapter.num_layers()
    
    # Save feature width
    first_layer = adapter.get_layer_prefix(0)#.lstrip('.')
    feature_width = submodule_dict[first_layer].ae.ae.encoder.out_features
    with open(save_dir / "feature_width.txt", "w") as f:
        f.write(str(feature_width))
    
    # Save config dictionary
    with open(config.save_directory / "config_dict.json", "w") as f:
        json.dump(cfg_dict, f, indent=2)

    if config.use_random_control:
        model.remove_hook()

def main():
    # Load and apply configuration overrides
    overrides = parse_config_overrides()
    if "--no-reinit_non_embedding" in os.sys.argv:
        overrides["reinit_non_embedding"] = False
    apply_overrides(config, overrides)
    
    os.environ['HF_HOME'] = str(config.cache_dir)
    process_model()
    

if __name__ == "__main__":
    main()