import glob
import multiprocessing as mp
import os
import shutil
from collections import defaultdict
from dataclasses import dataclass

import numpy as np
import pyrallis


@dataclass
class Config:
    seed: int = 0
    full_dataset: str = "trajectories"
    states_dim: int = 1
    actions_dim: int = 1
    # savedir: str = "trajectories"


def load_learning_histories_with_name(path: str):
    files = glob.glob(f"{path}/*.npz")

    learning_histories = []
    for filename in files:
        with np.load(filename, allow_pickle=True) as f:
            learning_histories.append(({
                "states": f["states"],
                "actions": f["actions"],
                "rewards": f["rewards"],
                "dones": f["dones"],
                "goal": f["goal"],
                "returns": f["returns"],
            }, filename))

    return learning_histories


def split_to_episodes(learning_history):
    trajectories = []
    # print(learning_history['states'].shape)
    # print(learning_history['returns'].shape)
    # print(learning_history['goal'].shape)
    traj_data = defaultdict(list)
    n_episode = 0
    for step in range(len(learning_history["dones"])):
        # append data
        traj_data["states"].append(learning_history["states"][step])
        traj_data["actions"].append(learning_history["actions"][step])
        traj_data["dones"].append(learning_history["dones"][step])
        traj_data["rewards"].append(learning_history["rewards"][step])

        if learning_history["dones"][step]:
            trajectories.append({k: np.array(v) for k, v in traj_data.items()} |
                                {"returns": learning_history["returns"][n_episode]})
            traj_data = defaultdict(list)
            n_episode += 1

    return trajectories, learning_history["goal"]


def flatten(trajectories, goal):
    return {
        "states": np.concatenate([traj["states"] for traj in trajectories]),
        "actions": np.concatenate([traj["actions"] for traj in trajectories]),
        "rewards": np.concatenate([traj["rewards"] for traj in trajectories]),
        "returns": np.concatenate([traj["returns"] for traj in trajectories]),
        "dones": np.concatenate([traj["dones"] for traj in trajectories]),
        "goal":  goal
    }


def dump_trajectories(savedir: str, name: str, trajectories, states_dim=1, actions_dim=1):
    np.savez(
        os.path.join(savedir, name),
        states=np.array(trajectories["states"], dtype=float).reshape(-1, states_dim),
        actions=np.array(trajectories["actions"]).reshape(-1, actions_dim),
        rewards=np.array(trajectories["rewards"], dtype=float).reshape(-1, 1),
        dones=np.int32(trajectories["dones"]).reshape(-1, 1),
        goal=np.array(trajectories['goal']),
        returns=np.array(trajectories["returns"], dtype=float).reshape(-1, 1),
    )


def merge_data(trajs1, trajs2, goal):
    return {
        "states": np.concatenate([trajs1["states"], trajs2["states"]]),
        "actions": np.concatenate([trajs1["actions"], trajs2["actions"]]),
        "rewards": np.concatenate([trajs1["rewards"], trajs2["rewards"]]),
        "returns": np.concatenate([trajs1["returns"], trajs2["returns"]]),
        "dones": np.concatenate([trajs1["dones"], trajs2["dones"]]),
        "goal": goal
    }


@pyrallis.wrap()
def build_datasets(config: Config):
    histories = load_learning_histories_with_name(config.full_dataset)

    early_trajs = {}  # defaultdict(list)
    mid_trajs = {}
    late_trajs = {}

    os.makedirs(config.full_dataset+"_early", exist_ok=True)
    os.makedirs(config.full_dataset+"_mid", exist_ok=True)
    os.makedirs(config.full_dataset+"_late", exist_ok=True)

    # os.makedirs(config.full_dataset+"_early_mid", exist_ok=True)
    # os.makedirs(config.full_dataset+"_early_late", exist_ok=True)
    # os.makedirs(config.full_dataset+"_mid_late", exist_ok=True)

    sd = 1
    ad = 1
    for (learning_history, name) in histories:
        name = name.split('/')[-1]
        # print(name)
        trajectories, goal = split_to_episodes(learning_history)
        num_episodes = len(trajectories)
        num_episodes_div3 = num_episodes // 3
        # print(goal)
        early_trajs[name] = flatten(trajectories[:num_episodes_div3], goal)
        # print(trajectories[0]['states'].shape, early_trajs[name]['states'].shape)
        if len(trajectories[0]['states'].shape) > 1:
            sd = trajectories[0]['states'].shape[-1]
        if len(trajectories[0]['actions'].shape) > 1:
            ad = trajectories[0]['actions'].shape[-1]
        mid_trajs[name] = flatten(trajectories[num_episodes_div3:2*num_episodes_div3], goal)
        late_trajs[name] = flatten(trajectories[-num_episodes_div3:], goal)
    # print(late_trajs[name]['actions'].shape)
        dump_trajectories(config.full_dataset+"_early", name, early_trajs[name], sd, ad)
        dump_trajectories(config.full_dataset+"_mid", name, mid_trajs[name], sd, ad)
        dump_trajectories(config.full_dataset+"_late", name, late_trajs[name], sd, ad)

        # dump_trajectories(config.full_dataset+"_early_mid", name, merge_data(early_trajs[name], mid_trajs[name], goal))
        # dump_trajectories(config.full_dataset+"_early_late", name, merge_data(early_trajs[name], late_trajs[name], goal))
        # dump_trajectories(config.full_dataset+"_mid_late", name, merge_data(mid_trajs[name], late_trajs[name], goal))


if __name__ == "__main__":
    build_datasets()
