# %%
import os
import sys
from omegaconf import DictConfig
import hydra
import torch
from omegaconf import DictConfig, OmegaConf
import hydra
from transformers import AutoModelForCausalLM
import re

script_dir = os.path.dirname(os.path.abspath(__file__))
project_dir = os.path.join(script_dir, '..')
sys.path.append(script_dir)
sys.path.append(project_dir)

from sparsify.sparsify import SaeConfig, Trainer, TrainConfig
from hooks import get_subspace_ablation
from utils import load_tokenized_data


# %%
@hydra.main(config_path=f'{project_dir}/config', config_name='train')
def train_sae(args: DictConfig):
    print(OmegaConf.to_yaml(args))

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # fix random seeds
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    run_name = f'{args.model_name} -d_sae {args.d_sae} -k {args.k} -n {args.n_sequences} -b {args.batch_size} -h {args.hookpoint} -s {args.seed}'

    tokenized = load_tokenized_data(args.cache_dir, args.dataset_path, args.n_sequences, args.context_length, args.model_name)

    print(f"Tokenized {len(tokenized)} sequences * {args.context_length} tokens = {len(tokenized) * args.context_length} tokens")

    if args.dtype == 'float32':
        dtype = torch.float32
    elif args.dtype == 'bfloat16':
        dtype = torch.bfloat16
    else:
        raise ValueError(f'Unknown dtype: {args.dtype}')

    model = AutoModelForCausalLM.from_pretrained(
        args.model_name,
        device_map={"": device},
        torch_dtype=dtype,
        cache_dir=args.cache_dir,
    )

    maybe_random_indices = None
    noise_dim = 0
    variance = None

    if args.model_hook == 'none':
        print('No hook')
    else:
        if 'only_dense_subspace' in args.model_hook:
            print('Dense subspace ablation')
            type = 'only_dense'
            threshold = float(re.search(r'subspace_(\d+\.\d+)', args.model_hook).group(1))
        elif 'only_sparse_subspace' in args.model_hook:
            print('Sparse subspace ablation')
            type = 'only_sparse'
            threshold = float(re.search(r'subspace_(\d+\.\d+)', args.model_hook).group(1))
        elif 'ablation_subset_sparse_subspace' in args.model_hook:
            print('Subset sparse subspace ablation')
            n_dims = int(args.model_hook.split('_')[-1])
            type = f'ablation_sparse_subset_{n_dims}'
            threshold = 0.1
        elif 'subset_sparse_subspace' in args.model_hook:
            print('Subset sparse subspace ablation')
            n_dims = int(args.model_hook.split('_')[-1])
            type = f'projection_sparse_subset_{n_dims}'
            threshold = 0.1
        elif 'dense_dec_score' in args.model_hook:
            type = args.model_hook
            threshold = 0.1
        else:
            raise ValueError(f'Unknown model hook: {args.model_hook}')
        
        layer_idx = int(args.hookpoint.split('.')[-1])
        
        hook_fn, maybe_random_indices = get_subspace_ablation(args.model_name, layer_idx, device, threshold=threshold, type=type, ckpt_name=args.ckpt_name, dataset=args.density_dataset_path, num_docs=args.density_num_docs, dtype=dtype)
        
        run_name += f' -h {args.model_hook}'

        if args.ckpt_name != 'none':
            run_name += f' -ckpt {args.ckpt_name}'
        
        blocks_name = '.'.join(args.hookpoint.split('.')[:-1])
        layer_idx = int(args.hookpoint.split('.')[-1])
        blocks = eval(f'model.{blocks_name}')
        handle = blocks[layer_idx].register_forward_hook(hook_fn)

    cfg = SaeConfig(
        num_latents=args.d_sae,
        k=args.k,
        normalize_decoder=args.normalize_decoder,
    )
    train_cfg = TrainConfig(cfg,
        hookpoints=[args.hookpoint],
        lr=args.lr,
        init_seeds=[args.seed],
        auxk_alpha=args.auxk_alpha,
        batch_size=args.batch_size,
        run_name=run_name,
        log_to_wandb=args.log_to_wandb,
        save_every=args.save_every,
        wandb_log_frequency=args.wandb_log_frequency,
        noise_dim=noise_dim,
        noise_variance=variance,
    )

    save_run_name = train_cfg.run_name.replace(' ', '_').replace('openai-community/', '')
    path = f"sae-ckpts/{save_run_name}"
    # create the directory if it doesn't exist
    os.makedirs(path, exist_ok=True)
    # dump args as json
    with open(f"{path}/args.json", "w") as f:
        f.write(OmegaConf.to_yaml(args))

    if maybe_random_indices is not None:
        random_indices = maybe_random_indices
        torch.save(random_indices, f"{path}/random_indices_sparse_subspace.pt")


    trainer = Trainer(train_cfg, tokenized, model)
    trainer.fit()

# %%
if __name__ == '__main__':
    train_sae()