import logging

import sys
import torch

import json

from model_configs import var_exper
from general_utils.hf_utils import get_model_and_tokenizer
from general_utils.config import config
from ml_collections import ConfigDict
from var_exp import var_exp

def process_args(args):
    for arg, argv in vars(args).items():
        logging.debug(f'{arg} = {argv}')

    if args.device:
        config.device = torch.device(args.device)

    if args.dtype == "fp16":
        config.dtype = torch.float16
    elif args.dtype == "bf16":
        config.dtype = torch.bfloat16
    elif args.dtype == "fp32":
        config.dtype = torch.float32
    else:
        raise ValueError

def measure_var_explained(args) -> None:
    model,tokenizer = get_model_and_tokenizer(args.model, dtype=config.dtype)

    file_path = "./data/probe_valid.jsonl"
    with open(file_path, 'r') as f:
        prompts = [json.loads(line)["text"] for line in f]

    if tokenizer.bos_token is not None:
        prompts_bos = [f"{tokenizer.bos_token} {prompt}" for prompt in prompts]
        prompts = prompts_bos

    var_exp_stats = var_exp(model, tokenizer, prompts, eps=0.2)
    
    # CHANGE BELOW AS NEEDED TO SAVE/PROCESS MEASUREMENTS
    print(var_exp_stats)
    return

# Example usage
if __name__ == "__main__":
    run_id = int(sys.argv[1])
    exper_config = var_exper[run_id - 1]
    exper_args =  ConfigDict(exper_config)
    process_args(exper_args)
    measure_var_explained(exper_args)
