from gfn_subtb_grid.agents import TabularGFlowNet, TemperedTabularGFlowNet
from gfn_subtb_grid.agents.losses import SubtrajectoryBalanceLoss
from gfn_subtb_grid.learners import BasicLearner
from gfn_subtb_grid.envs import HypergridEnv
from gfn_subtb_grid.metrics import Losses, DistributionMetrics, GradientCosineSimilarity
from gfn_subtb_grid.buffers import UniformFifoBuffer
from ray import tune
import numpy as np
import torch

TRAINER = BasicLearner

SEARCH_ALGORITHM = None

SCHEDULER_ALGORITHM = None

CONFIG = {
    'seed': tune.grid_search(list(range(30))),
    'num_gpus': 0.5,
    'train_batch_size': 64,
    'num_target_train_batches_per_step': 1,
    'env_config': {
        'type': HypergridEnv,
        'side_length': 8,
        'num_dims': 2,
        'R_0': 1e-4,
        'R_1': 1.0,
        'R_2': 3.0
    },
    'buffer_config': {
        'type': UniformFifoBuffer,
        'capacity': 64,
    },
    'target_agent_config': {
        'type': TabularGFlowNet,
        'loss_fxn_config': {
            'type': SubtrajectoryBalanceLoss,
            'lambda': 0.8
        },
        'param_backward_policy': True,
        'init_log_Z_val': 1.0,
        'log_Z_optim_config': {
            'type': torch.optim.Adam,
            'lr': 1.0,
        },
        'optim_config': {
            'type': torch.optim.Adam,
            'lr': .007,
        }
    },
    'behavior_agent_config': {
        'type': TemperedTabularGFlowNet,
        'temperature': 2,
        'loss_fxn_config': {
            'type': SubtrajectoryBalanceLoss,
            'lambda': 0.1
        },
        'param_backward_policy': True,
        'init_log_Z_val': 1.0,
        'log_Z_optim_config': {
            'type': torch.optim.Adam,
            'lr': 1e-1,
        },
        'optim_config': {
            'type': torch.optim.Adam,
            'lr': 1e-1,
        }
    },
    'metrics_config': [
        {'type': Losses},
        {'type': GradientCosineSimilarity, 'compute_period': 100, 'batch_size': 1024},
        {'type': DistributionMetrics, 'num_states_to_track': 200_000},
    ]
}
