import wandb
from transformer_lens import HookedTransformer
import torch
from torch.utils.flop_counter import FlopCounterMode

import os
import sys

from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner, SAE
from dataclasses import dataclass
from stitching.modified_sae_training import ModifiedSAETrainingRunner

import yaml
import argparse
with open('global_config.yaml') as global_stream:
    global_cfg = yaml.safe_load(global_stream)
CACHE_DIR = global_cfg['CACHE_DIR']
os.environ["TOKENIZERS_PARALLELISM"] = "false"


def run_job(cfg):
    device = cfg['device']
    print("Using device:", device)
    layer = cfg['layer']
    model_name = cfg['model_name']
    d_model = cfg['d_model']
    feature_dim = cfg['feature_dim']
    checkpoints_dir = cfg['checkpoints_dir']
    os.makedirs(checkpoints_dir, exist_ok=True)
    total_training_steps = cfg['total_training_steps']  # probably we should do more
    batch_size = cfg['batch_size']
    lr_warm_up_steps = total_training_steps // cfg['lr_warmup_div']
    #l1_warm_up_steps = total_training_steps // cfg['l1_warmup_div']  # 5% of training
    lr_decay_steps = total_training_steps // cfg['lr_decay_div']  # 20% of training
    total_training_tokens = total_training_steps * batch_size
    print(total_training_tokens, "lr warmup", lr_warm_up_steps, "lr decay", lr_decay_steps)
    #l1_coefficient = cfg['l1'] # [0.1, 0.25, 0.5, 1, 2]
    seed = cfg['seed']
    count_flops = False if cfg.get('count_flops') is None else cfg.get('count_flops')

    run_cfg = LanguageModelSAERunnerConfig(
        # Data Generating Function (Model + Training Distibuion)
        model_name=model_name,  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
        hook_name=f"blocks.{layer-1}.hook_resid_post",  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
        hook_layer=layer-1,
        d_in=d_model,
        d_sae=feature_dim,
        dataset_path="Skylion007/openwebtext",#","EleutherAI/the_pile_deduplicated"
        is_dataset_tokenized=False,
        prepend_bos=True,
        streaming=True,
        # SAE Parameters (we don't really care about any initialization since we will just set the parameters manually)
        architecture="topk",
        activation_fn_kwargs={"k": cfg['k']},
        mse_loss_normalization=None,  # We won't normalize the mse loss,
        b_dec_init_method="geometric_median",  # The geometric median can be used to initialize the decoder weights.
        apply_b_dec_to_input=False,  # We won't apply the decoder weights to the input.
        normalize_sae_decoder=True, # False
        decoder_heuristic_init=False, # True
        init_encoder_as_decoder_transpose=True,
        normalize_activations=cfg['normalize_activations'], # just use none.
        # Training Parameters
        lr=cfg['lr'],  # lower the better, we'll go fairly high to speed up the tutorial.
        adam_beta1=0.9,  # adam params (default, but once upon a time we experimented with these.)
        adam_beta2=0.999,
    
        # Not sure if we need these.
        lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
        lr_warm_up_steps=lr_warm_up_steps,  # this can help avoid too many dead features initially.
        lr_decay_steps=lr_decay_steps,  # this will help us avoid overfitting.
        
        # Data stuff
        train_batch_size_tokens=batch_size,
        context_size=128,  # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
        # Activation Store Parameters
        n_batches_in_buffer=128,  # controls how many activations we store / shuffle.
        training_tokens=total_training_tokens,
        store_batch_size_prompts=32,
        # Resampling protocol
        use_ghost_grads=False,  # we don't use ghost grads anymore.
        feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
        dead_feature_window=cfg['dead_feature_window'],  # would effect resampling or ghost grads if we were using it.
        dead_feature_threshold=1e-8,  # would effect resampling or ghost grads if we were using it.
        # WANDB
        log_to_wandb=cfg['wandb'],  # always use wandb unless you are just testing code.
        wandb_project=cfg['wandb_project'],
        wandb_log_frequency=1,
        eval_every_n_wandb_logs=500,
        # Misc
        device=device,
        seed=seed,  # the seed is for the data
        n_checkpoints=0,
        checkpoint_path=checkpoints_dir,
        dtype="float32",
        model_from_pretrained_kwargs={'cache_dir': CACHE_DIR}
    )
    training_runner = ModifiedSAETrainingRunner(run_cfg)
    print(training_runner.cfg.checkpoint_path)
    if not(os.path.isdir(training_runner.cfg.checkpoint_path)):
        os.makedirs(training_runner.cfg.checkpoint_path)
    training_runner.cfg.dir_id = os.path.basename(os.path.normpath(training_runner.cfg.checkpoint_path))
    
    training_runner.cfg.unfold_estimated_norm_factor = False
    training_runner.cfg.initialization = 'random'
    with open(os.path.join(training_runner.cfg.checkpoint_path, 'initialization.txt'), 'w') as f:
        f.write("random")
    
    # training_runner.cfg.checkpoint_thresholds = [100, 500, 1000, 2500, 5000, 10000, 30000]
    if isinstance(cfg['checkpoint_thresholds'], int):
        training_runner.cfg.checkpoint_thresholds = list(
            range(
                0,
                total_training_steps,
                total_training_steps // cfg['checkpoint_thresholds'],
            )
        )[1:]
    else:
        training_runner.cfg.checkpoint_thresholds = cfg['checkpoint_thresholds']
    print("Saving run at", training_runner.cfg.checkpoint_thresholds, "iterations")
    if count_flops:
        flop_counter = FlopCounterMode(display=False)
        with flop_counter:
            training_runner.run()
        total_flops = flop_counter.get_total_flops()
        print("Total FLOPs", f"{total_flops:e}")
        with open(os.path.join(training_runner.cfg.checkpoint_path, 'flop_count.txt'), 'w') as f:
            f.write(str(total_flops))
            f.write('\n')
            f.write(f"{total_flops:e}")
    else:
        training_runner.run()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('cfg_filename')
    args = parser.parse_args()
    with open(args.cfg_filename) as stream:
        cfg = yaml.safe_load(stream)
    run_job(cfg)
