import os
import pdb
import random
from itertools import product, chain

from typing import Dict, List, Tuple, Any, Optional, Union, Callable
import argparse

from tqdm import tqdm
import numpy as np
import torch
import nnsight
from nnsight import CONFIG, LanguageModel, util
import pandas as pd  # this has to be after nnsight or throws gcc error
from transformers.pytorch_utils import find_pruneable_heads_and_indices

import matplotlib.pyplot as plt
import seaborn as sns

import sys

sys.path.append("../pp_experiment")
from utils import get_model_and_tokenizer, fix_random_seed, get_random_circuit, get_circuit, eval_circuit_performance, \
    MODEL_TO_SHORT, stupid_pad
sys.path.append("../nnsight_patching_experiment")
from run_patching import build_parser, post_arg_parse_fix, get_model_and_dataset, maybe_patch_or_load_cache


def run_logit_diff(model, dataloader, args) -> List[float]:
    
    pass


def logit_diff_main(args: argparse.Namespace):
    """
    Run behavioral testing for models on specific dataset, script is generic to accomadate
    different experiments, sampling parameters, metric, etc.
    """
    dataloader, dataset, model = get_model_and_dataset(args)
    os.makedirs(args.output_dir, exist_ok=True)
    results = maybe_patch_or_load_cache(
        f"{args.output_dir}",
        run_logit_diff,
        model=model,
        dataloader=dataloader,
        args=args
    )


def add_args(parser: argparse.ArgumentParser):
    parser.add_argument("--metric", type=str, default="accuracy", choices=["all_obj_accuracy", "avg_obj_accuracy", "first_token_argmax_any"])
    parser.add_argument("--sampling", type=str, default="greedy", choices=["greedy", "sampling"])
    return parser


if __name__ == "__main__":
    parser = add_args(build_parser())
    args = parser.parse_args()
    print(f"ARGS: {args}")
    post_arg_parse_fix(args)
    fix_random_seed(args.seed)
    logit_diff_main(args)
