import os
import yaml
with open('global_config.yaml') as global_stream:
    global_cfg = yaml.safe_load(global_stream)
CACHE_DIR = global_cfg['CACHE_DIR']
os.environ["HF_TOKEN"] = global_cfg['hf_access_token']
from transformer_lens import HookedTransformer
import yaml
import torch
import argparse
from tqdm import tqdm
from stitching.losses import next_token_cross_entropy_loss
import json
import numpy as np
import gc
device = 'cuda'

torch.set_grad_enabled(False)

@torch.inference_mode()
def cache_acts(acts_dir, dataloader, model_name, layer, n_batches_per_file, ctx_size=None, remove_padding=True, save_tokens=True):
    os.makedirs(acts_dir, exist_ok=True)
    if 'gemma' in model_name:
        torch_dtype = torch.float16
    else:
        torch_dtype = torch.float32
    print("Loading", model_name)
    model = HookedTransformer.from_pretrained(model_name=model_name, device=device, cache_dir=CACHE_DIR, torch_dtype=torch_dtype)

    acts = []
    toks = []
    filenum = 1
    for i, sample in tqdm(enumerate(dataloader)):
        if ctx_size is not None:
            sample = sample[..., :ctx_size]
        batch_acts = model(sample, stop_at_layer=layer)
        if remove_padding:
            mask = (sample != model.tokenizer.pad_token_id) & (sample != model.tokenizer.eos_token_id) & (sample != model.tokenizer.bos_token_id)
            batch_acts = batch_acts[mask].cpu()
        else:
            batch_acts = batch_acts.cpu()
        acts.append(batch_acts)
        toks.append(sample)
        if ((i != 0) and (i % n_batches_per_file == 0)) or (i == len(dataloader)-1):
            # save the aggregate and clear it
            print(f"Saving at {i} iterations")
            torch.save(torch.cat(acts, dim=0), os.path.join(acts_dir, f"{model_name}_layer_{layer}_cached_activations_{filenum}.pt"))
            if save_tokens:
                torch.save(torch.cat(toks, dim=0), os.path.join(acts_dir, f"tokens_{filenum}.pt"))
            acts = []
            toks = []
            filenum += 1
        

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('cfg_filename')
    args = parser.parse_args()
    with open(args.cfg_filename) as stream:
        cfg = yaml.safe_load(stream)
    tokenized_dataset = {}
    print("Config", cfg)
    tokenized_dataset[cfg['dataset_key']] = torch.load(f'data/{cfg['dataset_name']}_tokenized_dataset_200000_{cfg['dataset_key']}_512.pt', weights_only=True)
    torch.manual_seed(cfg['seed'])
    remove_padding = cfg.get('remove_padding') if cfg.get('remove_padding') is not None else True
    save_tokens = cfg.get('save_tokens') if cfg.get('save_tokens') is not None else True
    dataloader = torch.utils.data.DataLoader(tokenized_dataset[cfg['dataset_key']], batch_size=cfg['batch_size'], shuffle=True)
    os.makedirs(cfg['acts_dir'], exist_ok=True)
    print(f"Made {cfg['acts_dir']}")
    cache_acts(cfg['acts_dir'], dataloader, cfg['model_name'], cfg['layer'], cfg['n_batches_per_file'], ctx_size=cfg['ctx_size'], remove_padding=remove_padding, save_tokens=save_tokens)
