# Imports
import numpy as np
import pandas as pd
import torch
import os
# Opendataval
from opendataval.dataloader import Register, DataFetcher, mix_labels, add_gauss_noise
from opendataval.dataval import (
    GLOC,
    AME,
    DVRL,
    BetaShapley,
    DataBanzhaf,
    DataOob,
    DataShapley,
    InfluenceSubsample,
    KNNShapley,
    LavaEvaluator,
    LeaveOneOut,
    RandomEvaluator,
    RobustVolumeShapley,
)

from opendataval.experiment import ExperimentMediator
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


dataset_name = "fried"
train_count, valid_count, test_count = 1000, 100, 3000
noise_rate = 0.1
noise_kwargs = {'noise_rate': noise_rate}
model_name = "LogisticRegression"
metric_name = "accuracy"
train_kwargs = {"epochs": 3, "batch_size": 100, "lr": 0.01}
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

exper_med = ExperimentMediator.model_factory_setup(
    dataset_name=dataset_name,
    cache_dir="../data_files/",  
    force_download=False,
    train_count=train_count,
    valid_count=valid_count,
    test_count=test_count,
    add_noise=mix_labels,
    noise_kwargs=noise_kwargs,
    train_kwargs=train_kwargs,
    model_name=model_name,
    metric_name=metric_name
)

data_evaluators = [ 
    # RandomEvaluator(),
    # LeaveOneOut(), # leave one out
    # InfluenceSubsample(num_models=1000), # influence function
    # DVRL(rl_epochs=2000), # Data valuation using Reinforcement Learning
    KNNShapley(k_neighbors=valid_count), # KNN-Shapley
    # DataShapley(cache_name=f"cached"), # Data-Shapley ## slow
    # BetaShapley(cache_name=f"cached"), # Beta-Shapley ## slow
    # DataBanzhaf(num_models=1000), # Data-Banzhaf
    GLOC(num_models=500), # GLOC
    AME(num_models=500), # Average Marginal Effects
    # DataOob(num_models=1000), # Data-OOB
    # LavaEvaluator(),
]

exper_med = exper_med.compute_data_values(data_evaluators=data_evaluators)

from opendataval.experiment.exper_methods import (
    discover_corrupted_sample,
    noisy_detection,
    remove_high_low,
    save_dataval
)
from matplotlib import pyplot as plt

# Saving the results
output_dir = f"../tmp/withoutlocal{dataset_name}_{noise_rate}/"
exper_med.set_output_directory(output_dir)
output_dir

exper_med.evaluate(noisy_detection, save_output=True)

fig = plt.figure(figsize=(15, 25))
_, fig = exper_med.plot(discover_corrupted_sample, fig, col=2, save_output=True)

fig = plt.figure(figsize=(15, 25))
df_resp, fig = exper_med.plot(remove_high_low, fig, col=2, save_output=True)

df_resp

exper_med.evaluate(save_dataval, save_output=True)