import numpy as np
import os
from os import path
from agent.config import ALL_CONFIGS
from tqdm import tqdm
import torch
from utils.common import get_device
from pathlib import Path
from agent.common import RAW_DATASET_FOLDER, NEW_DATASET_FOLDER
from agent.args import task, DEVICE1

DEVICE = DEVICE1


def load_data_set(task: str):
    cfg = ALL_CONFIGS[task]['dataset']

    episode_num = cfg['episode_num']
    episode_steps = cfg['episode_steps']
    state_dim = cfg['state_dim']
    action_dim = cfg['action_dim']
    physics_dim = cfg['physics_dim']

    states = torch.empty((episode_num, episode_steps, state_dim),
                         dtype=torch.float32)
    physics = torch.empty((episode_num, episode_steps, physics_dim),
                          dtype=torch.float32)
    actions = torch.empty((episode_num, episode_steps - 1, action_dim),
                          dtype=torch.float32)

    _dataset_folder = f'{RAW_DATASET_FOLDER}/{cfg["dataset_filename"]}'
    assert path.exists(_dataset_folder)

    cnt = 0
    with tqdm(total=episode_num, desc=f'load {task} dataset...') as pbar:
        for root, dirs, files in os.walk(_dataset_folder):

            if root.endswith('buffer'):
                for f in files:
                    assert f.endswith('.npz')
                    _episode = np.load(f'{root}/{f}')
                    states[cnt] = torch.as_tensor(_episode['observation'], dtype=torch.float32)
                    actions[cnt] = torch.as_tensor(_episode['action'][:-1], dtype=torch.float32)
                    physics[cnt] = torch.as_tensor(_episode['physics'], dtype=torch.float32)
                    pbar.update(1)
                    cnt += 1

    assert cnt == episode_num

    Path(f'{NEW_DATASET_FOLDER}/{task}').mkdir(exist_ok=True, parents=True)
    # _s = torch.as_tensor(states, dtype=torch.float32, device=DEVICE)
    torch.save(states, f'{NEW_DATASET_FOLDER}/{task}/states.pt')
    # del _s

    # _a = torch.as_tensor(actions, dtype=torch.float32, device=DEVICE)
    torch.save(actions, f'{NEW_DATASET_FOLDER}/{task}/actions.pt')
    # del _a

    # _p = torch.as_tensor(physics, dtype=torch.float32, device=DEVICE)
    torch.save(physics, f'{NEW_DATASET_FOLDER}/{task}/physics.pt')
    # del _p


if __name__ == '__main__':
    # for task in ALL_CONFIGS.keys():
    load_data_set(task)
