import sys, os
from datasets.datasets import SyntheticFACE, SyntheticMoons, Dataset, CaliforniaHousing, GermanCreditv2, GiveMeSomeCredit, AdultIncome
from visualisation import *

from models.mlp_pytorch import PyTorchMLP
from conformal.split_conformal import SplitConformalPrediction
from conformal.localised_conformal_baselcp import BaseLCP
from conformal.localised_conformal_tree import ConformalCONFEXTree

from counterfactual_explanations.counterfactual_benchmarker import *
from counterfactual_explanations.gradient_based.auxillary_models import *
from counterfactual_explanations.gradient_based.cf_gradient_based import *
from counterfactual_explanations.tree.cf_featuretweak import FeatureTweakGenerator
from counterfactual_explanations.tree.cf_focus import FOCUSGenerator
from counterfactual_explanations.gradient_based.losses import *

from counterfactual_explanations.milp_based.cf_conformal import *
from counterfactual_explanations.milp_based.cf_mindist import *
from counterfactual_explanations.dim_reduction import *
import argparse

from datetime import datetime

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run counterfactual benchmarking experiments.")
    parser.add_argument("--dataset", type=str, choices=["CaliforniaHousing", "GermanCredit", "GiveMeSomeCredit", "AdultIncome"], default="CaliforniaHousing", help="Dataset to use (CaliforniaHousing, GermanCredit, GiveMeSomeCredit, AdultIncome)")
    parser.add_argument("--model", type=str, choices=["MLP", "RandomForest"], default="MLP", help="Model to use (RandomForest, MLP)")
    args = parser.parse_args()
    is_rf = args.model.lower() == "randomforest"

    print(f"Start {datetime.now()}")
    mlp_config = {"epochs": 100, "batch_size": 64}
    rf_config = {}

    if args.dataset == "CaliforniaHousing":
        dataset_cls = CaliforniaHousing
    elif args.dataset == "GermanCredit":
        dataset_cls = GermanCreditv2
    elif args.dataset == "GiveMeSomeCredit":
        dataset_cls = GiveMeSomeCredit
        mlp_config = {"epochs": 50, "batch_size": 256}
        rf_config = {"max_n_leaves": 500, "n_estimators": 5}
    elif args.dataset == "AdultIncome":
        dataset_cls = AdultIncome
        mlp_config = {"epochs": 50, "batch_size": 256}
        rf_config = {"max_n_leaves": 500, "n_estimators": 5}
    else:
        raise ValueError(f"Unknown dataset: {args.dataset}")

    dataset = dataset_cls(0.6, 0.2, 0.2)
    model_factories = []

    if is_rf:
        factory = ModelFactory(RandomForestSKLearn, dataset.input_properties, config=rf_config, config_multi={})
        model_factories.append(factory)
    else:
        factory = ModelFactory(PyTorchMLP, dataset.input_properties, config=mlp_config, config_multi={})
        model_factories.append(factory)

    n_factuals_main = 100
    n_repeats = 2
    path = Path("results_v")
    use_pretrained = True

    metrics = [
        FailuresMetric(), 
        DistanceMetric(), 
        ValidityMetric(),  
        ImplausibilityMetric(included_prop=0.1), 
        LOFMetric(n_neighbours=20, stratified=True),
        SensitivityMetric(n_sensitivity=25, n_neighbours=4, budget=0.001), 
        StabilityMetric(n_neighbours=8, budget=0.001),
    ]

    conformal_config = {
        "alpha": [0.01, 0.05, 0.1], "scorefn_name": ["linear_logits" if is_rf else "linear2"], "kernel_bandwidth": [0.05, 0.1, 0.15, 0.2], 
    }

    generators = [
        GeneratorFactory([MinDistanceCF], config={}, config_multi={}),

        GeneratorFactory([ConformalCF], config={"conformal_class": SplitConformalPrediction}, config_multi={"conformal_config": {"alpha": [0.01, 0.05, 0.1]}}),

        GeneratorFactory([ConformalCF], config={"conformal_class": ConformalCONFEXTree}, config_multi={
            "conformal_config": conformal_config | {"idx_cat_groups_to_ignore": [[1, 2, 3, 4]]} if args.dataset == "AdultIncome" else conformal_config
        }),
    ]

    if is_rf:
        generators.append(
            GeneratorFactory([FOCUSGenerator], config={"n_iter": 200}, config_multi={}),
            GeneratorFactory([FeatureTweakGenerator], config={"epsilon": 0.01}, config_multi={})
        )
    else:
        f1 = GeneratorFactory([WachterGenerator], config={"mad": True}, config_multi={})
        f2 = GeneratorFactory([SchutGenerator], config={"new": True}, config_multi={})
        f3 = GeneratorFactory([ECCCOGenerator], config={}, config_multi={"conformal_config": {"alpha": [0.01, 0.05, 0.1]}})
        generators.extend([f1, f2, f3])
        

    ## Do not modify below
    print("Initializing CFBenchmarker...")
    benchmarker = CFBenchmarker(dataset, n_factuals_main, n_repeats, metrics, model_factories, generators, path, use_pretrained=use_pretrained)

    print("Setting up models...")
    benchmarker.setup_models()

    print("Evaluating models...")
    benchmarker.evaluate_models()

    print("Setting factuals...")
    benchmarker.set_factuals()

    print("Initializing generators...")
    benchmarker.initialise_generators()

    print("Generating counterfactuals...")
    benchmarker.get_counterfactuals(reset=not use_pretrained)

    print("Evaluating counterfactuals...")
    df_out = benchmarker.evaluate_counterfactuals()

    print("Test conformal...")
    benchmarker.test_conformal()

    print(f"Evaluation complete. See {path}")
    print(f"End {datetime.now()}")


