import numpy as np
import torch as th

from learned_planners.interp.collect_dataset import DatasetStore  # noqa: F401  # pyright: ignore
from learned_planners.interp.train_probes import ActivationsDataset

path = "/training/activations_dataset/hard/0ts_alternative_boxes_direction_map_20000_skip5.pt"
acts_ds = th.load("/training/activations_dataset/hard/0ts_alternative_boxes_direction_map_20000_skip5.pt")

new_ds = ActivationsDataset(
    dataset_path=acts_ds.dataset_path,
    labels_type=acts_ds.labels_type,
    keys=acts_ds.keys,
    num_data_points=acts_ds.num_data_points,
    fetch_all_boxing_data_points=acts_ds.fetch_all_boxing_data_points,
    skip_first_n=acts_ds.skip_first_n,
    multioutput=True,
    load_data=False,
)
new_ds.grid_wise = True
new_ds.classification = True
new_ds.keys = acts_ds.keys
data = []
gt_output_data = []
for idx in range(len(acts_ds)):
    all_cache_values = list(acts_ds[idx][0].values())
    gt_output = acts_ds[idx][1]
    for cell_idx in range(100):
        i, j = cell_idx % 10, cell_idx // 10
        cache_data = {}
        for comp_idx in range(len(all_cache_values)):
            v = all_cache_values[comp_idx][:, i, j]
            cache_data[list(acts_ds[idx][0].keys())[comp_idx]] = v
        data.append(cache_data)
        gt_output_data.append(gt_output[i, j].item() if gt_output[i, j].size == 1 else gt_output[i, j])
gt_output_data = np.stack(gt_output_data)

new_ds.data = data
new_ds.gt_output = gt_output_data

new_path = path.replace(".pt", "_gridwise_multioutput.pt")
th.save(new_ds, new_path)
