#%%
from folktexts.dataset import Dataset
from folktexts.benchmark import Benchmark
from folktexts.benchmark import BenchmarkConfig
from folktexts.acs.acs_tasks import ACSTaskMetadata
from folktexts.acs.acs_dataset import ACSDataset
from pathlib import Path

#%% Default params
DEFAULT_ACS_TASK = "ACSIncome"
DEFAULT_DATA_DIR = Path("~/data").expanduser().resolve()

DEFAULT_BATCH_SIZE = 16
DEFAULT_CONTEXT_SIZE = 600
DEFAULT_SEED = 42


cli_args = [
        ("--model",         str, "[str] Model name or path to model saved on disk"),
        ("--results-dir",   str, "[str] Directory under which this experiment's results will be saved"),
        ("--data-dir",      str, "[str] Root folder to find datasets on"),
        ("--task",          str, "[str] Name of the ACS task to run the experiment on", False, DEFAULT_ACS_TASK),
        ("--few-shot",      int, "[int] Use few-shot prompting with the given number of shots", False),
        ("--batch-size",    int, "[int] The batch size to use for inference", False, DEFAULT_BATCH_SIZE),
        ("--context-size",  int, "[int] The maximum context size when prompting the LLM", False, DEFAULT_CONTEXT_SIZE),
        ("--fit-threshold", int, "[int] Whether to fit the prediction threshold, and on how many samples", False),
        ("--subsampling",   float, "[float] Which fraction of the dataset to use (if omitted will use all data)", False),
        ("--seed",          int, "[int] Random seed -- to set for reproducibility", False, DEFAULT_SEED),
    ]

# parser.add_argument(
#     "--use-web-api-model",
#     help="[bool] Whether use a model hosted on a web API (instead of a local model)",
#     action="store_true",
#     default=False,
# )

# parser.add_argument(
#     "--dont-correct-order-bias",
#     help="[bool] Whether to avoid correcting ordering bias, by default will correct it",
#     action="store_true",
#     default=False,
# )

# parser.add_argument(
#     "--numeric-risk-prompting",
#     help="[bool] Whether to prompt for numeric risk-estimates instead of multiple-choice Q&A",
#     action="store_true",
#     default=False,
# )

# parser.add_argument(
#     "--reuse-few-shot-examples",
#     help="[bool] Whether to reuse the same samples for few-shot prompting (or sample new ones every time)",
#     action="store_true",
#     default=False,
# )

# parser.add_argument(
#     "--balance-few-shot-examples",
#     help="[bool] Whether to sample evenly from all classes in few-shot prompting",
#     action="store_true",
#     default=False,
# )

# # Optionally, receive a list of features to use (subset of original list)
# parser.add_argument(
#     "--use-feature-subset",
#     type=list_of_strings,
#     help="[str] Optional subset of features to use for prediction, comma separated",
#     required=False,
# )

# parser.add_argument(
#     "--use-population-filter",
#     type=list_of_strings,
#     help=(
#         "[str] Optional population filter for this benchmark; must follow "
#         "the format 'column_name=value' to filter the dataset by a specific value."
#     ),
#     required=False,
# )

# parser.add_argument(
#     "--max-api-rpm",
#     type=int,
#     help="[int] Maximum number of API requests per minute (if using a web-hosted model)",
#     required=False,
# )

# parser.add_argument(
#     "--logger-level",
#     type=str,
#     help="[str] The logging level to use for the experiment",
#     choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
#     required=False,
#     default="WARNING",
# )

few_shot = False
numeric_risk_prompting = False
reuse_few_shot_examples = False
balance_few_shot_examples = False
batch_size = 16
context_size = 600
correct_order_bias = True
feature_subset = False
population_filter = None
seed = 42

# %%
config = BenchmarkConfig(
    few_shot=few_shot,
    numeric_risk_prompting=numeric_risk_prompting,
    reuse_few_shot_examples=reuse_few_shot_examples,
    balance_few_shot_examples=balance_few_shot_examples,
    batch_size=batch_size,
    context_size=context_size,
    correct_order_bias=correct_order_bias,
    feature_subset=feature_subset,
    population_filter=population_filter,
    seed=seed,
)


#%%

# bench = Benchmark.make_acs_benchmark(
#         task_name=args.task,
#         model=model,
#         tokenizer=tokenizer,
#         data_dir=args.data_dir,
#         config=config,
#         subsampling=args.subsampling,
#         max_api_rpm=args.max_api_rpm,
#     )



# def make_acs_benchmark(
#         cls,
#         task_name: str,
#         *,
#         model: AutoModelForCausalLM | str,
#         tokenizer: AutoTokenizer = None,
#         data_dir: str | Path = None,
#         max_api_rpm: int = None,
#         config: BenchmarkConfig = BenchmarkConfig.default_config(),
#         **kwargs,
#     ) -> Benchmark:

# # Handle non-standard ACS arguments
acs_dataset_configs = Benchmark.ACS_DATASET_CONFIGS.copy()
# for arg in acs_dataset_configs:
#     if arg in kwargs and kwargs[arg] != Benchmark.ACS_DATASET_CONFIGS[arg]:
#         logging.warning(
#             f"Received non-standard ACS argument '{arg}' (using "
#             f"{arg}={kwargs[arg]} instead of default {arg}={cls.ACS_DATASET_CONFIGS[arg]}). "
#             f"This may affect reproducibility.")
#         acs_dataset_configs[arg] = kwargs.pop(arg)

# # Update config with any additional kwargs
# config = config.update(**kwargs)

# Fetch ACS task and dataset
acs_task = ACSTaskMetadata.get_task(
    name=DEFAULT_ACS_TASK,
    use_numeric_qa=config.numeric_risk_prompting)

acs_dataset = ACSDataset.make_from_task(
    task=acs_task,
    cache_dir='data',
    **acs_dataset_configs)
# %%
acs_dataset
# %%
X_train, y_train = acs_dataset.get_train()
# Select only half of the training data at random
# train_size = len(X_train) // 2
# indices = np.random.RandomState(42).choice(len(X_train), train_size, replace=False)
# X_train = X_train[indices]
# y_train = y_train[indices]

X_val, y_val = acs_dataset.get_val()
X_test, y_test = acs_dataset.get_test()
# %%
from sklearn.ensemble import RandomForestClassifier, HistGradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import glest

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, log_loss, brier_score_loss

df = pd.read_csv('folktexts-results/folktexts-results/model-Mixtral-8x7B-v0.1_task-ACSIncome/Mixtral-8x7B-v0.1_bench-838239935/figures/gle_results.csv')


#%%

df
#%%
def train_and_evaluate_model(model_type, X_train, y_train, X_test, y_test):
    """
    Train a model, apply calibration with GLE, and return accuracy and AUC scores.
    
    Args:
        model_type: str, one of 'random_forest', 'gradient_boosting', 'logistic_regression'
        X_train, y_train: training data
        X_test, y_test: test data
    
    Returns:
        tuple: (accuracy, auc_score)
    """
    # Initialize model based on type
    if model_type == 'random_forest':
        model = RandomForestClassifier(n_estimators=100, random_state=42)
    elif model_type == 'gradient_boosting':
        model = HistGradientBoostingClassifier(max_iter=100, random_state=42)
    elif model_type == 'logistic_regression':
        model = LogisticRegression(max_iter=100, random_state=42)
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    
    # Train model
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    
    # Split test data for calibration
    X_train_cal, X_test_cal, y_train_cal, y_test_cal, S_train, S_test = train_test_split(
        X_test, y_test, y_pred, test_size=0.5, random_state=0
    )
    
    X_train_cal, X_cal, y_train_cal, y_cal, S_train, S_cal = train_test_split(
        X_train_cal, y_train_cal, S_train, 
        test_size=max(int(len(X_train_cal) * 0.2), 4000), 
        random_state=0
    )
    
    # Calibration
    calibrated_classifier = LogisticRegression()
    calibrated_classifier.fit(S_cal.reshape(-1,1), y_cal)
    
    c_hat_train = calibrated_classifier.predict_proba(S_train.reshape(-1,1))[:, 1]
    c_hat_test = calibrated_classifier.predict_proba(S_test.reshape(-1,1))[:, 1]
    
    # GLE estimation
    residuals_train = y_train_cal - c_hat_train
    dt = DecisionTreeRegressor(max_depth=7, min_samples_leaf=10)
    dt.fit(X_train_cal, residuals_train)
    leaf_ids = dt.apply(X_test_cal)
    
    gle = glest.core.GLEstimatorResiduals(None, None)
    gle.fit(X_test_cal, y_test_cal, y_scores_cal=c_hat_test, partition=leaf_ids)
    
    # Calculate metrics
    acc = accuracy_score(y_test_cal, S_test > 0.5)
    auc = roc_auc_score(y_test_cal, c_hat_test)
    
    return gle, acc, auc

# Usage examples:
rf_gl, rf_acc, rf_auc = train_and_evaluate_model('random_forest', X_train, y_train, X_test, y_test)
gb_gl, gb_acc, gb_auc = train_and_evaluate_model('gradient_boosting', X_train, y_train, X_test, y_test)
lr_gl, lr_acc, lr_auc = train_and_evaluate_model('logistic_regression', X_train, y_train, X_test, y_test)

# %%
# Add results to the dataframe
new_results = pd.DataFrame([
    {
        'model_name': 'RandomForest',
        'accuracy': rf_acc,
        'auc': rf_auc,
        'grouping_loss': rf_gl.metrics().get('GL', None)
    },
    {
        'model_name': 'GradientBoosting', 
        'accuracy': gb_acc,
        'auc': gb_auc,
        'grouping_loss': gb_gl.metrics().get('GL', None)
    },
    {
        'model_name': 'LogisticRegression',
        'accuracy': lr_acc, 
        'auc': lr_auc,
        'grouping_loss': lr_gl.metrics().get('GL', None)
    }
])

# Concatenate with existing dataframe
df = pd.concat([df, new_results], ignore_index=True)
# %%
df
# %%
df.to_csv('combined_results_baseline.csv', index=False)
# %%
