from hydra import compose, initialize
import argparse
import random
from collections import defaultdict


# argparse
parser = argparse.ArgumentParser()
parser.add_argument("--sweep_name", type=str, required=True, help='Name of the sweep, ideally {YYYY}{MM}{DD}_sweep')
parser.add_argument("--n_samples", type=int, default=50, help='Number of random sets of hyperparameters to generate for each (algorithm, game) pair')
parser.add_argument("--n_seeds", type=int, default=3, help='Number of seeds to run for each set of hyperparameters')
parser.add_argument("--n_powers_of_2", type=int, default=3, help='Multiply random params by 2^n for n in [-n_powers_of_2, n_powers_of_2]')
parser.add_argument("--sleep", type=float, default=1, help='Sleep duration (in seconds) between launching each job.')
parser.add_argument("--games", type=str, default='all', help='List of comma-separated games to run, or "all" to run all games')
parser.add_argument("--algs", type=str, default='all', help='List of comma-separated algorithms to run, or "all" to run all algorithms')
parser.add_argument("--max_steps", type=int, default=10_000_000, help='Total number of training steps')
parser.add_argument("--compute_exploitability_every", type=int, default=2_000_000, help='How often to compute exploitability')
parser.add_argument("--dry", type=int, default=0)
parser.add_argument("--array", action='store_true', default=False, help='If set, generate a job array file to run with sbatch.')
args = parser.parse_args()

# algorithms and corresponding params to sweep
algorithms = {}

class Param:
    def __init__(self, name, default, min=None, max=None, mode='default'):
        self.name = name
        self.default = default
        self.min = min
        self.max = max
        self.mode = mode
        assert self.mode in ['default', 'exponent']

# PPO
algorithms["ppo"] = [
    Param("learning_rate", default=0.00025),
    Param("num_steps", default=128),
    # Param("gamma", default=0.99, mode='exponent'),
    # Param("gae_lambda", default=0.95, mode='exponent'),
    Param("num_minibatches", default=4),
    Param("update_epochs", default=4),
    Param("clip_coef", default=0.1),
    Param("ent_coef", default=0.05),
    Param("vf_coef", default=0.5),
    Param("max_grad_norm", default=0.5),
]
# MMD
algorithms["mmd"] = algorithms["ppo"] + [
    Param("kl_coef", default=0.05)
]
# PPG
algorithms["ppg"] = algorithms["ppo"] + [
    Param("n_iteration", default=32), 
    Param("e_policy", default=1), 
    Param("v_value", default=1), 
    Param("e_auxiliary", default=6),
    Param("beta_clone", default=1.0), 
    Param("num_aux_rollouts", default=4), 
    Param("n_aux_grad_accum", default=1)
]
# NFSP
algorithms["nfsp"] = [
    Param("reservoir_buffer_capacity", default=2_000_000, max=4_000_000),
    Param("min_buffer_size_to_learn", default=1_000),
    Param("anticipatory_param", default=0.1, mode='exponent'),
    Param("batch_size", default=128),
    Param("learn_every", default=64),
    Param("sl_learning_rate", default=0.01),
    # Param("rl_learning_rate", default=0.01),
    Param("inner_rl_agent.replay_buffer_capacity", default=200_000, max=500_000),
    Param("inner_rl_agent.batch_size", default=128),
    Param("inner_rl_agent.learning_rate", default=0.01),
    Param("inner_rl_agent.update_target_network_every", default=19_200),
    Param("inner_rl_agent.epsilon_decay_duration", default=10_000_000),
    Param("inner_rl_agent.epsilon_start", default=0.06),
    Param("inner_rl_agent.epsilon_end", default=0.001),
    # Param("inner_rl_agent.learn_every", default=10),
]
# PSRO
algorithms["psro"] = [
    Param("sims_per_entry", default=1_000),
    Param("number_training_episodes", default=1_000),
    Param("inner_rl_agent.batch_size", default=128),
    Param("inner_rl_agent.learning_rate", default=0.01),
    Param("inner_rl_agent.update_target_network_every", default=1_000),
    Param("inner_rl_agent.epsilon_decay_duration", default=10_000_000),
    Param("inner_rl_agent.epsilon_start", default=0.06),
    Param("inner_rl_agent.epsilon_end", default=0.001),
    Param("inner_rl_agent.replay_buffer_capacity", default=200_000, max=500_000),
    # Param("inner_rl_agent.learn_every", default=10),
]
# RNAD
algorithms["rnad"] = [
    Param("batch_size", default=256),
    Param("learning_rate", default=5e-5),
    Param("clip_gradient", default=10_000),
    Param("target_network_avg", default=0.001),
    Param("eta_reward_transform", default=0.2),
    Param("entropy_schedule_size_value", default=50_000),
    Param("c_vtrace", default=1.0),
]
# ESCHER
algorithms["escher_parallel"] = [
    Param("num_traversals", default=1_000),
    Param("num_val_fn_traversals", default=1_000),
    Param("regret_train_steps", default=5_000),
    Param("val_train_steps", default=5_000),
    Param("policy_net_train_steps", default=10_000),
    Param("batch_size_regret", default=2_048),
    Param("batch_size_val", default=2_048),
    Param("learning_rate", default=1e-3),
    # Param("expl", default=1.0),
    Param("val_expl", default=0.01),
]


SBATCH_HEADER = f"""#!/bin/bash
#SBATCH --account=pr_100_tandon_priority
#SBATCH --time=48:00:00
#SBATCH --mem=$(MEM_GB)G
#SBATCH --cpus-per-task=8
#SBATCH --array=0-$(IDX_MAX)
#SBATCH --output=logs/{args.sweep_name}/output_%A_%a.out
#SBATCH --error=logs/{args.sweep_name}/error_%A_%a.err
"""


# games
all_games = ["classical_phantom_ttt", "abrupt_phantom_ttt", "classical_dark_hex", "abrupt_dark_hex"]
if args.games == 'all':
    games = all_games
else:
    games = []
    for game in args.games.split(','):
        game = game.strip()
        if game in all_games:
            games.append(game)
        else:
            raise ValueError(f"Unrecognized game {game}")

# algs
all_algs = list(algorithms.keys())
if args.algs == 'all':
    algs = all_algs
else:
    algs = []
    for alg in args.algs.split(','):
        alg = alg.strip()
        if alg in all_algs:
            algs.append(alg)
        else:
            raise ValueError(f"Unrecognized algorithm {alg}")

def get_nested(cfg, key):
    keys = key.split('.')
    value = cfg
    for k in keys:
        value = value.get(k)
        if value is None:
            return None
    return value

out = defaultdict(str)
out_idx = defaultdict(int)

n_jobs_str = f'{len(algs) * len(games) * args.n_samples * args.n_seeds} runs generated.'
print(n_jobs_str)
if not args.array:
    out['sweep'] += f'# {n_jobs_str}\n\n'
else:
    for mem_gb in [60, 120]:
        out[mem_gb] = 'case "$SLURM_ARRAY_TASK_ID" in\n'

base_seed = random.randint(0, int(1e10))

for alg in algs:
    # get default hyperparameters for the algorithm
    with initialize(version_base=None, config_path="configs"):
        default_cfg = compose(config_name="experiment", overrides=[f"algorithm={alg}"])
    for game in games:
        mem_gb = 60
        if game in ['abrupt_dark_hex', 'abrupt_phantom_ttt'] and alg in ['nfsp', 'psro', 'escher_parallel']:
            mem_gb = 120
        if alg == 'nfsp':
            mem_gb = 120

        for sample in range(args.n_samples):
            cmd = f"# ALG={alg} GAME={game} SAMPLE={sample} SEED=$SEED\n"
            if args.array:
                cmd += f"python main.py algorithm={alg} game={game} seed=$SEED group_name={args.sweep_name} max_steps={args.max_steps} \\\n"
            else:
                cmd += f"python submit.py --compute_config configs/cluster/greene.yaml --mem '{mem_gb}gb' --main main.py --save_dir /scratch/USERNAME/log_dir/2p0s/{args.sweep_name} \\\n"
                cmd += f"--alg_args algorithm={alg} game={game} seed=$SEED group_name={args.sweep_name} max_steps={args.max_steps} \\\n"
            
            n = 0
            for sweep_param in algorithms[alg]:
                # get default param value
                default_val = get_nested(default_cfg.algorithm, sweep_param.name)
                assert default_val == sweep_param.default or ('inner_rl_agent' in sweep_param.name), f"Default values for {alg}.{sweep_param.name} are different: {default_val} and {sweep_param.default}"
                if ('inner_rl_agent' in sweep_param.name):
                    out[mem_gb if args.array else 'sweep'] += "# Changing inner_rl_agent default value.\n"
                if default_val is None:
                    raise ValueError(f"Unrecognized argument algorithm.{sweep_param.name} for algorithm {alg.upper()}")

                # randomize it
                # we want that two params of the same name get the same value across algorithms, for a given game and sample
                random.seed(f"{sweep_param.name}{sample}{game}{base_seed}")
                multiplier = 1 if sample == 0 else 2 ** (random.randint(-args.n_powers_of_2, args.n_powers_of_2))
                is_int = isinstance(default_val, int)
                if sweep_param.mode == 'exponent':
                    val = default_val ** multiplier
                else:
                    val = default_val * multiplier
                if sweep_param.min is not None:
                    val = max(sweep_param.min, val)
                if sweep_param.max is not None:
                    val = min(sweep_param.max, val)
                if isinstance(default_val, int):
                    val = max(1, int(val))

                # add it to command
                cmd += f"algorithm.{sweep_param.name}={val} "
                n += 1
                if n % 5 == 0:
                    cmd += "\\\n"
                    n = 0
            if n > 0:
                cmd += "\\\n"

            cmd += f"compute_exploitability=True compute_exploitability_every={args.compute_exploitability_every} "
            if args.array:
                cmd += f"save_dir=/scratch/USERNAME/log_dir/2p0s/sweeps job_id={mem_gb}gb_$JOBID"
            else:
                cmd += f"--dry {args.dry}"

            # different seeds
            for seed in range(args.n_seeds):
                if args.array:
                    out[mem_gb] += f'\t{out_idx[mem_gb]})\n'
                    out[mem_gb] += '\t\t' + cmd.replace("$SEED", str(seed + 1)).replace("$JOBID", str(out_idx[mem_gb])).replace('\n', '\n\t\t') + '\n\t;;\n\n'
                    out_idx[mem_gb] += 1
                else:
                    out['sweep'] += cmd.replace("$SEED", str(seed + 1)) + '\n\n'
                    if args.sleep:
                        out['sweep'] += f'sleep {args.sleep}\n\n'


print('\nRun the following to launch:\n')
if args.array:
    for k, v in out.items():
        if isinstance(k, int) and out_idx[k] > 0:
            with open(f'job_{k}GB.sh', 'w') as f:
                f.write(SBATCH_HEADER.replace('$(IDX_MAX)', str(out_idx[k] - 1)).replace('$(MEM_GB)', str(k)) + '\n\n')
                f.write(v)
                f.write("""    *)
        echo "Invalid SLURM_ARRAY_TASK_ID: $SLURM_ARRAY_TASK_ID"
        exit 1
        ;;
esac""")
                print(f'sbatch job_{k}GB.sh')
else:
    with open('sweep.sh', 'w') as f:
        f.write(out['sweep'])
        print('bash sweep.sh')
