from concurrent.futures import ThreadPoolExecutor
import time
import os
from collections import defaultdict

import torch
import hydra
from omegaconf import DictConfig, OmegaConf
import numpy as np

from il_scale.nethack.data.tty_data import TTYData
from il_scale.nethack.data.parquet_data import ParquetData

@hydra.main(version_base=None, config_path="../../../conf", config_name="nethack_config")
def main(cfg: DictConfig) -> None:
    print(cfg.data)
    if cfg.data.dataset_type == 'glyph':
        data = ParquetData(cfg.data)
        print('Using parquet loader ...')
    else:
        data = TTYData(cfg.data)
        print('Using ttyrec loader ...')
    with ThreadPoolExecutor(max_workers=50) as tp:
        dataloader = data.get_train_dataloader(tp)

        gameids = dataloader.dataset._gameids
        import random
        random.shuffle(gameids)

        np.save(f"datasets/{cfg.data.dataset_name}_split_1.npy", gameids[:int(len(gameids) * 0.5)])
        np.save(f"datasets/{cfg.data.dataset_name}_split_2.npy", gameids[int(len(gameids) * 0.5):])

        # i = 0
        # for batch in dataloader:
        #     if torch.any(batch['done']):
        #         print(batch['gameids'])
        #         print(batch['done'])
        #         print(batch['prev_action'][batch['done']])
        #         print(batch['gameids'][batch['done']])
                # print(batch['gameids'])
                # print(batch['prev_action'])
                # print(batch['labels'])


        # i = 0
        # start = time.time()
        # act_freq = defaultdict(int)
        # for batch in dataloader:
        #     labels = batch['labels']
        #     for l in labels:
        #         act_freq[l.item()] += 1
        #     i += 1
        #     # breakpoint()
        #     if i % 100 == 0:
        #         break

        # # normalize
        # total = sum(act_freq.values())
        # act_freq = {k: v / total for k, v in act_freq.items()}
        # print(act_freq[35])
        # exit(0)

        
        # gameids = dataloader.dataset._gameids
        # scores = [dataloader.dataset.get_meta(gid)['points'] for gid in gameids]
        
        # scores = np.array(scores)

        # breakpoint()

        # death_categories = {
        #     'killed': 0,
        #     'quit': 0,
        #     'poisoned': 0,
        #     'other': 0,
        #     'corpse': 0
        # }
        # death_killed = defaultdict(int)
        # death_poisoned = defaultdict(int)

        # deaths = []
        # for gid in gameids:
        #     death = dataloader.dataset.get_meta(gid)['death'].strip()
        #     deaths.append(death)

        #     if death.startswith('killed'):
        #         death_categories['killed'] += 1
        #         death_killed[death] += 1
        #     elif death.startswith('quit'):
        #         death_categories['quit'] += 1
        #     elif death.startswith('poisoned'):
        #         death_categories['poisoned'] += 1
        #         death_poisoned[death] += 1
        #     elif 'corpse' in death:
        #         death_categories['corpse'] += 1
        #     else:
        #         death_categories['other'] += 1

        # # normalize
        # total_deaths = sum(death_categories.values())
        # death_categories = {k: v / total_deaths for k, v in death_categories.items()}
        # print(death_categories)

        # # death_killed = {k: v / total_deaths for k, v in death_killed.items()}
        # # death_killed_sorted = list(sorted(death_killed.items(), key=lambda x: x[1], reverse=True))
        # # for k, v in death_killed_sorted[:10]:
        # #     print(f'{k}: {v}')

        # death_poisoned = {k: v / total_deaths for k, v in death_poisoned.items()}
        # death_poisoned_sorted = list(sorted(death_poisoned.items(), key=lambda x: x[1], reverse=True))
        # for k, v in death_poisoned_sorted:
        #     print(f'{k}: {v}')


        # # i = 1
        # # start = time.time()
        # # for batch in dataloader:
        # #     i += 1
        # #     # breakpoint()
        # #     if i % 1000 == 0:
        # #         # fps
        # #         print(1000 * cfg.data.batch_size * cfg.data.unroll_length / (time.time() - start))
        # #         start = time.time()
                

if __name__ == "__main__":
    main()