import argparse
from collections import defaultdict

import os
import glob
import json
import pandas as pd


# argparse
parser = argparse.ArgumentParser()
parser.add_argument("--sweep_name", type=str, required=True, help='Name of the sweep, ideally {YYYY}{MM}{DD}_sweep')
parser.add_argument("--results_dir", type=str, required=True, help='Directory containing the sweep results', default='results/')
parser.add_argument("--n_best", type=int, default=5, help='Keeps the best n sets of hyperparameters for each (algorithm, game) pair')
parser.add_argument("--n_seeds", type=int, default=10, help='Number of seeds to run for each set of best hyperparameters')
parser.add_argument("--max_steps", type=int, default=10_000_000, help='Total number of training steps')
parser.add_argument("--compute_exploitability_every", type=int, default=2_000_000, help='How often to compute exploitability')
args = parser.parse_args()

# folder structure: {args.results_dir}/{algorithm}/{game}/{timestamp}/(exploitability.csv|metadata.json)
metadata_paths = glob.glob(os.path.join(args.results_dir, "*/*/*/metadata.json"))
algorithms = sorted(set([path.split('/')[-4] for path in metadata_paths]))
algorithms = [algo for algo in algorithms if "old" not in algo]
games = sorted(set([path.split('/')[-3] for path in metadata_paths]))
print(f'Found {len(metadata_paths)} runs')
print(f'with algorithms: {algorithms}')
print(f'with games: {games}')
print()

# load and store data in dict: results[(algorithm, game)][cmd_without_seed] = [expls of all seeds...]
results = defaultdict(lambda: defaultdict(list))
for algorithm in algorithms:
    for game in games:
        run_dirs = glob.glob(os.path.join(args.results_dir, algorithm, game, "*"))
        print(f'- ({algorithm}, {game}) : Found {len(run_dirs)} runs')
        for run_dir in run_dirs:
            # load metadata and final exploitability
            expl_path = os.path.join(run_dir, "exploitability.csv")
            metadata_path = os.path.join(run_dir, "metadata.json")
            if not os.path.isfile(expl_path) or not os.path.isfile(metadata_path):
                print(f'[ERROR] Missing file in {run_dir}: {os.listdir(run_dir)}')
            try:
                df = pd.read_csv(expl_path)
            except FileNotFoundError:
                print(f'[ERROR] File not found: {expl_path}')
                continue
            expl = float(df['avg_score_response'].iloc[-1])
            with open(metadata_path, 'r') as f:
                metadata = json.load(f)
            # remove seed from cmd and things we'll modify later
            cmd = metadata['cmd']
            new_cmd = []
            for part in cmd.split():
                arg_name = part.split('=')[0]
                if arg_name in ['seed', 'group_name', 'max_steps', 'compute_exploitability_every', 'save_dir', 'job_id']:
                    new_cmd.append(f'{arg_name}=$({arg_name.upper()})')
                else:
                    new_cmd.append(part)
            new_cmd = ' '.join(new_cmd)
            # save expl
            results[(algorithm, game)][new_cmd].append(expl)
print()

# now we have exploitabilities for all seeds grouped by the cmd to run them
# compute average over seeds and keep top n_best to rerun
best_seeds = defaultdict(list)  # (algorithm, game) -> [best seed commands...]
for algorithm, game in results:
    buffer = []  # list of (cmd, avg expl over seeds)
    n_seeds = set()
    for cmd, expls in results[(algorithm, game)].items():
        n_seeds.add(len(expls))
        avg_expl = sum(expls) / len(expls)
        buffer.append(dict(cmd=cmd, expl=avg_expl))
    buffer.sort(key=lambda x: x['expl']) # we want to keep smallest exploitabilities
    kept_seeds = buffer[:args.n_best]
    best_seeds[(algorithm, game)] = [x['cmd'] for x in kept_seeds]
    print(f'{algorithm} > {game} - #seeds={n_seeds}, expl range=[{round(buffer[-1]["expl"], 2)}, {round(buffer[0]["expl"], 2)}]', end=', ')
    print('keeping seeds with expls:', [round(x['expl'], 2) for x in kept_seeds])
    if len(n_seeds) != 1:
        print('[ERROR] Different number of seeds for algorithm {algorithm} and game {game}: {n_seeds}')
print()

# output to bash file (as job array)
# split between 60GB jobs and 120GB jobs
out = defaultdict(str)
out_idx = defaultdict(int)

for (algorithm, game), cmds in best_seeds.items():
    # choose mem requirement
    mem_gb = 60
    if game in ['abrupt_dark_hex', 'abrupt_phantom_ttt'] and algorithm in ['nfsp', 'psro', 'escher_parallel']:
        mem_gb = 120
    if game == 'classical_dark_hex' and algorithm == 'nfsp':
        mem_gb = 120
    if algorithm == 'nfsp':
        mem_gb = 120
    out[mem_gb] += "\n\t" + '#' * 50 + f"\n\t# {algorithm.upper()} - {game.upper()}\n\t" + '#' * 50 + "\n\n"
    
    # generate commands
    for i, cmd in enumerate(cmds):
        cmd = cmd.replace('$(GROUP_NAME)', args.sweep_name) \
                 .replace('$(MAX_STEPS)', str(args.max_steps)) \
                 .replace('$(COMPUTE_EXPLOITABILITY_EVERY)', str(args.compute_exploitability_every)) \
                 .replace('$(SAVE_DIR)', f'/scratch/USERNAME/log_dir/2p0s/sweeps') \
                 .replace('$(JOB_ID)', f'{mem_gb}gb_$(JOB_ID)_{i+1}th_seed$(SEED)')
        for seed in range(args.n_seeds):
            cmd_sweep = cmd.replace('$(SEED)', str(10000 + seed + 1)) \
                           .replace('$(JOB_ID)', str(out_idx[mem_gb]))
            out[mem_gb] += f'\t{out_idx[mem_gb]})\n'  # switch case
            out[mem_gb] += '\t\t' + f'# {i+1}-th best hparams, seed {seed + 1} ({algorithm}, {game})\n'
            out[mem_gb] += '\t\t' + cmd_sweep.replace('\n', '\n\t\t') + '\n\t;;\n'
            out_idx[mem_gb] += 1

print(f'{sum(out_idx.values())} jobs generated.')
print('\nRun the following to launch:\n')
for k, v in out.items():
    file_path = f'job_{k}GB_{args.sweep_name}.sh'
    if isinstance(k, int) and out_idx[k] > 0:
        print(f'sbatch {file_path}')
        with open(file_path, 'w') as f:
            f.write(
f"""#!/bin/bash
#SBATCH --account=pr_100_tandon_priority
#SBATCH --time=48:00:00
#SBATCH --mem={k}G
#SBATCH --cpus-per-task=8
#SBATCH --array=0-{str(out_idx[k] - 1)}
#SBATCH --output=logs/{args.sweep_name}/output_%A_%a.out
#SBATCH --error=logs/{args.sweep_name}/error_%A_%a.err

# TOTAL: {sum(out_idx.values())} jobs
#    = {len(algorithms)} algorithms
#    x {len(games)} games
#    x {args.n_best} hyperparameter sets (this could be less if there are not enough samples in the results dir)
#    x {args.n_seeds} seeds

# IN THIS FILE ({k} GB): {out_idx[k]} jobs

case "$SLURM_ARRAY_TASK_ID" in

    {v.strip()}    
    *)
        echo "Invalid SLURM_ARRAY_TASK_ID: $SLURM_ARRAY_TASK_ID"
        exit 1
    ;;
esac
""")
