from typing import Optional, Union
import torch
import wandb
from tqdm.auto import tqdm
from sae import TrainingConfig
from logs import get_model_performance, log_model_performance, save_checkpoint

from sae import (
    VanillaSAE, 
    TopKSAE, 
    BatchTopKSAE, 
    JumpReLUSAE, 
    SAELoRAWrapper,
    KronSAE
    )

from sae import get_model_flops, count_parameters, set_seed, topk_flops_simple

from activation_store import ActivationsStore
from transformers import get_scheduler

from transformer_lens import HookedTransformer
import os
import json
from dataclasses import asdict
from safetensors.torch import save_file

from sae.train_utils import configure_optimizers


def save_sae(sae, cfg: TrainingConfig, step: Union[int, str] = 'final'):
    os.makedirs("checkpoints", exist_ok=True)
    base_path = f"checkpoints/sae_{cfg.model_name.replace('/', '_')}_{cfg.layer}_{cfg.dict_size}_{cfg.sae_type}_{cfg.wandb_run_suffix}"
    
    # Save model state in safetensors format
    state_dict = sae.state_dict()
    os.makedirs(base_path, exist_ok=True)
    save_file(state_dict, f"{base_path}/model_{step}.safetensors")
    
    # Save config as JSON
    config_dict = asdict(cfg)
    config_dict['dtype'] = str(config_dict['dtype'])
    config_dict['sae_dtype'] = str(config_dict['sae_dtype'])
    
    with open(f"{base_path}/config.json", 'w') as f:
        json.dump(config_dict, f, indent=2)


def count_topk_flops(d, F, k):
    return d * F + d * k

def count_kronsae_flops(d, k, m, n, h):
    return d * h * (m+n) + d * k


def train_sae(
    sae: VanillaSAE | TopKSAE | BatchTopKSAE | JumpReLUSAE,
    activations_store: ActivationsStore,
    model: HookedTransformer,
    cfg: TrainingConfig,
    train_transcoder: bool = False,
) -> None:
    """
    Train a sparse autoencoder on the activations of a transformer model
    """

    set_seed(cfg.seed)

    # Setup wandb
    #optim_groups = configure_optimizers(sae, cfg)
    optimizer = torch.optim.Adam(
        params=sae.parameters(),
        lr=cfg.lr,
        betas=(0.9, 0.99),
    )
    
    # Calculate FLOPS
    print("Start SAE FLOPS:")
    # with torch.no_grad():
    #     forward_flops = get_model_flops(sae, input_shape=(1, cfg.act_size))

    if isinstance(sae, (KronSAE)):
        print("Count theoretical Flops")
        forward_flops = count_kronsae_flops(cfg.act_size, cfg.topk2, cfg.num_mkeys, cfg.num_nkeys, cfg.num_heads)
    elif isinstance(sae, (TopKSAE, BatchTopKSAE)):
        print("Count theoretical Flops")
        forward_flops = count_topk_flops(cfg.act_size, cfg.dict_size, cfg.topk2)
    else:
        raise NotImplementedError
    
    if isinstance(sae, (KronSAE)):
        default_topk_flops = topk_flops_simple(sae.config.topk1, input_shape=(1, cfg.act_size), num_heads=sae.config.num_heads)
        
    accum_num_flops = 0

    count_parameters(sae)
    total_params = sum(p.numel() for p in sae.parameters())
    if cfg.enable_wandb:
        wandb.config.update({"parameters_count": total_params}, allow_val_change=True)
        wandb.config.update(asdict(cfg), allow_val_change=True)

    # Calculate total number of batches
    tokens_per_batch = cfg.batch_size
    total_batches = cfg.num_tokens // tokens_per_batch
    warmup_topk2_batches = int(total_batches * cfg.topk2_warmup_steps_fraction)
    warmup_topk1_batches = int(total_batches * cfg.topk1_warmup_steps_fraction)
    if cfg.warmup_fraction <= 1.0:
        warmup_lr_batches = int(total_batches * cfg.warmup_fraction)
    else:
        warmup_lr_batches = int(cfg.batch_size * cfg.warmup_fraction)


    # scheduler = get_scheduler(
    #     name=cfg.scheduler_type,
    #     optimizer=optimizer,
    #     num_warmup_steps=warmup_lr_batches,
    #     num_training_steps=total_batches
    # )
    # scheduler_type: str = 'cosine_with_min_lr'

    scheduler = get_scheduler(
        name=cfg.scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=warmup_lr_batches,
        num_training_steps=total_batches,
        scheduler_specific_kwargs=dict(min_lr=1e-6),
    )

    print(f"tokens per iteration will be: {tokens_per_batch:,}")
    
    # Initialize CUDA events for timing if available
    if torch.cuda.is_available():
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
    else:
        start_event = end_event = None
    
    # Training loop
    pbar = tqdm(range(total_batches), desc="Training SAE")
    assert cfg.batch_size % cfg.model_batch_size == 0, "Batch size must be a multiple of model batch size"
    num_batches_before_refresh = (cfg.seq_len * cfg.model_batch_size) // cfg.batch_size
    sae = torch.compile(sae)
    sae.to(cfg.device)
    
    # Check if we're using a TopK variant
    is_topk_variant = isinstance(sae, (
        TopKSAE, 
        BatchTopKSAE, 
        KronSAE)) or (
            isinstance(sae, SAELoRAWrapper) and isinstance(
                sae.base_sae, (TopKSAE, 
                               BatchTopKSAE, 
                               KronSAE)))
    
    n_tokens_processed_true = 0
    
    saving_steps = [100, 1000, 5000, 6000, 10000, 20000]
    for i in pbar:
        # Anneal topk2 during warmup only for TopK variants
        if is_topk_variant and i < warmup_topk2_batches:
            current_topk2 = int(cfg.start_topk2 - (cfg.start_topk2 - cfg.topk2) * (i / warmup_topk2_batches))
            sae.config.topk2 = current_topk2
        else:
            sae.config.topk2 = cfg.topk2

        if is_topk_variant and i < warmup_topk1_batches:
            current_topk1 = int(cfg.start_topk1 - (cfg.start_topk1 - cfg.topk1) * (i / warmup_topk1_batches))
            sae.config.topk1 = current_topk1
        else:
            sae.config.topk1 = cfg.topk1
        
        add_new = (i+1) % num_batches_before_refresh == 0
        batch = activations_store.get_batch(add_new=add_new)
        
        # Start timing forward pass
        if start_event is not None:
            start_event.record()
        if train_transcoder:
            loss_dict = sae(batch[0].to(cfg.sae_dtype), batch[1].to(cfg.sae_dtype))
        else:
            loss_dict = sae(batch.to(cfg.sae_dtype))
        
        # End timing forward pass
        if end_event is not None:
            end_event.record()
            torch.cuda.synchronize()
            forward_time_ms = start_event.elapsed_time(end_event)
        else:
            forward_time_ms = 0
        
        # Start timing backward pass
        if start_event is not None:
            start_event.record()
            
        # Optimizer step
        optimizer.zero_grad()
        loss_dict["loss"].backward()
        torch.nn.utils.clip_grad_norm_(sae.parameters(), max_norm=cfg.max_grad_norm)
        if not isinstance(sae, SAELoRAWrapper):
            sae.make_decoder_weights_and_grad_unit_norm()
            
        # End timing backward pass
        if end_event is not None:
            end_event.record()
            torch.cuda.synchronize()
            backward_time_ms = start_event.elapsed_time(end_event)
        else:
            backward_time_ms = 0
            
        optimizer.step()
        scheduler.step()

        del loss_dict["sae_out"]
        del loss_dict["feature_acts"]
        if isinstance(sae, (KronSAE)):
            forward_flops_current = forward_flops - default_topk_flops + topk_flops_simple(sae.config.topk1, input_shape=(1, cfg.act_size), num_heads=sae.config.num_heads)
        else:
            forward_flops_current = forward_flops
        accum_num_flops += forward_flops_current 

        # Log norm
        total_norm = 0.
        for p in sae.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** (1. / 2)
        
        # Log metrics
        n_tokens_processed = (i + 1) * tokens_per_batch
        n_tokens_processed_true += batch.shape[0]
        metrics_to_log = {
            "n_tokens": n_tokens_processed,
            "n_tokens_true": n_tokens_processed_true,
            "forward_time_ms": forward_time_ms,
            "backward_time_ms": backward_time_ms,
            "forward_flops": forward_flops_current,
            "accum_num_flops": accum_num_flops,
            "grad_norm": total_norm,
            **loss_dict,
        }
        if is_topk_variant:
            metrics_to_log["current_topk2"] = sae.config.topk2
        
        if cfg.enable_wandb:
            wandb.log(metrics_to_log)
        
        # Update progress bar
        progress_metrics = {
            "l2_loss": f"{loss_dict['l2_loss']:.3f}",
            "explained_var": f"{loss_dict['explained_variance']:.3f}",
            "forward_ms": f"{forward_time_ms}",
            "backward_ms": f"{backward_time_ms}",
        }
        if is_topk_variant:
            progress_metrics["topk2"] = sae.config.topk2
            
        pbar.set_postfix(progress_metrics)
        if (i+1) % cfg.performance_log_steps == 0:
            log_model_performance(
                wandb_run=wandb.run,
                step=i+1,
                model=model,
                config=cfg,
                activations_store=activations_store,
                sae=sae,
            )
        if (i+1) % cfg.save_checkpoint_steps == 0:
            save_checkpoint(
                wandb_run=wandb.run,
                sae=sae,
                cfg=cfg,
                step=i+1)
        
        # Update progress bar
        pbar.set_postfix({
            "l2_loss": f"{loss_dict['l2_loss']:.3f}",
            "explained_var": f"{loss_dict['explained_variance']:.3f}",
            "topk2": sae.config.topk2,
            "topk1": sae.config.topk1
        })
    
    sae = sae.fold_stats_into_weights()
    num_eval_iters = 150
    log_dict_total = {
        "performance/ce_degradation": 0.,
        "performance/recovery_from_zero": 0.,
        "performance/recovery_from_mean": 0.,
    }
    for i in tqdm(range(num_eval_iters)):
        log_dict = get_model_performance(
            model=model,
            config=cfg,
            activations_store=activations_store,
            sae=sae,
        )

        log_dict_total["performance/ce_degradation"] += log_dict["performance/ce_degradation"]
        log_dict_total["performance/recovery_from_zero"] += log_dict["performance/recovery_from_zero"]
        log_dict_total["performance/recovery_from_mean"] += log_dict["performance/recovery_from_mean"]
    
    log_dict_total["performance/ce_degradation"] /= num_eval_iters
    log_dict_total["performance/recovery_from_zero"] /= num_eval_iters
    log_dict_total["performance/recovery_from_mean"] /= num_eval_iters

    if cfg.enable_wandb:
        wandb.summary.update({
            "eval/performance/ce_degradation": log_dict_total["performance/ce_degradation"],
            "eval/performance/recovery_from_zero": log_dict_total["performance/recovery_from_zero"],
            "eval/performance/recovery_from_mean": log_dict_total["performance/recovery_from_mean"],
        })
        wandb.log(log_dict_total)

    save_checkpoint(
        wandb_run=wandb.run,
        sae=sae,
        cfg=cfg,
        step=i+1
        )

    pbar.close()
    wandb.finish()


def train_sae_group(
    cfg: TrainingConfig,
    sae_types: list[str],
    topk2_values: Optional[list[int]] = None,
    dict_sizes: Optional[list[int]] = None
) -> None:
    """
    Train multiple SAEs with different configurations
    """
    base_cfg = cfg
    
    for sae_type in sae_types:
        if topk2_values:
            for topk2 in topk2_values:
                cfg = TrainingConfig(**vars(base_cfg))  # Create new config instance
                cfg.sae_type = sae_type
                cfg.topk2 = topk2
                
                model = HookedTransformer.from_pretrained(cfg.model_name).to(cfg.dtype).to(cfg.device)
                cfg.act_size = model.cfg.d_model
                activations_store = ActivationsStore(model, cfg)
                
                sae = {
                    "vanilla": VanillaSAE,
                    "topk": TopKSAE,
                    "batchtopk": BatchTopKSAE,
                    "jumprelu": JumpReLUSAE,
                }[sae_type](cfg)
                
                train_sae(sae, activations_store, model, cfg)
        
        if dict_sizes:
            for dict_size in dict_sizes:
                cfg = TrainingConfig(**vars(base_cfg))  # Create new config instance
                cfg.sae_type = sae_type
                cfg.dict_size = dict_size
                
                model = HookedTransformer.from_pretrained(cfg.model_name).to(cfg.dtype).to(cfg.device)
                cfg.act_size = model.cfg.d_model
                activations_store = ActivationsStore(model, cfg)
                
                sae = {
                    "vanilla": VanillaSAE,
                    "topk": TopKSAE,
                    "batchtopk": BatchTopKSAE,
                    "jumprelu": JumpReLUSAE,
                }[sae_type](cfg)
                
                train_sae(sae, activations_store, model, cfg)
