from concurrent.futures import ThreadPoolExecutor
import time
import os
from collections import defaultdict, Counter

import torch
import hydra
from omegaconf import DictConfig, OmegaConf
import numpy as np

from il_scale.nethack.data.tty_data import TTYData
from il_scale.nethack.data.parquet_data import ParquetData

@hydra.main(version_base=None, config_path="../../../conf", config_name="nethack_config")
def main(cfg: DictConfig) -> None:
    print(cfg.data)
    if cfg.data.dataset_type == 'glyph':
        data = ParquetData(cfg.data)
        print('Using parquet loader ...')
    else:
        data = TTYData(cfg.data)
        print('Using ttyrec loader ...')
    with ThreadPoolExecutor(max_workers=50) as tp:
        dataloader = data.get_train_dataloader(tp)
        gameids = dataloader.dataset._gameids

        with open(f'{cfg.data.dataset_name}_stats.txt', 'w') as f:
            # First, get statistics: mean, median
            scores = [dataloader.dataset.get_meta(gid)['points'] for gid in gameids]
            scores = np.array(scores)
            f.write(f'Mean: {np.mean(scores)}\n')
            f.write(f'Median: {np.median(scores)}\n')
            f.write('\n')

            # Second, make dungeon level frequency plot
            dlvls = [dataloader.dataset.get_meta(gid)['maxlvl'] for gid in gameids]
            dlvl_freq = Counter(dlvls)
            dlvl_at_least = {k: sum([v for k_, v in dlvl_freq.items() if k_ >= k]) for k in range(1, 22)}
            total = sum(dlvl_freq.values())
            dlvl_at_least = {k: v/total for k, v in dlvl_at_least.items()}

            for i in range(1, 22):
                if i in dlvl_at_least:
                    f.write(f'{i}: {dlvl_at_least[i]:.3f}\n')
            f.write('\n')
        
            death_categories = {
                'killed': 0,
                'quit': 0,
                'poisoned': 0,
                'other': 0,
                'corpse': 0
            }
            death_killed = defaultdict(int)
            death_poisoned = defaultdict(int)

            deaths = []
            for gid in gameids:
                death = dataloader.dataset.get_meta(gid)['death'].strip()
                deaths.append(death)

                if death.startswith('killed'):
                    death_categories['killed'] += 1
                    death_killed[death] += 1
                elif death.startswith('quit'):
                    death_categories['quit'] += 1
                elif death.startswith('poisoned'):
                    death_categories['poisoned'] += 1
                    death_poisoned[death] += 1
                elif 'corpse' in death:
                    death_categories['corpse'] += 1
                else:
                    death_categories['other'] += 1

            # normalize
            total_deaths = sum(death_categories.values())
            death_categories = {k: v / total_deaths for k, v in death_categories.items()}
            for k, v in death_categories.items(): 
                f.write(f'{k}: {v:.3f}\n')
            f.write('\n')

            death_poisoned = {k: v / total_deaths for k, v in death_poisoned.items()}
            death_poisoned_sorted = list(sorted(death_poisoned.items(), key=lambda x: x[1], reverse=True))
            for k, v in death_poisoned_sorted:
                f.write(f'{k}: {v:.4f}\n')
            f.write('\n')

            death_killed = {k: v / total_deaths for k, v in death_killed.items()}
            death_killed_sorted = list(sorted(death_killed.items(), key=lambda x: x[1], reverse=True))
            for k, v in death_killed_sorted[:10]:
                f.write(f'{k}: {v:.4f}\n')


        # # i = 1
        # # start = time.time()
        # # for batch in dataloader:
        # #     i += 1
        # #     # breakpoint()
        # #     if i % 1000 == 0:
        # #         # fps
        # #         print(1000 * cfg.data.batch_size * cfg.data.unroll_length / (time.time() - start))
        # #         start = time.time()
                

if __name__ == "__main__":
    main()