import multiprocessing as mp
import os
import shutil
from dataclasses import dataclass

import gymnasium as gym
import numpy as np
import pyrallis

import src.envs
from src.envs.dark_key_to_door import train_test_goals
from src.envs.dark_room import train_test_goals_dr
from src.utils.q_learning import q_learning
from src.utils.misc import set_seed


@dataclass
class Config:
    seed: int = 0
    env_name: str = "Q-Dark-Key2Door-9x9-v0"
    num_train_goals: int = 60
    num_histories: int = 2424
    num_episodes: int = 1000
    savedir: str = "trajectories"
    lr: float = 5e-5
    eps_coef: float = 1.0
    discount: float = 0.9
    random_data: bool = False


class Worker:
    def __init__(self, config):
        self.config = config

    def __call__(self, goal):
        env = gym.make(self.config.env_name, goal_pos=goal)
        os.makedirs(self.config.savedir, exist_ok=True)
        _ = q_learning(
            env,
            lr=self.config.lr,
            num_episodes=self.config.num_episodes,
            savedir=self.config.savedir,
            discount=self.config.discount,
            random_data=self.config.random_data,
        )


@pyrallis.wrap()
def generate_dataset(config: Config):
    set_seed(config.seed)

    split_function = train_test_goals if "Key" in config.env_name else train_test_goals_dr
    train_goals, _ = split_function(
        grid_size=gym.make(config.env_name).unwrapped.size,
        num_train_goals=config.num_train_goals,
        seed=config.seed
    )
    print(len(train_goals), config.num_histories)
    assert config.num_histories >= len(train_goals)
    goal_inds = np.random.choice(len(train_goals), size=config.num_histories - len(train_goals), replace=True)
    # to ensure that at least once all goals are selected
    goals = np.vstack([train_goals, train_goals[goal_inds]])
    assert len(np.unique(goals, axis=0)) >= len(train_goals)

    print("Generating data for goals:")
    print(goals, goals.shape)
    if os.path.exists(config.savedir):
        shutil.rmtree(config.savedir)

    with mp.Pool(processes=os.cpu_count()) as pool:
        pool.map(Worker(config), goals.tolist())


if __name__ == "__main__":
    generate_dataset()
