import os
import glob
from collections import defaultdict

# ROLLOUT_FOLDER = '/n/fs/nlp-il-scale/il-scaling/logs/nethack/rollouts/scale_mamba_h1024_l24_16k_2xlj0o01/mon-hum-neu-mal/softmax/temp_1.0_topp_0.9_topk_1_steps_1000000_penalty_0_tag_model_84.0.tar_16k/rollouts'
# ROLLOUT_FOLDER = '/n/fs/nlp-il-scale/il-scaling/logs/nethack/rollouts/scale_mamba_h1024_l24_16k_2xlj0o01/mon-hum-neu-mal/softmax/temp_1.0_topp_0.9_topk_1_steps_1000000_penalty_0_tag_model_84.0.tar/rollouts'
# ROLLOUT_FOLDER = '/scratch/gpfs/anonymous/il-scaling/logs/nethack/rollouts/2llgx1zr_2llgx1zr/mon-hum-neu-mal/softmax/temp_1.0_topp_0.9_topk_1_steps_1000000_penalty_0_tag_model_46.0.tar/rollouts'
# ROLLOUT_FOLDER = '/scratch/gpfs/anonymous/il-scaling/logs/nethack/rollouts/mhytjyru_mhytjyru/mon-hum-neu-mal/softmax/temp_1.0_topp_0.9_topk_1_steps_1000000_penalty_0_tag_model_74.0.tar/rollouts'
ROLLOUT_FOLDER = '/scratch/gpfs/anonymous/il-scaling/logs/nethack/rollouts/mhytjyru_mhytjyru/mon-hum-neu-mal/softmax/temp_1.0_topp_0.9_topk_1_steps_1000000_penalty_0_tag_model_97.0.tar/rollouts'

deaths = []
death_categories = {
    'killed': 0,
    'quit': 0,
    'poisoned': 0,
    'other': 0,
    'corpse': 0
}
death_killed = defaultdict(int)
death_poisoned = defaultdict(int)
death_while = defaultdict(int)

rollouts = os.listdir(ROLLOUT_FOLDER)
print('number of rollouts', len(rollouts))
for rollout in rollouts:
    r_folder = os.path.join(ROLLOUT_FOLDER, rollout)
    # find the xlog file
    r_files = os.listdir(r_folder)
    for r_file in r_files:
        if r_file.endswith('.xlogfile'):
            # open file
            with open(os.path.join(r_folder, r_file), 'r') as f:
                lines = f.readlines()
                if len(lines) == 0:
                    continue
                line = lines[0].strip()
                # write me regex that parses death=<death_reason> from line
                death = line.split('death=')[1].split('\t')[0]
                deaths.append(death)

                if death.startswith('killed'):
                    death_categories['killed'] += 1
                    death_killed[death] += 1
                elif death.startswith('quit'):
                    death_categories['quit'] += 1
                elif death.startswith('poisoned'):
                    death_categories['poisoned'] += 1
                    death_poisoned[death] += 1
                elif 'corpse' in death:
                    death_categories['corpse'] += 1
                else:
                    death_categories['other'] += 1

                if 'while' in line:
                    wh = line.split('while=')[1].split('\t')[0].strip()
                    death_while[wh] += 1

# normalize
total_deaths = sum(death_categories.values())
death_categories = {k: v / total_deaths for k, v in death_categories.items()}
print(death_categories)

# death_killed = {k: v / total_deaths for k, v in death_killed.items()}
# death_killed_sorted = list(sorted(death_killed.items(), key=lambda x: x[1], reverse=True))
# for k, v in death_killed_sorted:
#     print(f'{k}: {v}')

# death_poisoned = {k: v / total_deaths for k, v in death_poisoned.items()}
# death_poisoned_sorted = list(sorted(death_poisoned.items(), key=lambda x: x[1], reverse=True))
# for k, v in death_poisoned_sorted:
#     print(f'{k}: {v}')

# total_while = sum(death_while.values())
# death_while = {k: v / total_deaths for k, v in death_while.items()}
# death_while_sorted = list(sorted(death_while.items(), key=lambda x: x[1], reverse=True))
# for k, v in death_while_sorted:
#     print(f'{k}: {v}')



                
