import os
import logging
from models.cartography.cartography.selection.train_dy_filtering import compute_train_dy_metrics, get_filtered_data, plot_data_map
from models.cartography.cartography.selection.selection_utils import read_training_dynamics

def print(*args, **kwargs):
    logging.info(" ".join(map(str, args)))


def run_filter(config, save_dir, plots_dir, data_numbers, data_name, model_name):
    training_dynamics = read_training_dynamics(save_dir,
                                               strip_last=False,
                                               burn_out=config.burn_out)
    total_epochs = len(list(training_dynamics.values())[0]["logits"])
    if config.burn_out > total_epochs:
        config.burn_out = total_epochs
        print(f"Total epochs found: {config.burn_out}")
    train_dy_metrics, _ = compute_train_dy_metrics(training_dynamics, config)

    burn_out_str = f"_{config.burn_out}" if config.burn_out > total_epochs else total_epochs
    train_dy_filename = os.path.join(save_dir, f"td_metrics_e{burn_out_str}.jsonl")
    train_dy_metrics.to_json(train_dy_filename,
                            orient='records',
                            lines=True)
    print(f"Metrics based on Training Dynamics written to {train_dy_filename}")
    
    if config.filter:
        filtering_output_dir = os.path.join(save_dir, "filtering")
        if not os.path.exists(filtering_output_dir):
            os.makedirs(filtering_output_dir)
        assert config.metric
        selected_ids = get_filtered_data(config, filtering_output_dir, train_dy_metrics, data_numbers)

    if config.plot:
        if not os.path.exists(plots_dir):
            os.makedirs(plots_dir)
        plot_data_map(train_dy_metrics, plots_dir, title=data_name, show_hist=True, model=model_name)
    
    return selected_ids