import collections
import copy
import functools
import json
import os

import numpy as np
import optuna
import torch
import torch.nn as nn
import torch.nn.functional as F

from optuna.pruners import MedianPruner
from optuna.samplers import TPESampler
from argparse import Namespace

import jacinle.random as random
import jacinle.io as io

from difflogic.cli import format_args
from difflogic.dataset.utils import ValidActionDataset
from difflogic.envs.blocksworld import make as make_env
from difflogic.nn.baselines import MemoryNet
from difflogic.nn.neural_logic import InputTransform, LogicInference, LogicMachine, LogitsInference
from difflogic.nn.neural_logic.modules._utils import meshgrid_exclude_self
from difflogic.nn.dlm.layer import DifferentiableLogicMachine
from difflogic.nn.dlm.neural_logic import DLMInferenceBase
from difflogic.nn.rl.reinforce import REINFORCELoss, REINFORCELogLoss
from difflogic.train import MiningTrainerBase

from jacinle.cli.argument import JacArgumentParser
from jacinle.logging import get_logger
from jacinle.logging import set_output_file
from jacinle.utils.container import GView
from jacinle.utils.meter import GroupMeters
from jactorch.optim.quickaccess import get_optimizer
from jactorch.utils.meta import as_cuda
from jactorch.utils.meta import as_numpy
from jactorch.utils.meta import as_tensor

TASKS = ['final', 'stack']

parser = JacArgumentParser()

parser.add_argument(
    '--model',
    default='dlm',
    choices=['nlm', 'memnet', 'dlm'],
    help='model choices, nlm: Neural Logic Machine, memnet: Memory Networks, dlm: Differentiable Logic Machine')

# NLM parameters, works when model is 'nlm'.
nlm_group = parser.add_argument_group('Neural Logic Machines')
LogicMachine.make_nlm_parser(
    nlm_group, {
        'depth': 7,
        'breadth': 2,
        'residual': True,
        'exclude_self': True,
        'logic_hidden_dim': []
    },
    prefix='nlm')
nlm_group.add_argument(
    '--nlm-attributes',
    type=int,
    default=8,
    metavar='N',
    help='number of output attributes in each group of each layer of the LogicMachine'
)

# MemNN parameters, works when model is 'memnet'.
memnet_group = parser.add_argument_group('Memory Networks')
MemoryNet.make_memnet_parser(memnet_group, {}, prefix='memnet')

parser.add_argument(
    '--task', required=True, choices=TASKS, help='tasks choices')

optuna_group = parser.add_argument_group('Optuna')
optuna_group.add_argument(
    '--study-name',
    default='study',
    help='name of the study in the database')
optuna_group.add_argument(
    '--n-trials',
    default='1',
    type=int,
    help='number of hyperparameters sampled')

method_group = parser.add_argument_group('Method')
method_group.add_argument(
    '--concat-worlds',
    type=bool,
    default=False,
    help='concat the features of objects of same id among two worlds accordingly'
)
method_group.add_argument(
    '--pred-depth',
    type=int,
    default=None,
    metavar='N',
    help='the depth of nlm used for prediction task')
method_group.add_argument(
    '--pred-weight',
    type=float,
    default=0.1,
    metavar='F',
    help='the linear scaling factor for prediction task')

MiningTrainerBase.make_trainer_parser(
    parser, {
        'epochs': 500,
        'epoch_size': 100,
        'test_epoch_size': 10,
        'test_number_begin': 3,
        'test_number_step': 2,
        'test_number_end': 13,
        'curriculum_start': 2,
        'curriculum_step': 1,
        'curriculum_graduate': 12,
        'curriculum_thresh_relax': 0.005,
        'curriculum_thresh': 1,
        'sample_array_capacity': 3,
        'enable_mining': True,
        'mining_interval': 10,
        'mining_epoch_size': 1000,
        'mining_dataset_size': 100,
        # major edit
        'inherit_neg_data': False,
        'disable_balanced_sample': True,
        'prob_pos_data': 0.6
    })

train_group = parser.add_argument_group('Train')
train_group.add_argument('--seed', type=int, default=None, metavar='SEED')
train_group.add_argument(
    '--use-gpu', action='store_true', help='use GPU or not')
train_group.add_argument(
    '--optimizer',
    default='AdamW',
    choices=['SGD', 'Adam', 'AdamW'],
    help='optimizer choices')
train_group.add_argument(
    '--lr',
    type=float,
    default=0.005,
    metavar='F',
    help='initial learning rate')
train_group.add_argument(
    '--lr-decay',
    type=float,
    default=0.9,
    metavar='F',
    help='exponential decay of learning rate per lesson')
train_group.add_argument(
    '--ntrajectory',
    type=int,
    default=1,
    metavar='N',
    help='number of trajectories to compute gradient')
train_group.add_argument(
    '--batch-size',
    type=int,
    default=4,
    metavar='N',
    help='batch size for extra prediction')
train_group.add_argument(
    '--candidate-relax',
    type=int,
    default=0,
    metavar='N',
    help='number of thresh relaxation for candidate')
train_group.add_argument(
    '--extract-path', action='store_true', help='extract path or not')
train_group.add_argument(
    '--gumbel-noise-decay',
    type=float,
    default=0.995)
train_group.add_argument(
    '--dropout-prob-decay',
    type=float,
    default=0.995)
train_group.add_argument(
    '--tau-decay',
    type=float,
    default=0.995)
train_group.add_argument(
    '--last-tau-decay',
    type=float,
    default=1.0)
train_group.add_argument(
    '--gumbel-noise-begin',
    type=float,
    default=1)
train_group.add_argument(
    '--dropout-prob-begin',
    type=float,
    default=0.1)
train_group.add_argument(
    '--tau-begin',
    type=float,
    default=1)
train_group.add_argument(
    '--last-tau-begin',
    type=float,
    default=0.1)
train_group.add_argument(
    '--norm-rewards',
    type=bool,
    default=False)

rl_group = parser.add_argument_group('Reinforcement Learning')
rl_group.add_argument(
    '--gamma',
    type=float,
    default=0.99,
    metavar='F',
    help='discount factor for accumulated reward function in reinforcement learning'
)
rl_group.add_argument(
    '--penalty',
    type=float,
    default=-0.01,
    metavar='F',
    help='a small penalty each step')
rl_group.add_argument(
    '--entropy-beta',
    type=float,
    default=0.2,
    metavar='F',
    help='entropy loss scaling factor')
rl_group.add_argument(
    '--entropy-beta-decay',
    type=float,
    default=0.8,
    metavar='F',
    help='entropy beta exponential decay factor')
rl_group.add_argument(
    '--dlm-noise',
    type=int,
    default=1,
    metavar='N',
    help='dlm noise handling')
rl_group.add_argument(
    '--reinforce-log',
    type=bool,
    default=False)


io_group = parser.add_argument_group('Input/Output')
io_group.add_argument(
    '--dump-dir', default=None, metavar='DIR', help='dump dir')
io_group.add_argument(
    '--dump-play',
    action='store_true',
    help='dump the trajectory of the plays for visualization')
io_group.add_argument(
    '--dump-fail-only', action='store_true', help='dump failure cases only')
io_group.add_argument(
    '--load-checkpoint',
    default=None,
    metavar='FILE',
    help='load parameters from checkpoint')

schedule_group = parser.add_argument_group('Schedule')
schedule_group.add_argument(
    '--save-interval',
    type=int,
    default=100,
    metavar='N',
    help='the interval(number of epochs) to save checkpoint')
schedule_group.add_argument(
    '--test-interval',
    type=int,
    default=25,
    metavar='N',
    help='the interval(number of epochs) to do test')
schedule_group.add_argument(
    '--test-only', action='store_true', help='test-only mode')
schedule_group.add_argument(
    '--early-drop-epochs',
    type=int,
    default=100,
    metavar='N',
    help='epochs could spend for each lesson, early drop')

# ARGS PARSING
args = parser.parse_args()

args.use_gpu = args.use_gpu and torch.cuda.is_available()
args.dump_play = args.dump_play and (args.dump_dir is not None)

if args.seed is not None:
    random.reset_global_seed(args.seed)

main_args = copy.deepcopy(args)

make_env = functools.partial(make_env, random_order=True, exclude_self=True, fix_ground=True)

cost_weight = np.array(range(args.test_number_begin, args.test_number_end + args.test_number_step, args.test_number_step))
cost_weight = cost_weight / cost_weight.sum()
cost_weight = cost_weight.cumsum()
cost_weight = np.insert(cost_weight, 0, 0.0)

logger = get_logger(__file__)


class Model(nn.Module):
    """The model for blocks world tasks."""

    def __init__(self):
        super().__init__()

        self.transform = InputTransform('cmp', exclude_self=False)

        # The 4 dimensions are: world_id, block_id, coord_x, coord_y
        if args.task == 'final':
            input_dim = 4
            # current_dim = 4 * 3 = 12
            current_dim = transformed_dim = self.transform.get_output_dim(input_dim)
            self.feature_axis = 1 if args.concat_worlds else 2
        elif args.task == 'stack':
            input_dim = 2
            current_dim = transformed_dim = self.transform.get_output_dim(input_dim)
            self.feature_axis = 2
        else:
            raise ()

        if args.model == 'memnet':
            self.feature = MemoryNet.from_args(
                current_dim, self.feature_axis, args, prefix='memnet')
            current_dim = self.feature.get_output_dim()
        elif args.model == 'nlm':
            input_dims = [0 for _ in range(args.nlm_breadth + 1)]
            input_dims[2] = current_dim
            self.features = LogicMachine.from_args(
                input_dims, args.nlm_attributes, args, prefix='nlm')
            current_dim = self.features.output_dims[self.feature_axis]
        elif args.model == 'dlm':
            input_dims = [0 for _ in range(args.nlm_breadth + 1)]
            input_dims[2] = current_dim
            self.features = DifferentiableLogicMachine.from_args(
                input_dims, args.nlm_attributes, args, prefix='nlm')
            current_dim = self.features.output_dims[self.feature_axis]
        else:
            raise ()

        self.final_transform = InputTransform('concat', exclude_self=False)
        if args.task == 'final':
            if args.concat_worlds:
                current_dim = (self.final_transform.get_output_dim(current_dim) +
                               transformed_dim) * 2

        if args.model == 'dlm':
            self.pred_valid = DLMInferenceBase(current_dim, 1, False, 'root_valid')
            self.pred = DLMInferenceBase(current_dim, 1, False, 'root')

            self.last_tau = args.last_tau_begin
            self.tau = args.tau_begin
            self.dropout_prob = args.dropout_prob_begin
            self.gumbel_prob = args.gumbel_noise_begin

            self.update_stoch()
        else:
            self.pred_valid = LogicInference(current_dim, 1, [])
            self.pred = LogitsInference(current_dim, 1, [])
        if args.reinforce_log:
            self.loss = REINFORCELogLoss()
        else:
            self.loss = REINFORCELoss()
        self.pred_loss = nn.BCELoss()

    def update_stoch(self):
        self.features.update_tau(self.tau)
        self.pred_valid.update_tau(self.tau)
        self.pred.update_tau(self.tau)

        self.features.update_gumbel_noise(self.gumbel_prob)
        self.pred_valid.update_gumbel_noise(self.gumbel_prob)
        self.pred.update_gumbel_noise(self.gumbel_prob)

        self.features.update_dropout_prob(self.dropout_prob)
        self.pred_valid.update_dropout_prob(self.dropout_prob)
        self.pred.update_dropout_prob(self.dropout_prob)

    def lowernoise(self):
        if args.model == 'dlm':
            self.pred.independant_noise_per_sample = False
            self.pred.with_gumbel = False
            self.pred.with_dropout = False

            self.pred_valid.independant_noise_per_sample = False
            self.pred_valid.with_gumbel = False
            self.pred_valid.with_dropout = False

            self.features.independant_noise_per_sample(False)
            self.features.with_gumbel(False)
            self.features.with_dropout(False)

    def restorenoise(self):
        if args.model == 'dlm':
            self.pred.independant_noise_per_sample = True
            self.pred.with_gumbel = True
            self.pred.with_dropout = True

            self.pred_valid.independant_noise_per_sample = True
            self.pred_valid.with_gumbel = True
            self.pred_valid.with_dropout = True

            self.features.independant_noise_per_sample(True)
            self.features.with_gumbel(True)
            self.features.with_dropout(True)

    def stoch_decay(self):
        if args.model == 'dlm':
            self.tau = self.tau * args.tau_decay
            self.last_tau = self.last_tau * args.last_tau_decay
            self.gumbel_prob = self.gumbel_prob * args.gumbel_noise_decay
            self.dropout_prob = self.dropout_prob * args.dropout_prob_decay

            self.update_stoch()

    def forward(self, feed_dict):
        feed_dict = GView(feed_dict)

        states = feed_dict.states.float()
        f = self.get_binary_relations(states)
        if args.model == 'dlm':
            logits = self.pred(f)[0].squeeze(dim=-1).view(states.size(0), -1)
            policy = F.softmax(logits / self.last_tau, dim=-1).clamp(min=1e-20)
        else:
            logits = self.pred(f).squeeze(dim=-1).view(states.size(0), -1)
            policy = F.softmax(logits, dim=-1).clamp(min=1e-20)

        if not self.training or 'pred_states' not in feed_dict.raw().keys():
            return dict(policy=policy, logits=logits)

        pred_states = feed_dict.pred_states.float()
        f = self.get_binary_relations(pred_states, depth=args.pred_depth)
        if args.model == 'dlm':
            f = self.pred_valid(f)[0].squeeze(dim=-1).view(pred_states.size(0), -1)
        else:
            f = self.pred_valid(f).squeeze(dim=-1).view(pred_states.size(0), -1)
        # Set minimal value to avoid loss to be nan.
        valid = f[range(pred_states.size(0)), feed_dict.pred_actions].clamp(min=1e-20)

        loss, monitors = self.loss(policy, feed_dict.actions,
                                   feed_dict.discount_rewards,
                                   feed_dict.entropy_beta)
        pred_loss = self.pred_loss(valid, feed_dict.valid)
        monitors['pred/accuracy'] = feed_dict.valid.eq((valid > 0.5).float()).float().mean()
        if args.model == 'dlm':
            monitors.update({'tau': np.array(self.tau)})
            monitors.update({'tau': np.array(self.tau)})
            monitors.update({'tau': np.array(self.tau)})
            monitors.update({'last_tau': np.array(self.last_tau)})
        loss = loss + args.pred_weight * pred_loss
        return loss, monitors, dict()

    def get_binary_relations(self, states, depth=None):
        """get binary relations given states, up to certain depth."""
        # total = 2 * the number of objects in each world
        total = states.size()[1]
        f = self.transform(states)
        if args.model == 'memnet':
            f = self.feature(f)
        else:
            inp = [None for i in range(args.nlm_breadth + 1)]
            inp[2] = f
            features = self.features(inp, depth=depth)
            if args.model == 'dlm':
                f = features[0][self.feature_axis]
            else:
                f = features[self.feature_axis]

        if args.task == 'final':
            assert total % 2 == 0
            nr_objects = total // 2
            if args.concat_worlds:
                # To concat the properties of blocks with the same id in both world.
                f = torch.cat([f[:, :nr_objects], f[:, nr_objects:]], dim=-1)
                states = torch.cat([states[:, :nr_objects], states[:, nr_objects:]], dim=-1)
                transformed_input = self.transform(states)
                # And perform a 'concat' transform to binary representation (relations).
                f = torch.cat([self.final_transform(f), transformed_input], dim=-1)
            else:
                f = f[:, :nr_objects, :nr_objects].contiguous()
        elif args.task == 'stack':
            nr_objects = total
            f = f[:, :nr_objects, :nr_objects].contiguous()
        else:
            raise ()

        f = meshgrid_exclude_self(f)
        return f


def make_data(traj, gamma):
    """Aggregate data as a batch for RL optimization."""
    q = 0
    discount_rewards = []
    for reward in traj['rewards'][::-1]:
        q = q * gamma + reward
        discount_rewards.append(q)
    discount_rewards.reverse()

    traj['states'] = as_tensor(np.array(traj['states']))
    traj['actions'] = as_tensor(np.array(traj['actions']))
    traj['discount_rewards'] = as_tensor(np.array(discount_rewards)).float()
    return traj


def run_episode(env,
                model,
                mode,
                number,
                play_name='',
                dump=False,
                dataset=None,
                eval_only=False,
                use_argmax=False,
                need_restart=False,
                entropy_beta=0.0):
    """Run one episode using the model with $number blocks."""
    is_over = False
    traj = collections.defaultdict(list)
    score = 0
    if need_restart:
        env.restart()
    nr_objects = number + 1
    # If dump_play=True, store the states and actions in a json file
    # for visualization.
    dump_play = args.dump_play and dump
    if dump_play:
        array = env.unwrapped.current_state
        moves, new_pos, policies = [], [], []

    if args.model == 'dlm':
        # by default network isn't in training mode during data collection
        # but with dlm we don't want to use argmax only
        model.train(True)

        if args.dlm_noise == 1 and mode in ['mining', 'inherit', 'test']:
            model.lowernoise()
        elif args.dlm_noise == 2:
            model.lowernoise()

    while not is_over:
        state = env.current_state
        feed_dict = dict(states=np.array([state]))
        feed_dict['entropy_beta'] = as_tensor(entropy_beta).float()
        feed_dict = as_tensor(feed_dict)

        with torch.set_grad_enabled(not eval_only):
            output_dict = model(feed_dict)
        policy = output_dict['policy']
        p = as_numpy(policy.data[0])
        action = p.argmax() if use_argmax else random.choice(len(p), p=p)
        # Need to ensure that the env.utils.MapActionProxy is the outermost class.
        mapped_x, mapped_y = env.mapping[action]
        # env.unwrapped to get the innermost Env class.
        valid = env.unwrapped.world.moveable(mapped_x, mapped_y)
        reward, is_over = env.action(action)
        if dump_play:
            moves.append([mapped_x, mapped_y])
            res = tuple(env.current_state[mapped_x][2:])
            new_pos.append((int(res[0]), int(res[1])))

            logits = as_numpy(output_dict['logits'].data[0])
            tops = np.argsort(p)[-10:][::-1]
            tops = list(
                map(lambda x: (env.mapping[x], float(p[x]), float(logits[x])), tops))
            policies.append(tops)

        # For now, assume reward=1 only when succeed, otherwise reward=0.
        # Manipulate the reward and get success information according to reward.
        if reward == 0 and args.penalty is not None:
            reward = args.penalty
        succ = 1 if is_over and reward > 0.99 else 0

        score += reward
        traj['states'].append(state)
        traj['rewards'].append(reward)
        traj['actions'].append(action)
        if not eval_only and dataset is not None and mapped_x != mapped_y:
            dataset.append(nr_objects, state, action, valid)

    # Dump json file as record of the playing.
    if dump_play and not (args.dump_fail_only and succ):
        array = array[:, 2:].astype('int32').tolist()
        array = [array[:nr_objects], array[nr_objects:]]
        json_str = json.dumps(
            # Let indent=True for an indented view of json files.
            dict(array=array, moves=moves, new_pos=new_pos,
                 policies=policies))
        dump_file = os.path.join(
            args.current_dump_dir,
            '{}_blocks{}.json'.format(play_name, env.unwrapped.nr_blocks))
        with open(dump_file, 'w') as f:
            f.write(json_str)

    length = len(traj['rewards'])

    if args.model == 'dlm':
        model.restorenoise()

    return succ, score, traj, length


class MyTrainer(MiningTrainerBase):
    def save_checkpoint(self, name):
        if args.checkpoints_dir is not None:
            checkpoint_file = os.path.join(args.checkpoints_dir,
                                           'checkpoint_{}.pth'.format(name))
            super().save_checkpoint(checkpoint_file)

    def _dump_meters(self, meters, mode):
        if args.summary_file is not None:
            meters_kv = meters._canonize_values('avg')
            meters_kv['mode'] = mode
            meters_kv['epoch'] = self.current_epoch
            with open(args.summary_file, 'a') as f:
                f.write(io.dumps_json(meters_kv))
                f.write('\n')

    def _prepare_dataset(self, epoch_size, mode):
        pass

    def _get_player(self, number, mode):
        player = make_env(args.task, number)
        player.restart()
        return player

    def _get_result_given_player(self, index, meters, number, player, mode):
        assert mode in ['train', 'test', 'mining', 'inherit']
        params = dict(
            eval_only=True,
            number=number,
            play_name='{}_epoch{}_episode{}'.format(mode, self.current_epoch, index))
        backup = None
        if mode == 'train':
            params['eval_only'] = False
            params['dataset'] = self.valid_action_dataset
            params['entropy_beta'] = self.entropy_beta
            meters.update(lr=self.lr, entropy_beta=self.entropy_beta)
        elif mode == 'test':
            params['dump'] = True
            params['use_argmax'] = True
        else:
            backup = copy.deepcopy(player)
            params['use_argmax'] = self.is_candidate

        if mode == 'train':
            if args.use_gpu:
                self.model.cpu()

            mergedfc = []
            for i in range(args.ntrajectory):
                succ, score, traj, length = run_episode(player, self.model, mode, need_restart=(i!=0), **params)
                meters.update(number=number, succ=succ, score=score, length=length)

                feed_dict = make_data(traj, args.gamma)
                feed_dict['entropy_beta'] = as_tensor(self.entropy_beta).float()
                mergedfc.append(feed_dict)

            for k in feed_dict.keys():
                if k not in ["rewards", "entropy_beta"]:  # reward not used to update loss
                    feed_dict[k] = torch.cat([j[k] for j in mergedfc], dim=0)
            feed_dict['entropy_beta'] = mergedfc[0]['entropy_beta']  # TODO: I didn't merge it because it's always the same from the argument

            # content from valid_move dataset
            states, actions, labels = self.valid_action_dataset.sample_batch(args.batch_size)
            feed_dict['pred_states'] = as_tensor(states)
            feed_dict['pred_actions'] = as_tensor(actions)
            feed_dict['valid'] = as_tensor(labels).float()

            if args.norm_rewards and feed_dict['discount_rewards'].shape[0] > 1:
                feed_dict['discount_rewards'] = (feed_dict['discount_rewards'] - feed_dict['discount_rewards'].mean()) / (feed_dict['discount_rewards'].std() + 10 ** -7)

            if args.use_gpu:
                feed_dict = as_cuda(feed_dict)
                self.model.cuda()
            return feed_dict
        else:
            if args.use_gpu:
                self.model.cpu()
            succ, score, traj, length = run_episode(player, self.model, mode, **params)
            meters.update(number=number, succ=succ, score=score, length=length)

            message = ('> {} iter={iter}, number={number}, succ={succ}, '
                       'score={score:.4f}, length={length}').format(
                mode, iter=index, **meters.val)
            return message, dict(succ=succ, number=number, backup=backup)

    def _extract_info(self, extra):
        return extra['succ'], extra['number'], extra['backup']

    def _get_accuracy(self, meters):
        return meters.avg['succ']

    def _get_threshold(self):
        candidate_relax = 0 if self.is_candidate else args.candidate_relax
        return super()._get_threshold() - \
               self.curriculum_thresh_relax * candidate_relax

    def _upgrade_lesson(self):
        super()._upgrade_lesson()
        # Adjust lr & entropy_beta w.r.t different lesson progressively.
        self.lr *= args.lr_decay
        self.entropy_beta *= args.entropy_beta_decay
        self.set_learning_rate(self.lr)

    def _train_epoch(self, epoch_size):
        meters = super()._train_epoch(epoch_size)
        self.model.stoch_decay()

        i = self.current_epoch
        if args.save_interval is not None and i % args.save_interval == 0:
            self.save_checkpoint(str(i))
        if args.test_interval is not None and i % args.test_interval == 0:
            score = self.test()
            cost = np.array([score[i].avg['succ'] for i in range(len(score))])
            if (cost > 0.2).any():
                ind = (cost > 0.2).nonzero()[0][-1]
                cost = cost[ind]
                self.trial.report((cost_weight[ind + 1] - cost_weight[ind]) * cost + cost_weight[ind], i)
            else:
                self.trial.report(0.0, i)
            if self.trial.should_prune():
                raise optuna.exceptions.TrialPruned()

        return meters

    def _early_stop(self, meters):
      t = args.early_drop_epochs
      if t is not None and self.current_epoch > t * (self.nr_upgrades + 1):
        return True
      return super()._early_stop(meters)

    def train(self):
        self.valid_action_dataset = ValidActionDataset()
        self.lr = args.lr
        self.entropy_beta = args.entropy_beta
        return super().train()


def main(run_id, run_str, trial):
    if args.dump_dir is not None:
        io.mkdir(args.dump_dir)
    else:
        args.checkpoints_dir = None
        args.summary_file = None

    if args.dump_dir is not None:
        args.current_dump_dir = os.path.join(args.dump_dir, 'run_{}_{}'.format(run_id, run_str))
        io.mkdir(args.current_dump_dir)
        args.checkpoints_dir = os.path.join(args.current_dump_dir, 'checkpoints')
        io.mkdir(args.checkpoints_dir)
        args.summary_file = os.path.join(args.current_dump_dir, 'summary.json')

        args.log_file = os.path.join(args.current_dump_dir, 'log.log')
        set_output_file(args.log_file)

    #args.epoch_size = args.epoch_size // args.ntrajectory
    logger.info(format_args(args))

    model = Model()
    optimizer = get_optimizer(args.optimizer, model, args.lr)

    trainer = MyTrainer.from_args(model, optimizer, args)
    trainer.trial = trial

    if args.load_checkpoint is not None:
        trainer.load_checkpoint(args.load_checkpoint)

    if args.test_only:
        trainer.current_epoch = 0
        return None, trainer.test()

    graduated = trainer.train()
    trainer.save_checkpoint('last')
    test_meters = trainer.test()
    return graduated, test_meters


if __name__ == '__main__':
    n_startup_trials = 50
    seed = None
    sampler = TPESampler(n_startup_trials=n_startup_trials, seed=seed)
    pruner = MedianPruner()  # less aggressive pruning
    study = optuna.create_study(sampler=sampler, pruner=pruner, study_name=args.study_name,
                                direction="maximize", storage=os.getenv('SQLURI'), load_if_exists=True)

    def sample_hyperparams(trial):
#        nlm_residual = trial.suggest_categorical('nlm_residual', [True, False])
        nlm_residual = False
        if nlm_residual == False:
            nlm_attributes = trial.suggest_categorical('nlm_attributes', [8, 16])
#            nlm_attributes = 8
            concat_worlds = False
        else:
            nlm_attributes = 8
            concat_worlds = trial.suggest_categorical('concat_worlds', [True, False])
#        lr_decay = trial.suggest_categorical('lr_decay', [0.9, 1.0])
        pred_weight = trial.suggest_categorical('pred_weight', [0.0, 0.1])
        entropy_beta = trial.suggest_categorical('entropy_beta', [0.0, 0.2])
#        entropy_beta_decay = trial.suggest_categorical('entropy_beta_decay', [1.0, 0.8])
        dlm_noise = trial.suggest_categorical('dlm_noise', [0, 1, 2])
        curriculum_thresh = trial.suggest_uniform('curriculum_thresh', 0.75, 0.99)
        gumbel_noise_decay = trial.suggest_categorical('gumbel_noise_decay', [0.998, 0.999, 1.0])
#        dropout_prob_decay = trial.suggest_categorical('dropout_prob_decay', [0.999, 1.0])
#        tau_decay = trial.suggest_categorical('tau_decay', [0.99, 0.995, 0.999, 1.0])
#        last_tau_decay = trial.suggest_categorical('last_tau_decay', [0.99, 0.995, 0.999, 1.0])
        gumbel_noise_begin = trial.suggest_uniform('gumbel_noise_begin', 0.005, 0.2)
        dropout_prob_begin = trial.suggest_uniform('dropout_prob_begin', 0.001, 0.05)
        tau_begin = trial.suggest_uniform('tau_begin', 0.01, 1.0)
        last_tau_begin = trial.suggest_loguniform('last_tau_begin', 0.005, 1.0)
        ntrajectory = trial.suggest_categorical('ntrajectory', [3, 5, 10])
        if pred_weight != 0.0:
            batch_size = trial.suggest_categorical('batch_size', [4, 8, 16, 32])
        else:
            batch_size = 4
#        norm_rewards = trial.suggest_categorical('norm_rewards', [True, False])
        norm_rewards = True
#        reinforce_log = trial.suggest_categorical('reinforce_log', [True, False])
        reinforce_log = True

#        tau_begin=0.1
        lr_decay=1.0
        entropy_beta_decay=0.8
        tau_decay=1.0
        last_tau_decay=1.0
        dropout_prob_decay=1.0
        sample_array_capacity=1#don't learn on several number of blocks when using several trajec.

        return {
            'nlm_attributes': nlm_attributes,
            'nlm_residual': nlm_residual,
            'concat_worlds': concat_worlds,
            'lr_decay': lr_decay,
            'pred_weight': pred_weight,
            'entropy_beta': entropy_beta,
            'entropy_beta_decay': entropy_beta_decay,
            'dlm_noise': dlm_noise,
            'curriculum_thresh': curriculum_thresh,
            'gumbel_noise_decay': gumbel_noise_decay,
            'dropout_prob_decay': dropout_prob_decay,
            'tau_decay': tau_decay,
            'last_tau_decay': last_tau_decay,
            'gumbel_noise_begin': gumbel_noise_begin,
            'dropout_prob_begin': dropout_prob_begin,
            'tau_begin': tau_begin,
            'last_tau_begin': last_tau_begin,
            'ntrajectory': ntrajectory,
            'batch_size': batch_size,
            'norm_rewards': norm_rewards,
            'reinforce_log': reinforce_log,
            'sample_array_capacity': sample_array_capacity,
        }

    def objective(trial):
        global args

        stats = []
        nr_graduated = 0
        i = 0

        hparams = vars(main_args)
        sampled_hparams = sample_hyperparams(trial)
        print()
        print("Sample new hyperparameters:")
        print(sampled_hparams)
        print()
        hparams.update(sampled_hparams)
        args = Namespace(**hparams)

        graduated, test_meters = main(trial._trial_id, '_'.join([str(s) for s in sampled_hparams.values()]), trial)
        logger.info('run {}'.format(trial._trial_id))

        if test_meters is not None:
            for j, meters in enumerate(test_meters):
                if len(stats) <= j:
                    stats.append(GroupMeters())
                stats[j].update(number=meters.avg['number'], test_succ=meters.avg['succ'])

            for meters in stats:
                logger.info('number {}, test_succ {}'.format(meters.avg['number'], meters.avg['test_succ']))

        if not args.test_only:
            nr_graduated += int(graduated)
            if graduated:
                for j, meters in enumerate(test_meters):
                    stats[j].update(grad_test_succ=meters.avg['succ'])
            if nr_graduated > 0:
                for meters in stats:
                    logger.info('number {}, grad_test_succ {}'.format(meters.avg['number'], meters.avg['grad_test_succ']))

        cost = np.array([stats[i].avg['test_succ'] for i in range(len(stats))])
        if (cost > 0.2).any():
            ind = (cost > 0.2).nonzero()[0][-1]
            cost = cost[ind]
            return (cost_weight[ind + 1] - cost_weight[ind]) * cost + cost_weight[ind]
        else:
            return 0.0

    try:
        study.optimize(objective, n_trials=args.n_trials, n_jobs=1)
    except KeyboardInterrupt:
        pass

    print('Number of finished trials: ', len(study.trials))

    print('Best trial:')
    trial = study.best_trial

    print('Value: ', trial.value)

    print('Params: ')
    for key, value in trial.params.items():
        print('    {}: {}'.format(key, value))

    study.trials_dataframe()
