import wandb
from transformer_lens import HookedTransformer
import torch
import os
import numpy as np
from torch.utils.flop_counter import FlopCounterMode


from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner, SAE
from dataclasses import dataclass
from stitching.stitching_utils import open_experiment
from stitching.modified_sae_training import ModifiedSAETrainingRunner
from stitching.sae_utils import BaseSAE

import yaml
import argparse
with open('global_config.yaml') as global_stream:
    global_cfg = yaml.safe_load(global_stream)
CACHE_DIR = global_cfg['CACHE_DIR']
os.environ["TOKENIZERS_PARALLELISM"] = "false"


def run_job(cfg):
    device = cfg['device']
    print("Using device:", device)
    init_dir = cfg['init_dir']
    transfer_id = cfg['transfer_id']
    # FIND FINAL DIRECTORY
    final_subdir = None
    for wee in os.listdir(init_dir):
        if 'final' in wee:
            final_subdir = wee
    normalize_decoder = False if cfg.get('normalize_decoder') is None else cfg.get('normalize_decoder')
    if cfg['replace_params']:
        print(f"Loading from {os.path.join(init_dir,final_subdir)}")
        init_sae = SAE.load_from_pretrained(os.path.join(init_dir, final_subdir), device='cuda')

        init_skeleton = BaseSAE(
            init_sae.W_enc.detach().clone(),
            init_sae.W_dec.detach().clone(),
            init_sae.b_enc.detach().clone(),
            init_sae.b_dec.detach().clone(),
            init_sae.activation_fn,
            apply_b_dec=init_sae.cfg.apply_b_dec_to_input
        )
        init_skeleton.get_rid_of_decoder_sub()

        P, Pinv, beta, bias, biasinv = open_experiment(
            cfg['d_orig_model'], cfg['d_model'],
            f'checkpoints/stitch_training_{cfg['model_a_name']}_to_{cfg['model_b_name']}_bidirectional_mse/',
            transfer_id,
            biases=True,
            device=device
        )
        skeleton = BaseSAE(
            Pinv @ init_skeleton.W_enc.clone() ,
            init_skeleton.W_dec.clone() @ P,
            init_skeleton.b_enc.clone() + biasinv @ init_skeleton.W_enc.clone(),
            init_skeleton.b_dec.clone() @ P + bias,
            init_skeleton.activation_fn,
            requires_grad=True,
            apply_b_dec=False
        )
        if normalize_decoder:
            print("Normalizing decoder vectors")
            skeleton.normalize_decoder_vectors()

    #if cfg['l1'] == 'auto':
    #    # auto determine l1 based on ratios of dimensions
    #    l1_coefficient = cfg['old_l1'] / np.sqrt(cfg['d_model'] / cfg['d_orig_model'])
    #else:
    #    l1_coefficient = cfg['l1'] # [0.1, 0.25, 0.5, 1, 2]
    #print(f"Using l1 {l1_coefficient}")
    
    
    layer = cfg['layer']
    model_name = cfg['model_name']
    d_model = cfg['d_model']
    feature_dim = cfg['feature_dim']
    checkpoints_dir = cfg['checkpoints_dir']
    os.makedirs(checkpoints_dir, exist_ok=True)
    total_training_steps = cfg['total_training_steps']  # probably we should do more
    batch_size = cfg['batch_size']
    lr_warm_up_steps = total_training_steps // cfg['lr_warmup_div']
    #l1_warm_up_steps = total_training_steps // cfg['l1_warmup_div']  # 5% of training
    lr_decay_steps = total_training_steps // cfg['lr_decay_div']  # 20% of training
    total_training_tokens = total_training_steps * batch_size
    print(total_training_tokens, "lr warmup", lr_warm_up_steps, "lr decay", lr_decay_steps)

    seed = cfg['seed']
    count_flops = False if cfg.get('count_flops') is None else cfg.get('count_flops')
    
    run_cfg = LanguageModelSAERunnerConfig(
        # Data Generating Function (Model + Training Distibuion)
        model_name=model_name,  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
        hook_name=f"blocks.{layer-1}.hook_resid_post",  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
        hook_layer=layer-1,
        d_in=d_model,
        d_sae=feature_dim,
        dataset_path="Skylion007/openwebtext",#","EleutherAI/the_pile_deduplicated"
        is_dataset_tokenized=False,
        streaming=True,
        # SAE Parameters (we don't really care about any initialization since we will just set the parameters manually)
        architecture="topk",
        activation_fn_kwargs={"k": cfg['k']},
        mse_loss_normalization=None,  # We won't normalize the mse loss,
        b_dec_init_method="zeros",  # The geometric median can be used to initialize the decoder weights.
        apply_b_dec_to_input=False,  # We won't apply the decoder weights to the input.
        normalize_sae_decoder=normalize_decoder, # False
        decoder_heuristic_init=False, # True
        init_encoder_as_decoder_transpose=True,
        normalize_activations=cfg['normalize_activations'], # just use none.
        # Training Parameters
        lr=cfg['lr'],  # lower the better, we'll go fairly high to speed up the tutorial.
        adam_beta1=0.9,  # adam params (default, but once upon a time we experimented with these.)
        adam_beta2=0.999,    
        # Not sure if we need these.
        lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
        lr_warm_up_steps=lr_warm_up_steps,  # this can help avoid too many dead features initially.
        lr_decay_steps=lr_decay_steps,  # this will help us avoid overfitting.
        
        # Data stuff
        train_batch_size_tokens=batch_size,
        context_size=128,  # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
        # Activation Store Parameters
        n_batches_in_buffer=128,  # controls how many activations we store / shuffle.
        training_tokens=total_training_tokens,
        store_batch_size_prompts=32,
        # Resampling protocol
        use_ghost_grads=False,  # we don't use ghost grads anymore.
        feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
        dead_feature_window=cfg['dead_feature_window'],  # would effect resampling or ghost grads if we were using it.
        dead_feature_threshold=1e-8,  # would effect resampling or ghost grads if we were using it.
        # WANDB
        log_to_wandb=cfg['wandb'],  # always use wandb unless you are just testing code.
        wandb_project=cfg['wandb_project'],
        wandb_log_frequency=1,
        eval_every_n_wandb_logs=200,
        # Misc
        device=device,
        seed=seed,  # the seed is for the data
        n_checkpoints=0,
        checkpoint_path=checkpoints_dir,
        dtype="float32",
        model_from_pretrained_kwargs={'cache_dir': CACHE_DIR}
    )
    training_runner = ModifiedSAETrainingRunner(run_cfg)
    print(training_runner.cfg.checkpoint_path)
    if not(os.path.isdir(training_runner.cfg.checkpoint_path)):
        os.makedirs(training_runner.cfg.checkpoint_path)
    training_runner.cfg.dir_id = os.path.basename(os.path.normpath(training_runner.cfg.checkpoint_path))

    # Reset to our parameters.
    if cfg['replace_params']:
        # here, we should replace specific elements of the .data - first k rows
        for param_name, param_value in skeleton.named_parameters():
            if param_name == 'W_enc':
                dims = param_value.shape[1]
                getattr(training_runner.sae, param_name).data[:, :dims] = param_value
            elif param_name == 'W_dec':
                dims = param_value.shape[0]
                getattr(training_runner.sae, param_name).data[:dims, :] = param_value
            elif param_name == 'b_enc':
                getattr(training_runner.sae, param_name).data[:dims] = param_value
            elif param_name == 'b_dec':
                getattr(training_runner.sae, param_name).data = param_value
            else:
                getattr(training_runner.sae, param_name).data = param_value
        print("replaced params")
        initialization_str = f"init_sae: {init_dir} stitch: {transfer_id}"
        if cfg.get('unfold_estimated_norm_factor') is None:
            training_runner.cfg.unfold_estimated_norm_factor = (training_runner.cfg.normalize_activations == "expected_average_only_in")
        else:
            training_runner.cfg.unfold_estimated_norm_factor = cfg.get('unfold_estimated_norm_factor')
    else:
        initialization_str = 'random'
        training_runner.cfg.unfold_estimated_norm_factor = False
    training_runner.cfg.initialization = initialization_str
    with open(os.path.join(training_runner.cfg.checkpoint_path, 'initialization.txt'), 'w') as f:
        f.write(initialization_str)
    print(initialization_str)

    # I HAVE A CONVERSION FROM ITERATIONS TO TOKENS
    if isinstance(cfg['checkpoint_thresholds'], int):
        training_runner.cfg.checkpoint_thresholds = list(
            range(
                0,
                total_training_steps,
                total_training_steps // cfg['checkpoint_thresholds'],
            )
        )[1:]
    else:
        training_runner.cfg.checkpoint_thresholds = cfg['checkpoint_thresholds']
    print("Saving run at", training_runner.cfg.checkpoint_thresholds)
    if count_flops:
        flop_counter = FlopCounterMode(display=False)
        with flop_counter:
            training_runner.run()
        total_flops = flop_counter.get_total_flops()
        print("Total FLOPs", f"{total_flops:e}")
        with open(os.path.join(training_runner.cfg.checkpoint_path, 'flop_count.txt'), 'w') as f:
            f.write(str(total_flops))
            f.write('\n')
            f.write(f"{total_flops:e}")
    else:
        training_runner.run()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('cfg_filename')
    args = parser.parse_args()
    with open(args.cfg_filename) as stream:
        cfg = yaml.safe_load(stream)
    run_job(cfg)
