import logging

import sys
import torch

from ml_collections import ConfigDict
from probe_configs import probe_exper
from general_utils.config import config
from general_utils.hf_utils import get_model_and_tokenizer
from data.data_utils import generate_prompts_cities
from probes import sink_decomp, gen_histogram, get_mass_mean_probes, calc_probe_acc

def process_args(args):
    for arg, argv in vars(args).items():
        logging.debug(f'{arg} = {argv}')

    if args.device:
        config.device = torch.device(args.device)

    if args.dtype == "fp16":
        config.dtype = torch.float16
    elif args.dtype == "bf16":
        config.dtype = torch.bfloat16
    elif args.dtype == "fp32":
        config.dtype = torch.float32
    else:
        raise ValueError

def measure_sink_metrics(args) -> None:
    print(f"PyTorch device: {config.device}")
    print(f"Number of available cuda devices: {torch.cuda.device_count()}")
    
    model,tokenizer = get_model_and_tokenizer(args.model, dtype=config.dtype)

    true_statements, false_statements = generate_prompts_cities()
    true_train = true_statements[:200]
    false_train = false_statements[:200]
    true_val = true_statements[200: 300]
    false_val = false_statements[200: 300]
    
    # probes
    tags_true_sinks, tags_true_no_sinks, train_act_true, train_true_hist = sink_decomp(model, tokenizer, true_train, eps=0.2)
    tags_false_sinks, tags_false_no_sinks, train_act_false, train_false_hist = sink_decomp(model, tokenizer, false_train, eps=0.2)

    gen_histogram(model, train_true_hist, train_false_hist, './Histograms/' +  str(0.2).replace(".", "_") + "/")

    tags_probes = get_mass_mean_probes(model, tags_true_sinks, tags_false_sinks)
    no_tags_probes = get_mass_mean_probes(model, tags_true_no_sinks, tags_false_no_sinks)
    full_probes = get_mass_mean_probes(model, train_act_true, train_act_false)

    val_tags_true, val_no_tags_true, val_act_true, _    = sink_decomp(model, tokenizer, true_val, eps=0.2, eval=True)
    val_tags_false, val_no_tags_false, val_act_false, _ = sink_decomp(model, tokenizer, false_val, eps=0.2, eval=True)
    
    # Parse probe accuracies as desired
    tag_acc        = calc_probe_acc(model, tags_probes, val_tags_true, val_tags_false, 0.2, "Tag")
    no_tag_acc     = calc_probe_acc(model, no_tags_probes, val_no_tags_true, val_no_tags_false, 0.2, "No Tag")
    activation_acc = calc_probe_acc(model, full_probes, val_act_true, val_act_false, 0.2, "Activation")

    print(tag_acc)
    return

# Example usage
if __name__ == "__main__":
    run_id = int(sys.argv[1])
    exper_config = probe_exper[run_id - 1]
    exper_args =  ConfigDict(exper_config)
    process_args(exper_args)
    measure_sink_metrics(exper_args)
