from itertools import islice
from typing import Optional, Union
import torch
from torch.nn.parallel import DistributedDataParallel
import wandb
from tqdm.auto import tqdm
from distibuted_activation_store import DistributedActivationsStore
from sae import TrainingConfig
from logs import get_distributed_model_performance, get_model_performance, log_model_performance, save_checkpoint, save_checkpoint_accelerator

from sae import (
    VanillaSAE, 
    TopKSAE, 
    BatchTopKSAE, 
    JumpReLUSAE, 
    SAELoRAWrapper,
    KronSAE, 
    )

import accelerate
from accelerate import (
    AutocastKwargs,
    DistributedDataParallelKwargs,
    InitProcessGroupKwargs,
    DataLoaderConfiguration,
)

from sae import get_model_flops, count_parameters, set_seed, topk_flops_simple

from transformers import get_scheduler

from transformer_lens import HookedTransformer
import os
import json
from dataclasses import asdict
from safetensors.torch import save_file

from sae.model import BaseSAE
from sae.train_utils import configure_optimizers


@torch.no_grad()
def make_decoder_weights_and_grad_unit_norm(W_dec: torch.Tensor):
    W_dec_normed = W_dec / W_dec.norm(dim=-1, keepdim=True)
    W_dec_grad_proj = (W_dec.grad * W_dec_normed).sum(
        -1, keepdim=True
    ) * W_dec_normed
    W_dec.grad -= W_dec_grad_proj
    W_dec.data = W_dec_normed


def save_sae(sae, cfg: TrainingConfig, step: Union[int, str] = 'final'):
    os.makedirs("checkpoints", exist_ok=True)
    base_path = f"checkpoints/sae_{cfg.model_name.replace('/', '_')}_{cfg.layer}_{cfg.dict_size}_{cfg.sae_type}_{cfg.wandb_run_suffix}"
    
    # Save model state in safetensors format
    state_dict = sae.state_dict()
    os.makedirs(base_path, exist_ok=True)
    save_file(state_dict, f"{base_path}/model_{step}.safetensors")
    
    # Save config as JSON
    config_dict = asdict(cfg)
    config_dict['dtype'] = str(config_dict['dtype'])
    config_dict['sae_dtype'] = str(config_dict['sae_dtype'])
    
    with open(f"{base_path}/config.json", 'w') as f:
        json.dump(config_dict, f, indent=2)


def train_step(
        cfg,
        batch,
        accelerator: accelerate.Accelerator,
        sae: BaseSAE,
        optimizer,
        scheduler,
        start_event, 
        train_transcoder, 
        end_event
        ):
    # Start timing forward pass
    if start_event is not None:
        start_event.record()
    if train_transcoder:
        loss_dict = sae(batch[0], batch[1])
    else:
        loss_dict = sae(batch)
    
    # End timing forward pass
    if end_event is not None:
        end_event.record()
        torch.cuda.synchronize()
        forward_time_ms = start_event.elapsed_time(end_event)
    else:
        forward_time_ms = 0
    
    # Start timing backward pass
    if start_event is not None:
        start_event.record()
        
    # Optimizer step
    optimizer.zero_grad()
    accelerator.backward(loss_dict["loss"])

    if accelerator.sync_gradients:
        grad_norm = accelerator.clip_grad_norm_(
            sae.parameters(), max_norm=cfg.max_grad_norm
        )
    
    if not isinstance(sae, SAELoRAWrapper):
        make_decoder_weights_and_grad_unit_norm(sae.module.W_dec)
        
    # End timing backward pass
    if end_event is not None:
        end_event.record()
        torch.cuda.synchronize()
        backward_time_ms = start_event.elapsed_time(end_event)
    else:
        backward_time_ms = 0
        
    optimizer.step()
    scheduler.step()

    del loss_dict["sae_out"]
    
    return loss_dict, forward_time_ms, backward_time_ms, grad_norm

def train_sae(
    sae: DistributedDataParallel|VanillaSAE | TopKSAE | BatchTopKSAE | JumpReLUSAE,
    distributed_activations_store: DistributedActivationsStore,
    iterator_dl, #: torch.utils.data.DataLoader,
    optimizer,
    model: HookedTransformer,
    cfg: TrainingConfig,
    accelerator: accelerate.Accelerator,
    train_transcoder: bool = False,
) -> None:
    """
    Train a sparse autoencoder on the activations of a transformer model
    """
    
    # Calculate FLOPS
    # Calculate total number of batches
    tokens_per_batch = cfg.model_batch_size * cfg.seq_len
    total_batches = cfg.num_tokens // tokens_per_batch
    warmup_topk2_batches = int(total_batches * cfg.topk2_warmup_steps_fraction)
    warmup_topk1_batches = int(total_batches * cfg.topk1_warmup_steps_fraction)
    if cfg.warmup_fraction <= 1.0:
        warmup_lr_batches = int(total_batches * cfg.warmup_fraction)
    else:
        warmup_lr_batches = int(cfg.batch_size * cfg.warmup_fraction)
        accelerator.print(f"tokens per iteration will be: {tokens_per_batch:,}")

    scheduler = get_scheduler(
        name=cfg.scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=warmup_lr_batches,
        num_training_steps=total_batches,
        scheduler_specific_kwargs=dict(min_lr=1e-6), # / accelerator.num_processes),
    )
    accelerator.print("Start SAE FLOPS counting:")
    accelerator.print(f"SAE device: {sae.device}")
    if accelerator.is_main_process:
        with torch.no_grad():
            forward_flops = get_model_flops(sae, input_shape=(1, cfg.act_size))
    
    if isinstance(sae, (KronSAE)):
        default_topk_flops = topk_flops_simple(sae.config.topk1, input_shape=(1, cfg.act_size), num_heads=sae.config.num_heads)
    print("------>", accelerator.device)
    accelerator.print("[*] Wait all processess ..... [*]")
    accelerator.wait_for_everyone()
    accum_num_flops = 0

    accelerator.print("[*] Prepare optimizer, SAE and scheduler ")
    sae, optimizer, scheduler = accelerator.prepare(sae, optimizer, scheduler)

    accelerator.print("[*] Count parameters and update WandB ")
    if accelerator.is_main_process:
        count_parameters(sae)
        total_params = sum(p.numel() for p in sae.parameters())
        if cfg.enable_wandb:
            wandb.config.update({"parameters_count": total_params}, allow_val_change=True)
            wandb.config.update(asdict(cfg), allow_val_change=True)

    # Initialize CUDA events for timing if available
    if torch.cuda.is_available():
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
    else:
        start_event = end_event = None
    
    # Training loop
    pbar = tqdm(range(total_batches), desc="Training SAE")
    assert cfg.batch_size % cfg.model_batch_size == 0, "Batch size must be a multiple of model batch size"
    num_batches_before_refresh = (cfg.seq_len * cfg.model_batch_size) // cfg.batch_size #// accelerator.num_processes
    
    #sae.to("auto")
    
    # Check if we're using a TopK variant
    is_topk_variant = isinstance(sae, (
        TopKSAE, 
        BatchTopKSAE, 
        KronSAE)) or (
            isinstance(sae, SAELoRAWrapper) and isinstance(
                sae.base_sae, (TopKSAE, 
                               BatchTopKSAE,
                               KronSAE)))
    
    saving_steps = [100, 1000, 5000, 6000, 10000, 20000]
    accelerator.print("[*] Starting training 👨‍🏫 [*]")

    #for batch in activations_dataloader:
        #i = pbar.n
    #for i, batch in enumerate(islice(activations_dataloader, total_batches)):
    for i in pbar:
        # Anneal topk2 during warmup only for TopK variants
        #batch = next(iterator_dl)
        add_new = (i+1) % num_batches_before_refresh == 0
        #batch = distributed_activations_store.get_processed_batch(model, batch, add_new=add_new)
        batch = distributed_activations_store.get_batch(model, iterator_dl, add_new=add_new)

        if is_topk_variant and i < warmup_topk2_batches:
            current_topk2 = int(cfg.start_topk2 - (cfg.start_topk2 - cfg.topk2) * (i / warmup_topk2_batches))
            sae.module.config.topk2 = current_topk2
        else:
            sae.module.config.topk2 = cfg.topk2

        if is_topk_variant and i < warmup_topk1_batches:
            current_topk1 = int(cfg.start_topk1 - (cfg.start_topk1 - cfg.topk1) * (i / warmup_topk1_batches))
            sae.module.config.topk1 = current_topk1
        else:
            sae.module.config.topk1 = cfg.topk1
        
        #print(batch.shape,  (i+1) % num_batches_before_refresh == 0, )
        
        loss_dict, forward_time_ms, backward_time_ms, total_norm = train_step( cfg,
            batch, accelerator, sae, optimizer, scheduler, start_event, train_transcoder, end_event,
        )
        with torch.no_grad():
            float_metrics = {k: v.float() for k, v in loss_dict.items()}
            accumulated_metrics = accelerator.gather_for_metrics(float_metrics)
        #accumulated_metrics = accelerator.reduce(gathered_metrics, reduction="mean")

        if accelerator.is_main_process:
            # print(accelerator.gather_for_metrics(loss_dict["explained_variance"]))
            # for k, v in accumulated_metrics:
            #     accumulated_metrics[k] = torch.mean(v, -1)

            with torch.no_grad():
                accumulated_metrics = {k: v.mean() for k, v in accumulated_metrics.items()}

            sae.eval()
            if isinstance(sae, (KronSAE)):
                forward_flops_current = forward_flops - default_topk_flops + topk_flops_simple(sae.config.topk1, input_shape=(1, cfg.act_size), num_heads=sae.config.num_heads)
            else:
                forward_flops_current = forward_flops
            accum_num_flops += forward_flops_current 
            
            # Log metrics
            n_tokens_processed = (i + 1) * tokens_per_batch
            metrics_to_log = {
                "n_tokens": n_tokens_processed,
                "forward_time_ms": forward_time_ms,
                "backward_time_ms": backward_time_ms,
                "forward_flops": forward_flops_current,
                "accum_num_flops": accum_num_flops,
                "grad_norm": total_norm,
                **accumulated_metrics,
                #**loss_dict
            }
            if is_topk_variant:
                metrics_to_log["current_topk2"] = sae.module.config.topk2
            
            accelerator.log(metrics_to_log)
            
            # Update progress bar
            progress_metrics = {
                # "l2_loss": f"{loss_dict['l2_loss']:.3f}",
                # "explained_var": f"{loss_dict['explained_variance']:.3f}",
                "l2_loss": f"{accumulated_metrics['l2_loss'].item():.3f}",
                "explained_var": f"{accumulated_metrics['explained_variance'].item():.3f}",
                "forward_ms": f"{forward_time_ms}",
                "backward_ms": f"{backward_time_ms}",
            }
            if is_topk_variant:
                progress_metrics["topk2"] = sae.module.config.topk2
        
            pbar.set_postfix(progress_metrics)
            if (i+1) % cfg.performance_log_steps == 0:
                logs = get_distributed_model_performance(
                    model=model,
                    config=cfg,
                    activations_store=distributed_activations_store,
                    iterator_dl=iterator_dl,
                    sae=sae,
                )
                accelerator.log(logs, step=i+1)
            
            if (i+1) % cfg.save_checkpoint_steps == 0:
                save_checkpoint(
                    wandb_run=accelerator.get_tracker("wandb"),
                    sae=sae,
                    cfg=cfg,
                    step=i+1)
                # save_checkpoint_accelerator(
                #     accelerator, 
                #     cfg=cfg,
                #     step=i+1)
            
            # Update progress bar
            pbar.set_postfix({
                # "l2_loss": f"{loss_dict['l2_loss']:.3f}",
                # "explained_var": f"{loss_dict['explained_variance']:.3f}",
                "l2_loss": f"{accumulated_metrics['l2_loss'].item():.3f}",
                "explained_var": f"{accumulated_metrics['explained_variance'].item():.3f}",
                "topk2": sae.module.config.topk2,
                "topk1": sae.module.config.topk1
            })
            sae.train()
        #pbar.update()

    if accelerator.is_main_process:
        save_checkpoint(
                    wandb_run=wandb.run,
                    sae=sae,
                    cfg=cfg,
                    step=i+1)
        sae = sae.fold_stats_into_weights()
    #sae = sae.fold_W_dec_norm() !!!!!! TODO NEED FIX 
    pbar.close()
    accelerator.wait_for_everyone()
    accelerator.end_training()