#%%

from decision.xp.test_xp import test_utility_table, test_regret_table, test_fig_gain_times, test_fig_gain_post_training_vs_metrics_excess2, test_fig_gain_post_training_vs_metrics2
from decision.xp.data.base import ForwardedMixin, ds_registry, ds_rename
from decision.xp.model.base import PretrainedMixin, model_registry, model_rename

from decision.xp.test_forward import forward_ds_model
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing
from joblib import Parallel, delayed
import os
import glob
# import dataset


#%%

# ds_name = "merged_hate_check"
# model_name = "mistral_instruct"

# #%%
# ds = ds_registry[ds_name]()
# model = model_registry[model_name]()
# data = forward_ds_model(ds, model, batch_size=32)


#%%

# test_utility_table("decision/utility_table", model_name, ds_name, "recal_isotonic", skip_existing=False)
#%%

print(ds_registry.keys())
#%%
# skip = ["hate", "hate_check", "merged_hate_check",
#  "hate_en_davidson", "hate_en_gender",
#  "hate_en_twitter", "hate_en_tweets",
#  "hate_en_speech_off", "hate_en_check",
#  "hate_dyn_gen", "hate_en_open",
#  "hate_en_speech18"
#  ]

# skip = ["hate_merged_large", "hate_merged_large_no_en", "hate_merged_no_en"]
# datasets = ["hate_en_frenk", "hate_merged_en", "hate_merged_en2", "hate_merged_large_en", "hate_merged_no_en2"]
# #%%
# # Loop through all combinations of datasets and models
# for ds_name in datasets:
#     for model_name in ["mistral_instruct"]:
#         print(f"Processing {ds_name} with {model_name}")
#         try:
#             ds = ds_registry[ds_name]()
#             model = model_registry[model_name]()
#             data = forward_ds_model(ds, model)
#             print(f"Successfully processed {ds_name} with {model_name}")
#         except Exception as e:
#             print(f"Error processing {ds_name} with {model_name}: {e}")
#%%

# data["text"]
# #%%
# data[0]
# dataset = data.map(lambda examples: {"labels": examples["label"]}, batched=True)
# #%%

# import torch

# dataset.set_format(type="torch", columns=["input_ids", "token_type_ids", "attention_mask", "labels"])

# dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
# #%%
# res = model.process(data["text"])


# #%%

# res
# #%%

# model.truncated_model(res)

post_training_methods = [
"recal_isotonic", 
"recal_sigmoid",
# "recal_hist10",
# "recal_hist15",
# "recal_scalbin10",
# "recal_scalbin15",
# "recal_metacal_acc90",
# "recal_metacal_mis05",
# "stack_logistic",
# "stack_rf",
# "stack_hgb",
# "stack_naive_bayes",
# "stack_gaussian_process",
# "finetuning_sigmoid",
]

# def process_combination(ds_name, model_name, post_method):
#     print(f"Processing {ds_name} with {model_name} and {post_method}")
#     try:
#         test_utility_table("decision/utility_table", model_name, ds_name, post_method, skip_existing=False)
#         return f"Successfully processed {ds_name} with {model_name} and {post_method}"
#     except Exception as e:
#         return f"Error processing {ds_name} with {model_name} and {post_method}: {e}"

# # Create all combinations and run in parallel
# results = Parallel(n_jobs=8)(
#     delayed(process_combination)(ds_name, model_name, post_method)
#     for ds_name in ds_registry.keys()
#     for model_name in model_registry.keys()
#     for post_method in post_training_methods
# )

# # Print results
# for result in results:
#     print(result)


#%%


# for ds_name in ds_registry.keys():
#     for model_name in model_registry.keys():
#         print(f"Processing {ds_name} with {model_name}")
#         try:
#             test_regret_table("decision/regret_table", "useless", model_name, ds_name, skip_existing=False)
#             print(f"Successfully processed {ds_name} with {model_name}")
#         except Exception as e:
#             print(f"Error processing {ds_name} with {model_name}: {e}")
#%%



test_fig_gain_post_training_vs_metrics_excess2(out ="decision/fig_gain_times", inp = "decision/", post_training= "finetuning_sigmoid")
# #%%


# test_fig_gain_post_training_vs_metrics2(out ="decision/fig_gain_times", inp = "decision/", post_training= "finetuning_sigmoid")
# # # from decision.xp.data.base import ForwardedMixin, ds_registry, ds_rename
# # # print(ds_registry)
# # # %%
# from decision.xp.model.base import PretrainedMixin, model_registry, model_rename

# print(model_registry)
# # %%
# import os

# os.environ["WORKING_DIR"]
# # %%

# %%
# Find and remove all files with "mistral" in their name in decision/regret_table
# regret_table_dir = "decision/utility_table"
# if os.path.exists(regret_table_dir):
#     mistral_files = glob.glob(os.path.join(regret_table_dir, "*mistral*"))
#     for file_path in mistral_files:
#         try:
#             os.remove(file_path)
#             print(f"Deleted: {file_path}")
#         except Exception as e:
#             print(f"Error deleting {file_path}: {e}")
#     print(f"Finished cleaning mistral files from {regret_table_dir}")
# else:
#     print(f"Directory {regret_table_dir} does not exist")
# %%
