import os
# set up logging
import logging
from functools import partial

from wandb.old.summary import h5py

logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)
import time
import csv
import pdb
import argparse
from typing import List, Dict, Tuple, Optional, Any

import pickle
from tqdm import tqdm

import numpy as np
import pandas as pd
import torch
from torch.utils.data.dataloader import DataLoader
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, CPUOffload, ShardingStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig
from accelerate import init_empty_weights, load_checkpoint_and_dispatch

from src.dataset import ProbeDataLoader, LMDataloader, GPTDataloaderForInference, ObjectLocationProbeDataLoader, BinaryProbeDataLoader, SpanProbeDataLoader, PhraseProbeDataLoader
from src.model import T5ForProbing, GPTForProbing, LlamaForProbing
from src.probe_trainer import Trainer, TrainerConfig
from src.probe_model import BatteryProbeClassification, ObjectLocationProbeClassification, BatteryProbeClassificationTwoLayer
from src.utils import get_token_pos_given_span_types, get_object_mapping, get_quantization_config

_MAX_SOURCE_TEXT_LENGTH = {
    "t5": 512,
    "gpt": 512,
    "llama": 2048,
    "Llama-3.1-8B": 4096,
    "Llama-3.1-70B": 2048,
    "Llama-3.1-405B": 2048,
    "CodeLlama-13b-hf": 4096,
    "gemma-2-2b": 8192,
    "Qwen3-14B": 5120
}

_MAX_TARGET_TEXT_LENGTH = 100


_INPUT_DIMENSIONS = {
    "t5": 768,
    "gpt": 1600,
    "llama": 5120,  # codellama13b
    "Llama-3.1-8B": 4096,
    "Llama-3.1-70B": 8192,
    "Llama-3.1-405B": 16384,
    "CodeLlama-13b-hf": 5120,
    "gemma-2-2b": 2304,
    "Qwen3-14B": 5120
}

# make deterministic
torch.manual_seed(0)

LARGE_FILE_LIMIT = 10000  # if more than this number of activations to cache, use h5


def get_activations_from_data(act_all_container, act_container, args, end_idx, inference_device, dataloader, model, tokenizer, object_list):
    for data in tqdm(dataloader, total=len(dataloader), desc="caching activations"):
        ids = data['prefix_ids'].to(inference_device, dtype=torch.long)
        mask = data['prefix_attn_masks'].to(inference_device, dtype=torch.long)
        if end_idx is not None:
            ids = ids[0][:end_idx].unsqueeze(0)
            mask = mask[0][:end_idx].unsqueeze(0)

        if "-" in args.condition_on:
            batch_token_pos = [get_token_pos_given_span_types(input_ids, tokenizer, args.condition_on, object_list) for input_ids in ids]
        elif "period_comma" in args.condition_on:
            batch_token_pos = [get_token_pos_given_span_types(input_ids, tokenizer, args.condition_on, object_list) for input_ids in ids]
            if "period_comma_prior" in args.condition_on:
                batch_token_pos = [[pos-1 for pos in token_pos] for token_pos in batch_token_pos]
        elif "number_all" in args.condition_on or "object_all" in args.condition_on:
            batch_token_pos = [get_token_pos_given_span_types(input_ids[:-4], tokenizer, args.condition_on, object_list) for input_ids in ids]
        else:
            batch_token_pos = [-1] *len(ids)

        if args.save_model_representation:
            if any([isinstance(model, c) for c in [LlamaForProbing, GPTForProbing, T5ForProbing]]):
                act = model(input_ids=ids, attention_mask=mask, return_all_layers=True)
            elif args.ndif_remote:
                # TODO: get model activation for this example
                raise NotImplementedError
            else: # hf models
                act = model(input_ids=ids, attention_mask=mask, output_hidden_states=True).hidden_states
            act_all_container.append([a[0, batch_token_pos, :].detach().cpu() for a in act])
            act_container.append(act[args.layer - 1][0, batch_token_pos, :].detach().cpu())
            if "-" in args.condition_on or "_" in args.condition_on:
                # add 6 empty caches so the total number of datapoint matches with original data without subseting
                for _ in range(6):
                    act_all_container.append([])
                    act_container.append([])
        else:
            # last hidden state
            if [isinstance(model, c) for c in [LlamaForProbing, GPTForProbing, T5ForProbing]]:
                act = model(input_ids=ids, attention_mask=mask, return_all_layers=True)
            else: # hf models
                act = model(input_ids=ids, attention_mask=mask, output_hidden_states=True).hidden_states
            act_container.append(act[args.layer-1][0, batch_token_pos, :].detach().cpu())


def save_representation(rep: List, split:str, args):
    if not os.path.exists(args.model_representation_path):
        os.makedirs(args.model_representation_path, exist_ok=True)

    subset_str = '_subset' if (split == 'train' and args.dataset_subset) or (split == 'test' and (args.dataset_subset or args.dataset_subset_test_only)) else ''
    if len(rep) > LARGE_FILE_LIMIT:
        rep_path = os.path.join(args.model_representation_path, f"representations_{split}{subset_str}")
        # having rank-specific path (as opposed to only have rank-0 save) works, but not sure if rank-0 save only works
        # rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
        # rep_path_rank = os.path.join(rep_path, f"rank{rank}")
        # os.makedirs(rep_path_rank, exist_ok=True)
        os.makedirs(rep_path, exist_ok=True)

        num_layers = len(rep[0])
        # open close 1 file at a time (less efficient than all files open, but that cause large file save to fail in 
        # in cluster with torch run distributed process
        # create groups once (optional) by touching files first
        for i in range(num_layers):
            p = os.path.join(rep_path, f"representations_l{i + 1}.h5")
            if not os.path.exists(p):
                with h5py.File(p, "w") as f:
                    f.create_group("activations")

        # then for each activation, open-append-close per write
        for act_i, act in tqdm(enumerate(rep), total=len(rep), desc=f"Saving layer-wise activations"):
            if len(act) == 0:
                continue
            for layer_i in range(num_layers):
                p = os.path.join(rep_path, f"representations_l{layer_i + 1}.h5")
                with h5py.File(p, "a") as f:  # 'a' == read/write, create if missing
                    grp = f["activations"]
                    grp.create_dataset(f"activations_{act_i}", data=act[layer_i][0].numpy(), compression="gzip")

        # OLD, save all layers in 1 file
        # rep_path = os.path.join(args.model_representation_path, f"representations_{split}.h5")
        # with h5py.File(rep_path, "w") as f:
        #     group = f.create_group('activations')
        #     for i, act in tqdm(enumerate(rep), total=len(rep)):
        #         if len(act) > 0:  # every 7 exmaples is non empty
        #             group.create_dataset(f"activations_{i}", data=torch.cat(act).numpy())
    else:
        rep_path = os.path.join(args.model_representation_path, f"representations_{split}{subset_str}.p")
        with open(rep_path, "wb") as rep_f:
            pickle.dump(rep, rep_f)
            rep.clear()


def fsdp_model(args):
    if "llama" in args.model_path.lower():
        from transformers.models.llama.modeling_llama import LlamaDecoderLayer
    else:
        raise NotImplementedError("FSDP layering not implemented for non-llama model yet")

    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    device = torch.device(f"cuda:{local_rank}")

    # Load model on CPU first (important!)
    config = AutoConfig.from_pretrained(args.model_path)

    with init_empty_weights():
        model = AutoModelForCausalLM.from_config(config) #,device_map=None)

    for name, p in model.named_parameters(recurse=True):
        if p.device.type != "meta":
            setattr(model, name, p.to("meta"))

    # diagnostic — run immediately AFTER init_empty_weights() and model creation
    bad = []
    for name, p in model.named_parameters(recurse=True):
        if p.device.type != "meta":
            bad.append((name, p.device, p.size()))
    for name, b in model.named_buffers(recurse=True):
        if b.device.type != "meta":
            bad.append(("buffer:" + name, b.device, b.size()))

    if bad:
        print("Found non-meta tensors BEFORE to_empty():")
        for n, dev, sz in bad[:20]:
            print(n, dev, sz)
        raise SystemError("Some parameters/buffers were materialized before to_empty() — fix the creation pipeline.")
    else:
        print("All params & buffers are meta. Safe to call to_empty().")
    pdb.set_trace()

    # model.to_empty(device=device)

    # Create auto-wrap policy
    auto_wrap_policy = partial(transformer_auto_wrap_policy,transformer_layer_cls={LlamaDecoderLayer})

    # Wrap in FSDP
    model = FSDP(
        model,
        auto_wrap_policy=auto_wrap_policy,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        cpu_offload=CPUOffload(offload_params=False),
        use_orig_params=True,
        device_id=device,
        param_init_fn=lambda module: None,  # <---- prevents reset_parameters()
    )

    model = load_checkpoint_and_dispatch(
        model,
        args.model_path,
        device_map={"": f"cuda:{local_rank}"},
        no_split_module_classes=["LlamaDecoderLayer"],
        dtype=torch.bfloat16,
        offload_state_dict=True,
    )

    return model


def main():

    parser = argparse.ArgumentParser(description='Train classification network')
    parser.add_argument("--model_type",
                        required=True,
                        choices=_INPUT_DIMENSIONS.keys(),
                        help=f"{_INPUT_DIMENSIONS.keys()} supported.")
    parser.add_argument("--load_in_8bit", action="store_true", help="Load model with 8-bit")
    parser.add_argument("--load_in_4bit", action="store_true", help="Load model with 4-bit")
    parser.add_argument("--dataset_path",
                        required=True,
                        type=str)
    parser.add_argument("--dataset_subset",
                        dest='dataset_subset',
                        action='store_true'
                        )
    parser.add_argument("--dataset_subset_test_only",
                        dest='dataset_subset_test_only',
                        action='store_true'
                        )
    parser.add_argument('--dataset_loader_num_workers',
                        default=4,
                        type=int)
    parser.add_argument("--model_path",
                        required=False,
                        default=None,
                        type=str)
    parser.add_argument('--checkpoint_root',
                        default="./probe_checkpoints", type=str)
    parser.add_argument(
            "--object_vocabulary_file",
            type=str,
            default="data/objects_with_bnc_frequency.csv",
            help='Path to a .csv file with a string field "object_names".')
    parser.add_argument('--layer',
                        required=True,
                        default=-1,
                        type=int)
    parser.add_argument('--epo',
                        default=16,
                        type=int)
    parser.add_argument('--condition_on', 
                        choices=["box", "period", "the", "number", "contains",
                                 "number-object-remove", "number-object-exist",  # span probs
                                 "number-remove", "object-remove",  # span probs (using only one of the tokens)
                                 "period_comma_local", "period_comma_cumulative", # period or comma, only 1 token, ternary probe
                                 "period_comma_prior_local", "period_comma_prior_cumulative", # 1 token before period or comma (box id), only 1 token , ternary probe (should use number_all)
                                 "number_all_local", "number_all_cumulative", # ternary probe, condition on all box ids (1 at a time)
                                 "object_all_local", "object_all_cumulative", # ternary probe, condition on all objects (1 at a time)
                                 ],
                        type=str,
                        dest='condition_on',
                        default='number')
    parser.add_argument('--max_train_data',
                        type=int,
                        default=None)
    parser.add_argument('--max_test_data',
                        type=int,
                        default=None)
    parser.add_argument('--num_prior_state',
                        default=-1,
                        type=int)
    parser.add_argument('--lr',
                        type=float,
                        default=1e-3)
    parser.add_argument("--overwrite_cache",
                        dest='overwrite_cache',
                        action='store_true')

    # span probe args (when condition_on=="number-object")
    parser.add_argument("--expand_query_box",
                        action=argparse.BooleanOptionalAction,
                        default=False,
                        help="whether to cache query box id token (in addition to other latest tokens)")
    parser.add_argument("--balance_label_sampling",
                        action=argparse.BooleanOptionalAction,
                        default=True,
                        help="whether to balance label distribution by downsampling no-remove cases")
    parser.add_argument('--same_phrase_only',
                        choices=['train', 'test', 'both', 'neither'],
                        default='neither',
                        type=str,
                        )

    # Non-linear probe
    parser.add_argument('--mid_dim',
                        default=128,
                        type=int)
    parser.add_argument('--twolayer',
                        dest='twolayer', 
                        action='store_true')
    parser.add_argument('--object_location',
                        dest='object_location', 
                        action='store_true')    
    parser.add_argument('--binary_probe',
                        dest='binary_probe', 
                        action='store_true')
    parser.add_argument('--random',
                        dest='random', 
                        action='store_true')
    parser.add_argument('--eval_only',
                        dest='eval_only', 
                        action='store_true')
    parser.add_argument('--exclude_empty',
                        dest='exclude_empty', 
                        action='store_true')
    parser.add_argument('--condition_on_obj',
                        default=0,
                        type=int)
    parser.add_argument('--debug_train',
                        action='store_true')

    parser.add_argument('--model_representation_path',
                        default=None,
                        type=str)

    parser.add_argument('--save_model_representation',
                        dest="save_model_representation",
                        action="store_true")
    
    parser.add_argument('--load_model_representation',
                        dest="load_model_representation",
                        action="store_true")

    parser.add_argument('--include_prompt',
                        type=str,
                        default=False,
                        choices=[False, "PROMPT", "PROMPT_ALTFORM", "PROMPT_ALLBOX_ALTFORM", "INSTRUCTION", ]
                        )
    parser.add_argument(
        "--chat", action="store_true", help="format prompt into chat templates"
    )
    
    # distributed inference ( for caching embedding) 
    parser.add_argument('--distributed',
                        dest="distributed",
                        action="store_true")
    parser.add_argument("--local-rank", "--local_rank", type=int)
    parser.add_argument('--fsdp',
                        dest="fsdp", # TODO attempts to use fsdp to load llama 70b into 4GPU runs, still failing
                        action="store_true")
    parser.add_argument('--ndif_remote',
                        dest="ndif_remote",
                        action="store_true",
                        help="whether to use NDIF remote option to save model hiddens"
                        )

    args, _ = parser.parse_known_args()

    if (args.condition_on not in ["number", "contains"]) and args.model_type == 't5':
        raise ValueError("--condition_on must be set to 'number' or 'contains' when training a probe on T5.")
    # if args.eval_only:
    #     raise ValueError("--eval_only is buggy, do not use for now")

    if args.exclude_empty and not args.binary_probe:
        raise ValueError("--exclude_empty only works with --binary_probe")

    if args.exclude_empty and args.condition_on not in ["contains", "the"]:
        raise ValueError("--exclude_empty can only be used with --condition_on 'contains' or 'the'")
    
    if args.condition_on in ["contains", "the"] and not args.exclude_empty:
        raise ValueError("--condition_on 'contains' or 'the' can only be used with --exclude_empty")

    if args.save_model_representation and args.model_representation_path is None:
        raise ValueError("--save_model_representation requires --model_representation_path to be set")
    
    if args.load_model_representation and args.model_representation_path is None:
        raise ValueError("--load_model_representation requires --model_representation_path to be set")
    
    if args.load_model_representation and args.save_model_representation:
        raise ValueError("--load_model_representation and --save_model_representation cannot be used together")

    if args.max_train_data is not None:
        assert args.max_train_data % 7 == 0, "number of data points must be divisible by 7"

    if args.max_test_data is not None:
        assert args.max_test_data % 7 == 0, "number of data points must be divisible by 7"

    if args.dataset_subset or args.dataset_subset_test_only:
        assert os.path.exists(os.path.join(args.dataset_path,f'test-subsample-states-{"t5" if args.model_type == "t5" else "gpt"}.jsonl'))
        if "movecontent" in args.dataset_path.lower() or "move_content" in args.dataset_path.lower():
            assert os.path.exists(os.path.join(args.dataset_path, f'train-subsample-states-mask.p'))

    if args.fsdp:
        assert args.distributed, "--fsdp requires --distributed"

    folder_name = f"probing/state"
    if args.include_prompt:
        folder_name = folder_name + f"_fs{args.include_prompt}"
    if args.chat:
        folder_name = folder_name + f"_chat"
    if args.twolayer:
        folder_name = folder_name + f"_tl{args.mid_dim}"  # tl for probes without batchnorm
    if args.random:
        folder_name = folder_name + "_random"
    if args.object_location:
        folder_name = folder_name + "_object_location"
    if args.binary_probe:
        folder_name = folder_name + "_binary"
    if args.exclude_empty:
        folder_name = folder_name + "_exclude_empty"
    if args.condition_on_obj > 0:
        folder_name = folder_name + f"_condition_on_obj_{args.condition_on_obj}"
    if args.num_prior_state != -1:
        folder_name = folder_name + f"_prior_state_{args.num_prior_state}"
    if "-" in args.condition_on:
        folder_name = folder_name + f"_span_{args.condition_on}"
    if "_" in args.condition_on:
        folder_name = folder_name + f"_{args.condition_on}"

    training_file = os.path.join(args.checkpoint_root, folder_name, f"layer{args.layer}_token1", "tensorboard.txt")
    if not args.overwrite_cache and os.path.exists(training_file) and len(open(training_file).readlines())>0:
        print(f"Found trained probe, skipping: {training_file}")
        exit(0)

    print(f"Running experiment for {folder_name}")
    print("[Data]: Reading data...\n")

    device = 'cuda' if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else 'cpu'
    print(f'Training Device: {device}')


    # Load data
    data_type = "t5" if args.model_type == "t5" else "gpt"
    dataset_path_train = os.path.join(args.dataset_path, f'train{"-subsample-states" if args.dataset_subset else ""}-{data_type}.jsonl')
    print("Train dataset:", dataset_path_train)
    dataset_path_test = os.path.join(args.dataset_path, f'test{"-subsample-states" if (args.dataset_subset or args.dataset_subset_test_only) else ""}-{data_type}.jsonl')

    train_df = pd.read_json(dataset_path_train, orient='records', lines=True)
    test_df = pd.read_json(dataset_path_test, orient='records', lines=True)

    if args.eval_only:
        train_df = train_df.head(0)

    if args.model_type == "t5":
        train_df = train_df[["sentence_masked", "masked_content"]]
        test_df = test_df[["sentence_masked", "masked_content"]]

    # Load object names
    object_map, object_list = get_object_mapping(args.object_vocabulary_file)
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)

    act_container_train = []
    act_all_container_train = []
    act_container_test = []
    act_all_container_test = []
    act_all_h5_train_dir = None
    act_all_h5_test_dir = None

    if args.load_model_representation:

        # load pre-computed representations
        train_rep_path = os.path.join(args.model_representation_path, f"representations_train{'_subset' if args.dataset_subset else ''}.p")
        test_rep_path = os.path.join(args.model_representation_path, f"representations_test{'_subset' if (args.dataset_subset or args.dataset_subset_test_only) else ''}.p")

        act_all_h5_train_dir = train_rep_path.replace(".p", "")
        act_all_h5_test_dir = test_rep_path.replace(".p", "")
        
        if os.path.isdir(act_all_h5_train_dir): # and os.path.exists(exploded_train_path):
            print("h5 cache for train representation found, skipping loading ...")
        else:
            act_all_h5_train_dir = None
            print(f"loading cached train representations from {train_rep_path} ...")
            with open(train_rep_path, "rb") as rep_f:
                act_all_container_train = pickle.load(rep_f)

            for act in act_all_container_train:
                if len(act) == 0 and ("-" in args.condition_on or "_" in args.condition_on):
                    act_container_train.append([])
                else:
                    act_container_train.append(act[args.layer - 1])

            act_all_container_train.clear()
            
        if os.path.exists(act_all_h5_test_dir): # and os.path.exists(exploded_test_path):
            print("h5 cache for test representation found, skipping loading ...")
        else:
            act_all_h5_test_dir = None
            print(f"loading cached test representations from {test_rep_path} ...")
            with open(test_rep_path, "rb") as rep_f:
                act_all_container_test = pickle.load(rep_f)

            for act in act_all_container_test:
                if len(act) == 0 and ("-" in args.condition_on or "_" in args.condition_on):
                    act_container_test.append([])
                else:
                    act_container_test.append(act[args.layer - 1])

            act_all_container_test.clear()

    else:
        # set up distributed inference
        if args.distributed:
            rank = int(os.environ.get("RANK",'0'))
            print(f"{rank=}")
            inference_device = torch.device(f"cuda:{rank}")
            if not args.fsdp:
                torch.cuda.set_device(inference_device)  # https://github.com/pytorch/pytorch/issues/146767
            torch.distributed.init_process_group("nccl", device_id=inference_device)
            tp_plan = "auto"
        else:
            inference_device = device
            tp_plan = None
        model_kwargs = {"tp_plan": tp_plan}

        # Load model to compute representations
        if args.distributed:
            if args.load_in_8bit:
                model_kwargs["load_in_8bit"] = True
            elif args.load_in_4bit:
                model_kwargs["load_in_4bit"] = True
        else:
            model_kwargs["quantization_config"] = get_quantization_config(args)
        if model_kwargs.get("quantization_config") is None and transformers.utils.is_torch_bf16_gpu_available():
            model_kwargs["torch_dtype"] = torch.bfloat16

        if args.model_type == "t5":
            model = T5ForProbing.from_pretrained(args.model_path)
        elif args.model_type == "gpt":
            model = GPTForProbing.from_pretrained(args.model_path)
        elif "llama" in args.model_type.lower() and not args.fsdp and not args.ndif_remote:
            model = LlamaForProbing.from_pretrained(args.model_path, **model_kwargs) #,device_map="auto",
        elif args.fsdp:
            model = fsdp_model(args)
        elif args.ndif_remote:
            # TODO here, initialize ndif remote model
            pass
        else:
            model = AutoModelForCausalLM.from_pretrained(args.model_path, **model_kwargs) #,device_map="auto",

        if model_kwargs.get("quantization_config") is None and not args.distributed:
            model = model.to(device)

        # Set probe layer (1-indexed)
        model.probe_layer = args.layer
        inference_device = model.device
        model.eval()

        # initialze LM test dataset
        if args.model_type == "t5":
            train_dataset = LMDataloader(train_df, tokenizer, _MAX_SOURCE_TEXT_LENGTH[args.model_type], _MAX_TARGET_TEXT_LENGTH, "sentence_masked", "masked_content", include_empty=not args.exclude_empty, min_prev_objects=args.condition_on_obj)
            test_dataset = LMDataloader(test_df, tokenizer, _MAX_SOURCE_TEXT_LENGTH[args.model_type], _MAX_TARGET_TEXT_LENGTH, "sentence_masked", "masked_content", include_empty=not args.exclude_empty, min_prev_objects=args.condition_on_obj)
        else: #if args.model_type in ["gpt", "llama"]:
            train_dataset = GPTDataloaderForInference(train_df, tokenizer, max_length=_MAX_SOURCE_TEXT_LENGTH[args.model_type], include_empty=not args.exclude_empty, condition_on=args.condition_on, min_prev_objects=args.condition_on_obj, include_prompt=args.include_prompt, args=args)
            test_dataset = GPTDataloaderForInference(test_df, tokenizer, max_length=_MAX_SOURCE_TEXT_LENGTH[args.model_type], include_empty=not args.exclude_empty, condition_on=args.condition_on, min_prev_objects=args.condition_on_obj, include_prompt=args.include_prompt, args=args)

        # truncate dataset if needed
        if args.max_train_data is not None:
            train_dataset = torch.utils.data.Subset(train_dataset, range(args.max_train_data))
        if args.max_test_data is not None:
            test_dataset = torch.utils.data.Subset(test_dataset, range(args.max_test_data))

        if ("-" in args.condition_on or "_" in args.condition_on) and args.save_model_representation:
            # for span probe, we really just need the representation for each unique data point (disregarding query box)
            # so on caching runs, inference every 7 datapoints to get the representation to save some space
            train_dataset = torch.utils.data.Subset(train_dataset, range(0, len(train_dataset), 7))
            test_dataset = torch.utils.data.Subset(test_dataset, range(0, len(test_dataset), 7))

        loader_train = DataLoader(train_dataset, shuffle=False, pin_memory=True, batch_size=1, num_workers=1)
        loader_test = DataLoader(test_dataset, shuffle=False, pin_memory=True, batch_size=1, num_workers=1)

        # compute hidden representations  (deleted T5 ones)
        end_idx = None
        # I've checked that " n" between 0 and 7 are all single tokens.
        # But this is probably a patchy solution if box num >= 8
        if args.condition_on == "box":
            end_idx = -3  # originally -1
        # "Box" is one token, "Box 3" is 3 tokens
        elif args.condition_on == "period":
            end_idx = -4  # originally -2, our data ends with "contains", so one extra word, also 'Box 3' is 3 tokens (space is one)

        get_activations_from_data(act_all_container_train, act_container_train, args, end_idx, inference_device, loader_train, model, tokenizer, object_list)
        if args.save_model_representation:  # save them as we go, cheaper in memory
            if not args.distributed or torch.distributed.get_rank() == 0:
                save_representation(act_all_container_train, split="train", args=args)

        get_activations_from_data(act_all_container_test, act_container_test, args, end_idx, inference_device, loader_test, model, tokenizer, object_list)
        if args.save_model_representation:
            if not args.distributed or torch.distributed.get_rank() == 0:
                save_representation(act_all_container_test, split="test", args=args)

        if args.distributed:
            torch.distributed.destroy_process_group()

        print("caching done! existing program for now. re-run with load_model_representation! ")
        exit()



    probe_class = 8 if not args.binary_probe else 2
    input_dim = _INPUT_DIMENSIONS[args.model_type]

    # if moveContent split, need to load the whole dataset, then apply subsample mask after computing prior states
    train_subset_mask, test_subset_mask = None, None
    if args.num_prior_state != -1 and ("movecontent" in dataset_path_train.lower() or "move_content" in dataset_path_train.lower()):
        dataset_path_train = dataset_path_train.replace("-subsample-states", "")
        dataset_path_test = dataset_path_test.replace("-subsample-states", "")
        with open(os.path.join(args.dataset_path, f'train-subsample-states-mask.p'), "rb") as rep_f:
            train_subset_mask = pickle.load(rep_f)
        with open(os.path.join(args.dataset_path, f'test-subsample-states-mask.p'), "rb") as rep_f:
            test_subset_mask = pickle.load(rep_f)

    if args.object_location:
        probing_dataset_train = ObjectLocationProbeDataLoader(act_container_train, dataset_path_train, max_data=args.max_train_data)
        probing_dataset_test = ObjectLocationProbeDataLoader(act_container_test, dataset_path_test, max_data=args.max_test_data)
    elif args.binary_probe:
        if "-" in args.condition_on:  # span probe
            probing_dataset_train = SpanProbeDataLoader(act_container_train, dataset_path_train, object_map,include_empty=not args.exclude_empty,min_prev_objects=args.condition_on_obj,max_data=args.max_train_data, tokenizer=tokenizer,expand_query_box=args.expand_query_box,balance_label_sampling=args.balance_label_sampling, span_probe_type=args.condition_on, args=args, split="train", same_phrase_only=args.same_phrase_only in ["train", "both"])
            probing_dataset_test = SpanProbeDataLoader(act_container_test, dataset_path_test, object_map,include_empty=not args.exclude_empty,min_prev_objects=args.condition_on_obj,max_data=args.max_test_data, tokenizer=tokenizer,expand_query_box=args.expand_query_box,balance_label_sampling=args.balance_label_sampling, span_probe_type=args.condition_on, args=args, split="test", same_phrase_only=args.same_phrase_only in ["test", "both"])
        else:
            probing_dataset_train = BinaryProbeDataLoader(act_container_train, dataset_path_train, object_map, include_empty=not args.exclude_empty, min_prev_objects=args.condition_on_obj, max_data=args.max_train_data, local_operation_order=args.num_prior_state, subset_mask=train_subset_mask)
            probing_dataset_test = BinaryProbeDataLoader(act_container_test, dataset_path_test, object_map, include_empty=not args.exclude_empty, min_prev_objects=args.condition_on_obj, max_data=args.max_test_data, local_operation_order=args.num_prior_state, subset_mask=test_subset_mask)
    elif "_" in args.condition_on:  # phrase probes, takes 1 hidden state to predict 3 classes
        probing_dataset_train = PhraseProbeDataLoader(act_container_train, dataset_path_train, object_map,include_empty=not args.exclude_empty, max_data=args.max_train_data, tokenizer=tokenizer, args=args, split="train", activation_h5_path=act_all_h5_train_dir)
        probing_dataset_test = PhraseProbeDataLoader(act_container_test, dataset_path_test, object_map,include_empty=not args.exclude_empty, max_data=args.max_test_data, tokenizer=tokenizer, args=args, split="test", activation_h5_path=act_all_h5_test_dir)
    else:
        probing_dataset_train = ProbeDataLoader(act_container_train, dataset_path_train, object_map, max_data=args.max_train_data)
        probing_dataset_test = ProbeDataLoader(act_container_test, dataset_path_test, object_map, max_data=args.max_test_data)

    train_dataset, test_dataset = probing_dataset_train, probing_dataset_test

    if args.object_location:
        if args.twolayer:
            raise ValueError("Parameter --twolayer is not supported when using the object location probe.")
        probe = ObjectLocationProbeClassification(device,
            input_dim=input_dim,
            probe_class=probe_class,
            ce_weights=probing_dataset_train.get_weights().to(device, dtype=torch.float32))
    elif "-" in args.condition_on:
        if args.twolayer:
            probe = BatteryProbeClassificationTwoLayer(device,
                input_dim=input_dim*2 if args.condition_on.startswith("number-object") else input_dim,
                probe_class=probe_class,
                num_task=1,
                mid_dim=args.mid_dim,
                ce_weights=probing_dataset_train.get_weights().to(device, dtype=torch.float32),
                # dtype=probing_dataset_train.activations[0].dtype
                )
        else:
            probe = BatteryProbeClassification(device,
                input_dim=input_dim*2 if args.condition_on.startswith("number-object") else input_dim,
                probe_class=probe_class,
                num_task=1,
                ce_weights=probing_dataset_train.get_weights().to(device, dtype=torch.float32),
                # dtype=probing_dataset_train.activations[0].dtype
                )
    elif "_" in args.condition_on:
        if args.twolayer:
            probe = BatteryProbeClassificationTwoLayer(device,
                input_dim=input_dim,
                probe_class=3,  # ternary probe of [exist, non-exist, removed]
                num_task=7 * 100,  # not 8 because non-exist covers the case
                mid_dim=args.mid_dim,
                ce_weights=probing_dataset_train.get_weights().to(device, dtype=torch.float32),
                )
        else:
            probe = BatteryProbeClassification(device,
                input_dim=input_dim,
                probe_class=3,   # ternary probe of [exist, non-exist, removed]
                num_task=7 * 100,  # not 8 because non-exist covers the case
                ce_weights=probing_dataset_train.get_weights().to(device, dtype=torch.float32),
                )
    else:
        if args.twolayer:
            probe = BatteryProbeClassificationTwoLayer(device,
                input_dim=input_dim,
                probe_class=probe_class,
                num_task=100,
                mid_dim=args.mid_dim,
                ce_weights=probing_dataset_train.get_weights().to(device, dtype=torch.float32),
                )             
        else: 
            probe = BatteryProbeClassification(device,
                input_dim=input_dim,
                probe_class=probe_class,
                num_task=100,
                ce_weights=probing_dataset_train.get_weights().to(device, dtype=torch.float32),
                )        

    max_epochs = args.epo
    t_start = time.strftime("_%Y%m%d_%H%M%S")
    tconf = TrainerConfig(
        max_epochs=max_epochs, batch_size=1024, learning_rate=args.lr,#1e-3,
        betas=(.9, .999), 
        lr_decay=True, warmup_tokens=len(train_dataset)*5, 
        final_tokens=len(train_dataset)*max_epochs,
        num_workers=args.dataset_loader_num_workers, # originally 4, if loading h5 files during training, need more worker because IO is bottleneck
        weight_decay=0.,
        ckpt_path=os.path.join(args.checkpoint_root, folder_name, f"layer{args.layer}_token1"),
        debug_train=args.debug_train,
    )
    os.makedirs(tconf.ckpt_path, exist_ok=True)
    trainer = Trainer(probe, train_dataset, test_dataset, tconf)
    if not args.eval_only:
        predictions_matrix = trainer.train(prt=True).astype(int)
        trainer.save_traces()
        trainer.save_checkpoint()
    else:
        trainer.load_checkpoint()
        predictions_matrix = trainer.predict(prt=True).astype(int)
    
    predictions_file = os.path.join(tconf.ckpt_path, "predictions.txt")

    if "-" in args.condition_on:
        predictions_df = pd.DataFrame({
            "input": test_dataset.analysis_strings,
            "prediction": predictions_matrix.squeeze().tolist(),
            "label": [i.item() for i in test_dataset.examples],
            "same_phrase": [i.item() for i in test_dataset.mentioned_objects],
        })
        predictions_df.to_json(predictions_file.replace(".txt", ".jsonl"),lines=True, orient="records")
    elif "_" in args.condition_on:
        header = " ".join(object_list) * 7
        # np.savetxt(predictions_file, predictions_matrix, delimiter=" ", fmt='%i', header=header, comments="")
        np.save(predictions_file.replace(".txt", ".npy"), predictions_matrix)
        test_input_path = os.path.join(tconf.ckpt_path, "test_inputs.txt")
        if not os.path.isfile(test_input_path):
            with open(test_input_path, "w") as f:
                f.writelines("\n".join(test_dataset.analysis_strings))
    else:
        header = " ".join(object_list)
        np.savetxt(predictions_file, predictions_matrix, delimiter=" ", fmt='%i', header=header, comments="")

    # save plot
    fig_file = os.path.join(tconf.ckpt_path, "predictions.png")
    trainer.flush_plot(fig_file)





if __name__ == "__main__":
    main()
    

