from tree_exp.env import get_random_tree_env, solve_tree
from tree_exp.policy import PolicyLinProgTree


def run_tree(depth, seed=42):
    n_batches = 20000
    verbose = depth <= 3

    env = get_random_tree_env(max_depth=depth, seed=seed)
    solve_tree(env.root, gamma)
    print(f'\ndepth = {depth}, seed = {seed}')
    print(
        f'Optimal utility: Principal {round(env.root.utility_principal, 3)}, Agent {round(env.root.utility_agent, 3)}')
    print(
        f'Fraction of non-trivial actions: {round(env.root.actions_count[:-1].sum() / env.root.actions_count.sum(), 3)}')

    policy = PolicyLinProgTree(env, hid_size=256, n_hid_layers=1, gamma=gamma, lr_start=1e-3, lr_end=1e-4,
                               outcome_dist_known=outcome_dist_known, lr_dist_start=1e-3, lr_dist_end=1e-4,
                               eps_start=1, eps_end=0, batch_size=128, n_batches=n_batches, n_interactions=8,
                               buffer_size=int(1e5), n_warm_start_batches=100, target_update_freq=100,
                               log_freq=1000, verbose=verbose, log_wandb=True,
                               )
    policy.train()


if __name__ == '__main__':
    depths = [10]
    n_trees = 5
    n_runs = 5
    outcome_dist_known = True
    gamma = 1.

    RAY = True
    if RAY:
        import ray
        ray.init()

        @ray.remote(num_cpus=1, num_gpus=0)
        def run_tree_remote(depth, seed):
            return run_tree(depth, seed)

        for depth in depths:
            result_ids = []
            for seed in range(n_trees):
                for _ in range(n_runs):
                    result_ids.append(run_tree_remote.remote(depth, seed))
                ray.get(result_ids)
    else:
        for depth in depths:
            for seed in range(n_trees):
                for _ in range(n_runs):
                    run_tree(depth, seed)
