import torch
from datasets import load_dataset, load_from_disk
from transformers import AutoModelForCausalLM, AutoTokenizer

from sae import SaeConfig, SaeTrainer, TrainConfig
from sae.data import chunk_and_tokenize
from experiment_config import config
from noise_embedding import NoiseEmbeddingModel
import os
import argparse
from typing import Optional, List

from config_loader import parse_config_overrides, apply_overrides, save_config
from rerandomized_model import RerandomizedModel
from model_adapter import ModelArchitectureAdapter
from pathlib import Path
import shutil

def clear_cuda_memory():
    """Clear CUDA memory and garbage collect"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        
    import gc
    gc.collect()

def parse_layer_args():
    """Parse command line arguments for layer selection"""
    parser = argparse.ArgumentParser(description='Train SAE models')
    parser.add_argument('--layer-by-layer', action='store_true', 
                       help='Enable layer-by-layer training mode')
    parser.add_argument('--layers', type=str, 
                       help='With layer-by-layer mode: comma-separated list of layer indices or range like "0-4"')
    parser.add_argument('--single-layer', type=int, 
                       help='With layer-by-layer mode: single layer to train')
    parser.add_argument('--layer-group-size', type=int, default=None,
                       help='Number of layers to train together in layer-by-layer mode')
    return parser

def get_layers_to_train(args, num_layers: int, layer_stride: int) -> List[List[int]]:
    """Determine which layers to train based on command line arguments"""
    if not args.layer_by_layer:
        return []  # Normal training mode
        
    if args.single_layer is not None:
        if args.single_layer < 0 or args.single_layer >= num_layers:
            raise ValueError(f"Layer index {args.single_layer} out of range [0, {num_layers-1}]")
        return [[args.single_layer]]
        
    if args.layers:
        if '-' in args.layers:
            start, end = map(int, args.layers.split('-'))
            if start < 0 or end >= num_layers:
                raise ValueError(f"Layer range {start}-{end} out of range [0, {num_layers-1}]")
            layers = list(range(start, end + 1, layer_stride))
        else:
            try:
                layers = [int(l) for l in args.layers.split(',')]
                if any(l < 0 or l >= num_layers for l in layers):
                    raise ValueError(f"Layer indices must be in range [0, {num_layers-1}]")
                layers = sorted(layers)
            except ValueError as e:
                raise ValueError(f"Invalid layer format. Use comma-separated numbers or range. Error: {e}")
    else:
        layers = list(range(0, num_layers, layer_stride))
    
    # Group layers
    group_size = args.layer_group_size or 1
    return [layers[i:i + group_size] for i in range(0, len(layers), group_size)]

def create_sae_config() -> SaeConfig:
    """Create SAE config from experiment config"""
    return SaeConfig(
        expansion_factor=config.expansion_factor,
        normalize_decoder=config.normalize_decoder,
        num_latents=config.num_latents,
        k=config.k,
        multi_topk=config.multi_topk,
    )

def create_train_config(sae_config: SaeConfig, save_dir: Path, layers: List[int], adapter: ModelArchitectureAdapter) -> TrainConfig:
    """Create training config for specified layers"""
    return TrainConfig(
        sae=sae_config,
        batch_size=config.batch_size,
        layers=layers,
        run_name=str(save_dir)
    )

def load_and_prepare_dataset():
    """Load and prepare the dataset for training"""
    tokenized_path = f'{str(config.tokenized_dataset_path)}_{str(config.train_dataset_split)}'

    if Path(tokenized_path).exists():
        print(f"Loading tokenized dataset from {tokenized_path}")
        tokenized = load_from_disk(tokenized_path)
    else: 
        print(f"Loading dataset {config.dataset}")
        dataset_args = config.get_dataset_args()
        dataset_args['split'] = config.train_dataset_split
        
        dataset = load_dataset(**dataset_args,
                             trust_remote_code=True,
                             cache_dir=config.cache_dir)
        
        print("Tokenizing dataset...")
        tokenizer = AutoTokenizer.from_pretrained(config.model_name)
        tokenized = chunk_and_tokenize(dataset, tokenizer, text_key=config.text_key)
        tokenized.save_to_disk(tokenized_path)

    # Process dataset
    tokens_per_sequence = tokenized['input_ids'][0].shape[0]
    total_tokens = len(tokenized) * tokens_per_sequence
    print(f"Number of tokens in dataset: {total_tokens:,}")
    
    max_sequences = config.max_tokens // tokens_per_sequence
    tokenized = tokenized.select(range(min(len(tokenized), max_sequences)))
    
    final_tokens = len(tokenized) * tokens_per_sequence
    print(f"Final number of tokens: {final_tokens:,}")
    
    if final_tokens < config.max_tokens:
        print(f"Warning: Less than {config.max_tokens:,} tokens in dataset")
    
    return tokenized.shuffle(seed=config.random_seed)

def main():
    # Parse arguments and apply config overrides
    parser = parse_layer_args()
    args, remaining = parser.parse_known_args()
    
    overrides = parse_config_overrides()
    apply_overrides(config, overrides)

    os.environ['HF_HOME'] = str(config.cache_dir)
    
    # Load and prepare dataset
    tokenized = load_and_prepare_dataset()

    # Initialize model and adapter
    print("\nInitializing model...")
    print(f"Using step0 revision: {config.use_step0}")
    model = AutoModelForCausalLM.from_pretrained(
        config.model_name,
        device_map=config.device_map,
        torch_dtype=getattr(torch, config.torch_dtype),
        revision="step0" if config.use_step0 else None,
        cache_dir=config.cache_dir
    )
    
    adapter = ModelArchitectureAdapter(model)
    print(f"Detected model architecture: {adapter.model_type}")

    if config.rerandomize:
        print(f"Rerandomizing model parameters:")
        print(f"  - Embeddings: {'included' if config.rerandomize_embeddings else 'preserved'}")
        print(f"  - Layer Norm: {'randomized' if config.rerandomize_layer_norm else 'frozen'}")
        model = RerandomizedModel(
            model,
            rerandomize_embeddings=config.rerandomize_embeddings,
            rerandomize_layer_norm=config.rerandomize_layer_norm,
            seed=config.random_seed
        ).model

    if config.use_random_control:
        print(f"Using random control mode with noise std: {config.noise_std}")
        model = NoiseEmbeddingModel(model, std=config.noise_std)

    # Setup directories and configs
    save_dir = config.save_directory
    save_dir.mkdir(parents=True, exist_ok=True)
    print(f"Models will be saved in: {save_dir}")
    
    sae_config = create_sae_config()
    
    try:
        if args.layer_by_layer:
            num_layers = adapter.num_layers()
            layers_to_train = get_layers_to_train(args, num_layers, config.layer_stride)
            
            print(f"\nLayer-by-layer training mode")
            print(f"Training layers: {layers_to_train}")
            print(f"Total layer groups: {len(layers_to_train)}")
            
            for i, layer_group in enumerate(layers_to_train, 1):
                group_str = ','.join(map(str, layer_group))
                print(f"\nTraining layer group {group_str} ({i}/{len(layers_to_train)})")
                
                train_cfg = create_train_config(sae_config, save_dir, layer_group, adapter)
                trainer = SaeTrainer(train_cfg, tokenized, model)
                
                try:
                    trainer.fit()
                finally:
                    print("Cleaning up trainer...")
                    del trainer
                    clear_cuda_memory()
                
                if config.use_random_control:
                    for layer in layer_group:
                        layer_prefix = adapter.get_layer_prefix(layer).lstrip('.')
                        target_dir = save_dir / layer_prefix
                        if target_dir.exists():
                            shutil.rmtree(target_dir)
                        source_dir = save_dir / f"model.{layer_prefix}"
                        if source_dir.exists():
                            source_dir.rename(target_dir)
                
                print("\nGPU memory status:")
                print(f"Allocated: {torch.cuda.memory_allocated() // 1024**2:,}MB")
                print(f"Cached: {torch.cuda.memory_reserved() // 1024**2:,}MB")
        else:
            print("\nNormal training mode (all layers)")
            all_layers = list(range(0, adapter.num_layers(), config.layer_stride))
            train_cfg = create_train_config(sae_config, save_dir, all_layers, adapter)
            
            trainer = SaeTrainer(train_cfg, tokenized, model)
            trainer.fit()
            
            if config.use_random_control:
                for layer in all_layers:
                    layer_prefix = adapter.get_layer_prefix(layer).lstrip('.')
                    target_dir = save_dir / layer_prefix
                    if target_dir.exists():
                        shutil.rmtree(target_dir)
                    source_dir = save_dir / f"model.{layer_prefix}"
                    if source_dir.exists():
                        source_dir.rename(target_dir)
    finally:
        print("\nFinal cleanup...")
        clear_cuda_memory()
        print(f"Final GPU memory: {torch.cuda.memory_allocated() // 1024**2:,}MB")

if __name__ == "__main__":
    main()