from gfn_subtb_grid.agents import BasicGFlowNet, EpsilonNoisyGFlowNet
from gfn_subtb_grid.agents.losses import SubtrajectoryBalanceLoss
from gfn_subtb_grid.learners import BasicLearner
from gfn_subtb_grid.envs import HypergridEnv
from gfn_subtb_grid.metrics import Losses, DistributionMetrics, GradientCosineSimilarity
from gfn_subtb_grid.buffers import UniformFifoBuffer
from ray import tune
import numpy as np
import torch

TRAINER = BasicLearner

SEARCH_ALGORITHM = None

SCHEDULER_ALGORITHM = None

CONFIG = {
    'seed': tune.grid_search(list(range(3))),
    'num_gpus': 0.5,
    'train_batch_size': 16,
    'num_target_train_batches_per_step': 1,
    'env_config': {
        'type': HypergridEnv,
        'side_length': 24,
        'num_dims': 4,
        'R_0': 1e-3,
        'R_1': 0.5,
        'R_2': 2.0
    },
    'buffer_config': {
        'type': UniformFifoBuffer,
        'capacity': 64,
    },
    'target_agent_config': {
        'type': BasicGFlowNet,
        'loss_fxn_config': {
            'type': SubtrajectoryBalanceLoss,
            'lambda': tune.grid_search([0.8, 0.9, 0.99])
        },
        'hidden_layer_dim': 256,
        'num_hidden_layers': 2,
        'param_backward_policy': True,
        'init_log_Z_val': 1.0,
        'log_Z_optim_config': {
            'type': torch.optim.Adam,
            'lr': tune.grid_search([0.005, 0.0075, 0.01, 0.03, 0.05, 0.075, 0.1]),
        },
        'optim_config': {
            'type': torch.optim.Adam,
            'lr': tune.grid_search([0.0005, 0.00075, 0.001, 0.003, 0.005, 0.0075, 0.01]),
        }
    },
    'behavior_agent_config': {
        'type': EpsilonNoisyGFlowNet,
        'epsilon': 0.01,
        'loss_fxn_config': {
            'type': SubtrajectoryBalanceLoss,
            'lambda': 0.1
        },
        'param_backward_policy': True,
        'init_log_Z_val': 1.0,
        'log_Z_optim_config': {
            'type': torch.optim.Adam,
            'lr': 1e-1,
        },
        'optim_config': {
            'type': torch.optim.Adam,
            'lr': 1e-1,
        }
    },
    'metrics_config': [
        {'type': Losses},
        {'type': DistributionMetrics, 'num_states_to_track': 200_000},
    ]
}
