import argparse

from hypo_interp.config import ExperimentConfig
from hypo_interp.tasks import (
    DocstringTask,
    GreaterThanTask,
    InductionTask,
    IoITask,
    TracrProportionTask,
    TracrReverseTask,
)
from hypo_interp.test_executor import TestExecutor

#############################################
# Constants
#############################################

INDUCTION_NAME = "induction"
TRACR_PROP_NAME = "tracr-proportion"
TRACR_REV_NAME = "tracr-reverse"
GREATER_THAN_NAME = "greater-than"
IOI_NAME = "ioi"
DOCSTRING_NAME = "docstring"

TASK_NAMES = [
    INDUCTION_NAME,
    TRACR_PROP_NAME,
    TRACR_REV_NAME,
    GREATER_THAN_NAME,
    IOI_NAME,
    DOCSTRING_NAME,
]

TASK_NAME_TO_TASK = {
    INDUCTION_NAME: InductionTask,
    TRACR_PROP_NAME: TracrProportionTask,
    TRACR_REV_NAME: TracrReverseTask,
    GREATER_THAN_NAME: GreaterThanTask,
    IOI_NAME: IoITask,
    DOCSTRING_NAME: DocstringTask,
}

FAITHFULNESS_NAME = "faithfulness"
MINIMALITY_NAME = "minimality"
INDEPENDENCE_NAME = "independence"
WILCOXON_NAME = "wilcoxon"
TEST_NAMES = ["faithfulness", "independence", "wilcoxon", "minimality"]

#############################################
#  Helper functions
#############################################


def parse_args():
    parser = argparse.ArgumentParser()

    # ----------------------------
    # Required arguments
    # ----------------------------
    parser.add_argument(
        "--test",
        nargs="+",
        choices=TEST_NAMES,
        help="What test to run",
        required=True,
    )

    parser.add_argument(
        "--task",
        choices=TASK_NAMES,
        help="Task to run.",
        required=True,
    )

    # ------------------------------------
    # Optional arguments - Faithfulness
    # ------------------------------------
    parser.add_argument(
        "--num-random-circuits",
        type=int,
        help="Number of random circuits to generate for faithfulnes test.",
        default=100,
        required=False,
    )

    random_size_msg = "Number of edges to include in the random circuit."
    random_size_msg += "If not set, then the random circuit will be the same size as the original circuit."
    parser.add_argument(
        "--random_proportion",
        type=int,
        help=random_size_msg,
        default=None,
        required=False,
    )
    parser.add_argument(
        "--alpha",
        type=float,
        help="significance level",
        default=0.05,
        required=False,
    )
    parser.add_argument(
        "--quantile",
        type=float,
        help="test quantile",
        default=0.1,
        required=False,
    )
    parser.add_argument(
        "--per-prompt",
        type=bool,
        help="per prompt",
        default=True,
        required=False,
    )
    parser.add_argument(
        "--use-mean",
        type=bool,
        help="use mean",
        default=False,
        required=False,
    )
    parser.add_argument(
        "--invert",
        type=bool,
        help="use the candidate circuit or the complement of the candidate circuit",
        default=False,
        required=False,
    )

    # ------------------------------------
    # Optional arguments - Minimality
    # ------------------------------------
    parser.add_argument(
        "--subset_edge_size",
        type=int,
        help="the size of edge of test.",
        default=1,
        required=False,
    )
    parser.add_argument(
        "--num_edge_test",
        type=int,
        help="the number of edge of test.",
        default=1,
        required=False,
    )

    parser.add_argument(
        "--base_distribution_size_minimality",
        type=int,
        help="the minum number of refernece set we need to knock out.",
        default=100,
        required=False,
    )

    parser.add_argument(
        "--inflate_percent",
        type=float,
        help="how much more to inflate ",
        default=1,
        required=False,
    )

    # ------------------------------------
    # Optional arguments - Independence (permutation test )
    # ------------------------------------
    parser.add_argument(
        "--num_permutations",
        type=int,
        help="Number of subset circuits to generate for minimality test.",
        default=1000,
        required=False,
    )

    # ----------------------------
    # Optional arguments - Misc
    # ----------------------------
    save_pth_msg = "Directory to write results and json file to"
    parser.add_argument(
        "--save-path",
        type=str,
        required=False,
        help=save_pth_msg,
        default="results_default",
    )

    device_msg = "Device to run on."
    device_msg += "If not set cpu will be used."
    parser.add_argument(
        "--device",
        type=str,
        help=device_msg,
        default="cpu",
        required=False,
    )

    num_examples_msg = "Number of examples to use in the dataset."
    num_examples_msg += "If not set 50 will be used."
    parser.add_argument(
        "--num-examples",
        type=int,
        help=num_examples_msg,
        required=False,
        default=50,
    )

    seed_msg = "Seed to use for random number generation."
    parser.add_argument(
        "--seed",
        type=int,
        help=seed_msg,
        required=False,
        default=None,
    )

    save_scores_msg = "Whether to save scores for random circuits."
    parser.add_argument(
        "--save-scores",
        type=bool,
        help=save_scores_msg,
        required=False,
        default=False,
    )
    parser.add_argument(
        "--scores-path",
        type=str,
        required=False,
        help="path to stored scores",
        default=None,
    )

    return parser.parse_args()


#############################################
# Main
#############################################


def main():
    args = parse_args()

    zero_ablation = True if args.task == "induction" else False

    config = ExperimentConfig(
        device=args.device,
        # ----------------------------
        # Faithfulness
        # ----------------------------
        num_random_circuits=args.num_random_circuits,  # this is relevant for the faithfulness test
        random_proportion=args.random_proportion,  # this is relevant for the faithfulness test
        inflate_percentage=args.inflate_percent,
        base_distribution_size_minimality=args.base_distribution_size_minimality,
        subset_edge_size=args.subset_edge_size,
        save_scores=args.save_scores,
        scores_path=args.scores_path,
        save_path=args.save_path,
        seed=args.seed,
        alpha=args.alpha,
        quantile=args.quantile,
        per_prompt=args.per_prompt,
        use_mean=args.use_mean,
        zero_ablation=zero_ablation,
        invert=args.invert,
    )

    task = TASK_NAME_TO_TASK[args.task](
        device=config.device,
        num_examples=args.num_examples,
        zero_ablation=zero_ablation,
    )

    handler = TestExecutor(
        config=config, task=task, candidate_circuit=task.canonical_circuit
    )

    if FAITHFULNESS_NAME in args.test:
        handler.test_faithfulness(
            quantile=args.quantile,
            alpha=args.alpha,
            per_prompt=args.per_prompt,
            use_mean=args.use_mean,
        )
    if INDEPENDENCE_NAME in args.test:
        handler.test_sufficiency(num_permutations=args.num_permutations)

    if WILCOXON_NAME in args.test:
        handler.test_faithfulness_two_sample()

    if MINIMALITY_NAME in args.test:
        handler.test_minimality_via_quantile(
            num_edge_test=args.num_edge_test,
            quantile=args.quantile,
            per_prompt=args.per_prompt,
            use_mean=args.use_mean,
        )


if __name__ == "__main__":
    main()
