import os
import math
import glob
from typing import List

import numpy as np
import torch
import nle.dataset as nld
import wandb
from omegaconf import DictConfig

from il_scale.nethack.data.parquet_data_loader import ParquetDataLoader
import il_scale.nethack.utils.constants as CONSTANTS

class ParquetData():
    def __init__(
        self, 
        cfg: DictConfig
    ):
        self.cfg = cfg
        print(f'Using dataset: {self.cfg.dataset_name}')

        self.dbfilename = CONSTANTS.DBFILENAME

        # Load train & dev ids
        self.gameids = self._get_gameids()
        self.train_gameids = self.gameids['train_gameids']
        self.dev_gameids = self.gameids['dev_gameids']

        # Train dataloader params
        self.train_batch_size = min(self.cfg.batch_size, len(self.train_gameids))
        self.train_seq_len = self.cfg.unroll_length

    def _get_gameids(self):        
        data = ParquetDataLoader(
            self.cfg.env,
            device=torch.device('cpu'),
            dataset_name=self.cfg.dataset_name,
            batch_size=1,
            seq_length=1,
            dbfilename=self.dbfilename,
            threadpool=None,
            shuffle=True,
            obs_frame_stack=self.cfg.obs_frame_stack,
            max_episode_steps=self.cfg.max_episode_steps
        )
    
        train_gameids = data.gameids

        if self.cfg.train_gameids:
            print(f'Loading gameids from file {self.cfg.train_gameids}')
            train_gameids = np.load(self.cfg.train_gameids).tolist()

        dev_gameids = []
        
        return { 
            "train_gameids": train_gameids,
            "dev_gameids": dev_gameids
        }

    def get_train_dataloader(self, tp = None, rank: int = 0, world_size: int = 1, loop_forever: bool = True):
        data_chunk = math.ceil(len(self.train_gameids)/world_size) # spreads a bit uneven but probably fine
        return ParquetDataLoader(
            self.cfg.env,
            rank,
            dataset_name=self.cfg.dataset_name,
            batch_size=self.train_batch_size,
            seq_length=self.train_seq_len,
            dbfilename=self.dbfilename,
            threadpool=tp,
            shuffle=True,
            gameids=self.train_gameids[rank * data_chunk: (rank + 1) * data_chunk],
            loop_forever=loop_forever,
            obs_frame_stack=self.cfg.obs_frame_stack,
            max_episode_steps=self.cfg.max_episode_steps
        )