from typing import Dict, Any, cast
import json
from utils.env import seed
from agent.args import args
from agent.algm import Robert
from agent.params import PARAMS
from agent.config import CONFIGS
from os import path
import torch
from agent.common import NEW_DATASET_FOLDER
import math

DATASET_FOLDER = path.abspath(path.join(path.dirname(__file__), "../datasets"))

VALID_SIZE = 6


# %%
def main(args: Dict[str, Any]):
    RANDOM_SEED = args["seed"]
    seed(RANDOM_SEED)
    eval_length_ratio = args["eval_ratio"]

    read_dir = f'{NEW_DATASET_FOLDER}/{args["task"]}'
    print(f"start to split dataset into valid in {args['task']}")
    states = torch.load(f"{read_dir}/states.pt")
    actions = torch.load(f"{read_dir}/actions.pt")

    print(f"states shape: {states.shape}, actions shape: {actions.shape}")
    all_idx = set(list(range(states.size(0))))
    valid_idx = torch.randperm(states.size(0))[:VALID_SIZE].tolist()
    assert len(valid_idx) == VALID_SIZE

    print(f"valid idx is: {valid_idx}")
    train_idx = list(all_idx.difference(set(valid_idx)))

    episode_idx = torch.randint(0, CONFIGS["dataset"]["episode_length"], (VALID_SIZE,))
    valid_states, valid_actions = (
        states[valid_idx, episode_idx],
        actions[valid_idx, episode_idx],
    )
    train_states, train_actions = states[train_idx], actions[train_idx]
    print(
        f"train states shape: {train_states.shape}, train actions shape: {train_actions.shape}"
    )
    print(
        f"valid states shape: {valid_states.shape}, valid actions shape: {valid_actions.shape}"
    )

    valid_half_masks = torch.stack(
        [
            torch.randperm(PARAMS["post_steps"] * eval_length_ratio)[
                : math.floor(eval_length_ratio * PARAMS["post_steps"] / 5)
            ]
            for _ in range(VALID_SIZE)
        ],
        dim=0,
    )

    valid_full_masks = torch.stack(
        [
            torch.randperm(PARAMS["post_steps"] * eval_length_ratio)[
                : math.floor(eval_length_ratio * PARAMS["post_steps"] / 2)
            ]
            for _ in range(VALID_SIZE)
        ],
        dim=0,
    )

    torch.save(train_states, f"{read_dir}/states_train.pt")
    torch.save(train_actions, f"{read_dir}/actions_train.pt")
    torch.save(valid_states, f"{read_dir}/states_valid.pt")
    torch.save(valid_actions, f"{read_dir}/actions_valid.pt")

    assert torch.all(valid_half_masks <= eval_length_ratio * PARAMS["post_steps"])
    assert torch.all(valid_full_masks <= eval_length_ratio * PARAMS["post_steps"])

    torch.save(valid_half_masks, f"{read_dir}/valid_half_masks.pt")
    torch.save(valid_full_masks, f"{read_dir}/valid_full_masks.pt")


if __name__ == "__main__":
    main(vars(args))
