import json
import os

import numpy as np


### Load json files
def load_json(path_to_file):
    with open(path_to_file) as f:
        data = json.load(f)
    return data


f_train = load_json("src/datasets/json/arc-agi_training_challenges.json")
f_train_solution = load_json("src/datasets/json/arc-agi_training_solutions.json")

f_eval = load_json("src/datasets/json/arc-agi_evaluation_challenges.json")
f_eval_solution = load_json("src/datasets/json/arc-agi_evaluation_solutions.json")

f_submission = load_json("src/datasets/json/arc-agi_test_challenges.json")
f_sample_submission = load_json("src/datasets/json/sample_submission.json")

### Number of train pairs per task
# for i in range(400):
#     print(len(list(f_train.values())[i]["train"]), end="\t")

### Number of test inputs/outputs per task
# num_test_inputs | num_test_outputs
# for task_id in f_train.keys():
#     print(len(f_train[task_id]["test"]), len(f_train_solution[task_id]), end="\t")

# Make ARC Train Set
max_rows, max_cols = 30, 30
mini_batch_size = 3

arc_train_set_grids = []
arc_train_set_shapes = []
for task_id, task in f_train.items():
    pair_list, shape_list = [], []
    for example in task["train"]:
        input = np.array(example["input"])
        input_shape = input.shape
        input = np.pad(input, ((0, max_rows - input.shape[0]), (0, max_cols - input.shape[1])))
        output = np.array(example["output"])
        output_shape = output.shape
        output = np.pad(output, ((0, max_rows - output.shape[0]), (0, max_cols - output.shape[1])))
        pair_list.append(np.stack([input, output], axis=-1))
        shape_list.append(np.stack([input_shape, output_shape], axis=-1))
        if len(pair_list) % mini_batch_size == 0:
            arc_train_set_grids.append(pair_list)
            arc_train_set_shapes.append(shape_list)
            pair_list, shape_list = [], []
    for data, output in zip(task["test"], f_train_solution[task_id]):
        input = np.array(data["input"])
        input_shape = input.shape
        input = np.pad(input, ((0, max_rows - input.shape[0]), (0, max_cols - input.shape[1])))
        output = np.array(output)
        output_shape = output.shape
        output = np.pad(output, ((0, max_rows - output.shape[0]), (0, max_cols - output.shape[1])))
        pair_list.append(np.stack([input, output], axis=-1))
        shape_list.append(np.stack([input_shape, output_shape], axis=-1))
        if len(pair_list) % mini_batch_size == 0:
            arc_train_set_grids.append(pair_list)
            arc_train_set_shapes.append(shape_list)
            pair_list, shape_list = [], []

arc_train_set_grids = np.array(arc_train_set_grids)
arc_train_set_shapes = np.array(arc_train_set_shapes)

dataset_dir = f"src/datasets/storage/arc_train_{len(arc_train_set_grids)}_{mini_batch_size}"
os.makedirs(dataset_dir, exist_ok=True)
grid_path = os.path.join(dataset_dir, "grids.npy")
np.save(grid_path, arc_train_set_grids)
print("Saved ARC train dataset grids of shape", arc_train_set_grids.shape, f"at '{grid_path}'")
shape_path = os.path.join(dataset_dir, "shapes.npy")
np.save(shape_path, arc_train_set_shapes)
print("Saved ARC train dataset shapes of shape", arc_train_set_shapes.shape, f"at '{shape_path}'")
# Make ARC Eval Set
max_rows, max_cols = 30, 30
mini_batch_size = 3

arc_eval_set_grids = []
arc_eval_set_shapes = []
for task_id, task in f_eval.items():
    pair_list, shape_list = [], []
    for example in task["train"]:
        input = np.array(example["input"])
        input_shape = input.shape
        input = np.pad(input, ((0, max_rows - input.shape[0]), (0, max_cols - input.shape[1])))
        output = np.array(example["output"])
        output_shape = output.shape
        output = np.pad(output, ((0, max_rows - output.shape[0]), (0, max_cols - output.shape[1])))
        pair_list.append(np.stack([input, output], axis=-1))
        shape_list.append(np.stack([input_shape, output_shape], axis=-1))
        if len(pair_list) % mini_batch_size == 0:
            arc_eval_set_grids.append(pair_list)
            arc_eval_set_shapes.append(shape_list)
            pair_list, shape_list = [], []
    for data, output in zip(task["test"], f_eval_solution[task_id]):
        input = np.array(data["input"])
        input_shape = input.shape
        input = np.pad(input, ((0, max_rows - input.shape[0]), (0, max_cols - input.shape[1])))
        output = np.array(output)
        output_shape = output.shape
        output = np.pad(output, ((0, max_rows - output.shape[0]), (0, max_cols - output.shape[1])))
        pair_list.append(np.stack([input, output], axis=-1))
        shape_list.append(np.stack([input_shape, output_shape], axis=-1))
        if len(pair_list) % mini_batch_size == 0:
            arc_eval_set_grids.append(pair_list)
            arc_eval_set_shapes.append(shape_list)
            pair_list, shape_list = [], []

arc_eval_set_grids = np.array(arc_eval_set_grids)
arc_eval_set_shapes = np.array(arc_eval_set_shapes)

dataset_dir = f"src/datasets/storage/arc_eval_{len(arc_eval_set_grids)}_{mini_batch_size}"
os.makedirs(dataset_dir, exist_ok=True)
grid_path = os.path.join(dataset_dir, "grids.npy")
np.save(grid_path, arc_eval_set_grids)
print("Saved ARC eval dataset grids of shape", arc_eval_set_grids.shape, f"at '{grid_path}'")
shape_path = os.path.join(dataset_dir, "shapes.npy")
np.save(shape_path, arc_eval_set_shapes)
print("Saved ARC eval dataset shapes of shape", arc_eval_set_shapes.shape, f"at '{shape_path}'")
# Make ARC Dummy Set (for testing purposes)
max_rows, max_cols = 4, 4
mini_batch_size = 3

arc_dummy_set_grids = []
arc_dummy_set_shapes = []
f_train_f_eval = {**f_train, **f_eval}
f_train_solution_f_eval_solution = {**f_train_solution, **f_eval_solution}
for task_id, task in f_train_f_eval.items():
    pair_list, shape_list = [], []
    for example in task["train"]:
        input = np.array(example["input"])
        input_shape = input.shape
        if input_shape[0] > max_rows or input_shape[1] > max_cols:
            continue
        input = np.pad(input, ((0, max_rows - input.shape[0]), (0, max_cols - input.shape[1])))
        output = np.array(example["output"])
        output_shape = output.shape
        if output_shape[0] > max_rows or output_shape[1] > max_cols:
            continue
        output = np.pad(output, ((0, max_rows - output.shape[0]), (0, max_cols - output.shape[1])))
        pair_list.append(np.stack([input, output], axis=-1))
        shape_list.append(np.stack([input_shape, output_shape], axis=-1))
        if len(pair_list) % mini_batch_size == 0:
            arc_dummy_set_grids.append(pair_list)
            arc_dummy_set_shapes.append(shape_list)
            pair_list, shape_list = [], []
    for data, output in zip(task["test"], f_train_solution_f_eval_solution[task_id]):
        input = np.array(data["input"])
        input_shape = input.shape
        if input_shape[0] > max_rows or input_shape[1] > max_cols:
            continue
        input = np.pad(input, ((0, max_rows - input.shape[0]), (0, max_cols - input.shape[1])))
        output = np.array(output)
        output_shape = output.shape
        if output_shape[0] > max_rows or output_shape[1] > max_cols:
            continue
        output = np.pad(output, ((0, max_rows - output.shape[0]), (0, max_cols - output.shape[1])))
        pair_list.append(np.stack([input, output], axis=-1))
        shape_list.append(np.stack([input_shape, output_shape], axis=-1))
        if len(pair_list) % mini_batch_size == 0:
            arc_dummy_set_grids.append(pair_list)
            arc_dummy_set_shapes.append(shape_list)
            pair_list, shape_list = [], []

arc_dummy_set_grids = np.array(arc_dummy_set_grids)
arc_dummy_set_shapes = np.array(arc_dummy_set_shapes)

dataset_dir = (
    f"src/datasets/storage/arc_dummy_{max_rows}x{max_cols}_{len(arc_dummy_set_grids)}_{mini_batch_size}"
)
os.makedirs(dataset_dir, exist_ok=True)
grid_path = os.path.join(dataset_dir, "grids.npy")
np.save(grid_path, arc_dummy_set_grids)
print("Saved ARC dummy dataset grids of shape", arc_dummy_set_grids.shape, f"at '{grid_path}'")
shape_path = os.path.join(dataset_dir, "shapes.npy")
np.save(shape_path, arc_dummy_set_shapes)
print("Saved ARC dummy dataset shapes of shape", arc_dummy_set_shapes.shape, f"at '{shape_path}'")
