from matplotlib import pyplot as plt

from tabicl.config.config_benchmark_sweep import ConfigBenchmarkSweep
from tabicl.results.dataset_plot_combined import make_combined_dataset_plot, make_combined_dataset_plot_data
from tabicl.results.dataset_plot_separate import make_separate_dataset_plot_data, make_separate_dataset_plots
from tabicl.results.merge_ds_bench_whytrees_and_sweeps import merge_ds_bench_whytrees_and_sweeps
from tabicl.results.random_sequence import create_random_sequences_from_dataset
from tabicl.results.results_sweep import ResultsSweep


def make_dataset_plots(cfg: ConfigBenchmarkSweep, results_sweep: ResultsSweep) -> None:

    results_sweep.ds.attrs['model_plot_name'] = cfg.model_plot_name
    ds = merge_ds_bench_whytrees_and_sweeps(cfg, [results_sweep.ds])
    ds = create_random_sequences_from_dataset(cfg, ds)

    plot_data_combined = make_combined_dataset_plot_data(cfg, ds)
    plot_data_combined.to_netcdf(cfg.output_dir / "dataset_plot_combined.nc")
    fig_combined = make_combined_dataset_plot(cfg, plot_data_combined)
    fig_combined.savefig(cfg.output_dir / "dataset_plot_combined.png")
    plt.close(fig_combined)

    plot_data_separate = make_separate_dataset_plot_data(cfg, ds)
    plot_data_separate.to_netcdf(cfg.output_dir / "dataset_plot_separate.nc")
    fig_separate = make_separate_dataset_plots(cfg, plot_data_separate)
    fig_separate.savefig(cfg.output_dir / "dataset_plot_separate.png")
    plt.close(fig_separate)