from pathlib import Path
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
from experiment_config import config
from config_loader import parse_config_overrides, apply_overrides
from noise_embedding import NoiseEmbeddingModel
from rerandomized_model import RerandomizedModel
from dataclasses import dataclass
from sae_auto_interp.config import CacheConfig
from sae_auto_interp.utils import load_tokenized_data
from sae import Sae

from sae_eval import evaluate_sae, EvalConfig

class SimpleActivationStore:
    def __init__(self, dataset, batch_size, ctx_len):
        self.dataset = dataset
        self.batch_size = batch_size
        self.ctx_len = ctx_len
        self.current_idx = 0
        
    def get_batch_tokens(self, batch_size):
        if self.current_idx + batch_size > len(self.dataset):
            self.current_idx = 0
        
        batch = self.dataset[self.current_idx:self.current_idx + batch_size]
        self.current_idx += batch_size
        print(batch)
        return torch.tensor(batch['input_ids'])
        
    def shuffle_input_dataset(self, seed=None):
        if seed is not None:
            torch.manual_seed(seed)
        indices = torch.randperm(len(self.dataset))
        self.dataset = self.dataset.select(indices.tolist())

def main():
    # Load and apply configuration overrides
    overrides = parse_config_overrides()
    if "--no-reinit_non_embedding" in os.sys.argv:
        overrides["reinit_non_embedding"] = False
    apply_overrides(config, overrides)
    
    os.environ['HF_HOME'] = str(config.cache_dir)
    
    print(f"\nEvaluating model in {config.save_directory}")
    print(f"Random control mode: {'enabled' if config.use_random_control else 'disabled'}")
    print(f"Using model: {config.model_name}")
    print(f"Using dataset: {config.dataset}")
    if config.dataset_name:
        print(f"Dataset config: {config.dataset_name}")
    
    # Load base model
    print("\nLoading base model...")
    model = AutoModelForCausalLM.from_pretrained(
        config.model_name,
        device_map=config.device_map,
        torch_dtype=getattr(torch, config.torch_dtype),
        revision="step0" if config.use_step0 else None
    )
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    
    if config.rerandomize:
        print(f"Rerandomizing model parameters (embeddings: {config.rerandomize_embeddings})")
        model = RerandomizedModel(
            model,
            rerandomize_embeddings=config.rerandomize_embeddings,
            seed=config.random_seed
        ).model

    if config.use_random_control:
        print(f"Applying random control with noise std: {config.noise_std}")
        model = NoiseEmbeddingModel(model, std=config.noise_std)

    # Load dataset for evaluation
    print("\nLoading evaluation dataset...")
    dataset_args = config.get_dataset_args()

    cfg = CacheConfig(
        dataset_repo=config.dataset,
        dataset_split=config.test_dataset_split,
        batch_size=config.cache_batch_size,
        ctx_len=config.cache_ctx_len,
        n_tokens=config.cache_n_tokens,
        n_splits=5,
    )
    
    #chunk and tokenize the dataset
    eval_dataset = load_tokenized_data(
        ctx_len=cfg.ctx_len,
        tokenizer=model.tokenizer,
        dataset_repo=cfg.dataset_repo,
        dataset_split=cfg.dataset_split,
        dataset_row=config.text_key
    )
    # Load saved SAE models using Sae.load_many
    device = next(model.parameters()).device
    
    print(f"Loading SAE models from {config.save_directory}")
    sae_models = Sae.load_many(
        name=str(config.save_directory),
        local=True,
        device=device,
    )
    
    # Setup evaluation config
    eval_config = EvalConfig(
        batch_size_prompts=8,
        n_eval_reconstruction_batches=10,
        n_eval_sparsity_variance_batches=2
    )

    # Create activation store
    activation_store = SimpleActivationStore(
        dataset=eval_dataset,
        batch_size=eval_config.batch_size_prompts,
        ctx_len=config.cache_ctx_len
    )

    # Create evaluation directory
    eval_dir = config.eval_directory / "eval_results"
    eval_dir.mkdir(exist_ok=True, parents=True)

    # Run evaluations for each layer
    ignored_tokens = {
        tokenizer.pad_token_id,
        tokenizer.eos_token_id,
        tokenizer.bos_token_id
    }
    
    print("\nRunning evaluations...")
    for layer_name, sae in sae_models.items():
        layer_num = int(layer_name.split('.')[-1])
        hook_name = f".gpt_neox.layers.{layer_num}"
        print(f"\nEvaluating {layer_name}")
        
        # Run evaluation
        results = evaluate_sae(
            sae=sae,
            model=model,
            activation_store=activation_store,
            save_dir=config.save_directory,

            eval_config=eval_config,
            hook_name=hook_name,
            ignore_tokens=ignored_tokens,
            verbose=True
        )
        
        # Print key metrics
        print(f"\nKey metrics for {layer_name}:")
        if "metrics" in results and "model_behavior_preservation" in results["metrics"]:
            metrics = results["metrics"]["model_behavior_preservation"]
            if "kl_div_score" in metrics:
                print(f"KL divergence score: {metrics['kl_div_score']:.4f}")
                
        if "metrics" in results and "reconstruction_quality" in results["metrics"]:
            metrics = results["metrics"]["reconstruction_quality"]
            if "explained_variance" in metrics:
                print(f"Explained variance: {metrics['explained_variance']:.4f}")
                
        if "metrics" in results and "sparsity" in results["metrics"]:
            metrics = results["metrics"]["sparsity"]
            if "l0" in metrics:
                print(f"L0 sparsity: {metrics['l0']:.4f}")

    print("\nEvaluation complete! Results saved to:", eval_dir)

if __name__ == "__main__":
    main()