import collections
import copy
import functools
import json
import os

import numpy as np
import optuna
import torch
import torch.nn as nn
import torch.nn.functional as F

from optuna.pruners import MedianPruner, SuccessiveHalvingPruner
from optuna.samplers import TPESampler
from argparse import Namespace

import jacinle.random as random
import jacinle.io as io

from keras.utils.np_utils import to_categorical
from difflogic.cli import format_args
from difflogic.dataset.utils import ValidActionDataset
from difflogic.envs.blocksworld import make as make_env
from difflogic.nn.baselines import MemoryNet
from difflogic.nn.critic import *
from difflogic.nn.neural_logic import InputTransform
from difflogic.nn.neural_logic import LogicInference
from difflogic.nn.neural_logic import LogicMachine
from difflogic.nn.neural_logic import LogitsInference
from difflogic.nn.neural_logic.modules._utils import meshgrid_exclude_self
from difflogic.nn.dlm.layer import DifferentiableLogicMachine
from difflogic.nn.dlm.neural_logic import DLMInferenceBase
from difflogic.nn.rl.ppo import PPOLoss
from difflogic.train import MiningTrainerBase

from jacinle.cli.argument import JacArgumentParser
from jacinle.logging import get_logger
from jacinle.logging import set_output_file
from jacinle.utils.container import GView
from jacinle.utils.meter import GroupMeters
from jactorch.optim.accum_grad import AccumGrad
from jactorch.optim.quickaccess import get_optimizer
from jactorch.utils.meta import as_cuda
from jactorch.utils.meta import as_numpy
from jactorch.utils.meta import as_tensor
from difflogic.tqdm_utils import tqdm_for

TASKS = ['final', 'stack']

parser = JacArgumentParser()

parser.add_argument(
    '--model',
    default='nlm',
    choices=['nlm', 'memnet', 'dlm'],
    help='model choices, nlm: Neural Logic Machine, memnet: Memory Networks, dlm: Differentiable Logic Machine')

# NLM parameters, works when model is 'nlm'.
nlm_group = parser.add_argument_group('Neural Logic Machines')
LogicMachine.make_nlm_parser(
    nlm_group, {
        'depth': 7,
        'breadth': 2,
        'residual': True,
        'exclude_self': True,
        'logic_hidden_dim': []
    },
    prefix='nlm')
nlm_group.add_argument(
    '--nlm-attributes',
    type=int,
    default=8,
    metavar='N',
    help='number of output attributes in each group of each layer of the LogicMachine'
)
nlmcritic_group = parser.add_argument_group('Neural Logic Machines critic')
LogicMachine.make_nlm_parser(
    nlmcritic_group, {
        'depth': 4,
        'breadth': 2,
        'residual': True,
        'exclude_self': True,
        'logic_hidden_dim': []
    },
    prefix='nlmcrit')
nlmcritic_group.add_argument(
    '--nlm-attributes-critic',
    type=int,
    default=2,
    metavar='N',
    help='number of output attributes in each group of each layer of the critic LogicMachine'
)

# Optuna parameters
optuna_group = parser.add_argument_group('Optuna')
optuna_group.add_argument(
    '--study-name',
    default='study',
    help='name of the study in the database')
optuna_group.add_argument(
    '--n-trials',
    default='1',
    type=int,
    help='number of hyperparameters sampled')

# MemNN parameters, works when model is 'memnet'.
memnet_group = parser.add_argument_group('Memory Networks')
MemoryNet.make_memnet_parser(memnet_group, {}, prefix='memnet')

parser.add_argument(
    '--task', required=True, choices=TASKS, help='tasks choices')

method_group = parser.add_argument_group('Method')
method_group.add_argument(
    '--concat-worlds',
    type=bool,
    default=True,
    help='concat the features of objects of same id among two worlds accordingly'
)
method_group.add_argument(
    '--pred-depth',
    type=int,
    default=None,
    metavar='N',
    help='the depth of nlm used for prediction task')
method_group.add_argument(
    '--pred-weight',
    type=float,
    default=0.1,
    metavar='F',
    help='the linear scaling factor for prediction task')

MiningTrainerBase.make_trainer_parser(
    parser, {
        'epochs': 500,
        'epoch_size': 100,
        'test_epoch_size': 10,
        'test_number_begin': 3,
        'test_number_step': 1,
        'test_number_end': 13,
        'curriculum_start': 2,
        'curriculum_step': 1,
        'curriculum_graduate': 12,
        'curriculum_thresh_relax': 0.005,
        'curriculum_thresh': 1,
        'enable_mining': True,
        'mining_interval': 10,
        'mining_epoch_size': 1000,
        'mining_dataset_size': 100,
        # major edit for dlm
#        'sample_array_capacity': 1,
#        'inherit_neg_data': False,
#        'disable_balanced_sample': True,
        'sample_array_capacity': 3,
        'inherit_neg_data': True,
        'disable_balanced_sample': False,
        'prob_pos_data': 0.6
    })

train_group = parser.add_argument_group('Train')
train_group.add_argument('--seed', type=int, default=None, metavar='SEED')
train_group.add_argument(
    '--use-gpu', action='store_true', help='use GPU or not')
train_group.add_argument(
    '--optimizer',
    default='AdamW',
    choices=['SGD', 'Adam', 'AdamW'],
    help='optimizer choices')
train_group.add_argument(
    '--lr',
    type=float,
    default=0.005,
    metavar='F',
    help='initial learning rate')
train_group.add_argument(
    '--lr-decay',
    type=float,
    default=0.9,
    metavar='F',
    help='exponential decay of learning rate per lesson')
train_group.add_argument(
    '--accum-grad',
    type=int,
    default=1,
    metavar='N',
    help='accumulated gradient (default: 1)')
train_group.add_argument(
    '--batch-size',
    type=int,
    default=4,
    metavar='N',
    help='batch size for extra prediction')
train_group.add_argument(
    '--candidate-relax',
    type=int,
    default=0,
    metavar='N',
    help='number of thresh relaxation for candidate')

rl_group = parser.add_argument_group('Reinforcement Learning')
rl_group.add_argument(
    '--gamma',
    type=float,
    default=0.99,
    metavar='F',
    help='discount factor for accumulated reward function in reinforcement learning'
)
rl_group.add_argument(
    '--penalty',
    type=float,
    default=-0.01,
    metavar='F',
    help='a small penalty each step')
rl_group.add_argument(
    '--entropy-beta',
    type=float,
    default=0.2,
    metavar='F',
    help='entropy loss scaling factor')
rl_group.add_argument(
    '--entropy-beta-decay',
    type=float,
    default=0.8,
    metavar='F',
    help='entropy beta exponential decay factor')
rl_group.add_argument(
    '--critic-type',
    type=int,
    default=5)
rl_group.add_argument(
    '--noptepochs',
    type=int,
    default=4
)
rl_group.add_argument(
    '--epsilon',
    type=float,
    default=0.2
)
rl_group.add_argument(
    '--ntrajectory',
    type=int,
    default=2
)
rl_group.add_argument(
    '--lam',
    type=float,
    default=0.97
)
rl_group.add_argument(
    '--clip-vf',
    type=float,
    default=0.0
)
rl_group.add_argument(
    '--no-shuffle-minibatch',
    type = bool,
    default = False
    # action='store_true'
    )
rl_group.add_argument(
    '--no-adv-norm',
    type = bool,
    default = False
    # action='store_true'
    )

io_group = parser.add_argument_group('Input/Output')
io_group.add_argument(
    '--dump-dir', default=None, metavar='DIR', help='dump dir')
io_group.add_argument(
    '--dump-play',
    action='store_true',
    help='dump the trajectory of the plays for visualization')
io_group.add_argument(
    '--dump-fail-only', action='store_true', help='dump failure cases only')
io_group.add_argument(
    '--load-checkpoint',
    default=None,
    metavar='FILE',
    help='load parameters from checkpoint')

schedule_group = parser.add_argument_group('Schedule')
schedule_group.add_argument(
    '--early-drop-epochs',
    type=int,
    default=100,
    metavar='N',
    help='epochs could spend for each lesson, early drop')
schedule_group.add_argument(
    '--save-interval',
    type=int,
    default=500,
    metavar='N',
    help='the interval(number of epochs) to save checkpoint')
schedule_group.add_argument(
    '--test-interval',
    type=int,
    default=25,
    metavar='N',
    help='the interval(number of epochs) to do test')
schedule_group.add_argument(
    '--test-only', action='store_true', help='test-only mode')

schedule_group.add_argument(
    '--extract-path',
    action='store_true')

# args parsing
args = parser.parse_args()

args.use_gpu = args.use_gpu and torch.cuda.is_available()
args.dump_play = args.dump_play and (args.dump_dir is not None)

if args.seed is not None:
    random.reset_global_seed(args.seed)

main_args = copy.deepcopy(args)

make_env = functools.partial(make_env, random_order=True, exclude_self=True, fix_ground=True)

cost_weight = np.array(range(args.test_number_begin, args.test_number_end + args.test_number_step, args.test_number_step))
cost_weight = cost_weight / cost_weight.sum()
#cost_weight = cost_weight.cumsum()
#cost_weight = np.insert(cost_weight, 0, 0.0)

logger = get_logger(__file__)


class Model(nn.Module):
    """The model for blocks world tasks."""

    def __init__(self):
        super().__init__()
        self.transform = InputTransform('cmp', exclude_self=False)

        # The 4 dimensions are: world_id, block_id, coord_x, coord_y
        if args.task == 'final':
            input_dim = 4
            # current_dim = 4 * 3 = 12
            current_dim = transformed_dim = self.transform.get_output_dim(input_dim)
            self.feature_axis = 1 if args.concat_worlds else 2
        elif args.task == 'stack':
            input_dim = 2
            current_dim = transformed_dim = self.transform.get_output_dim(input_dim)
            self.feature_axis = 2
        else:
            raise ()

        if args.model == 'memnet':
            self.feature = MemoryNet.from_args(
                current_dim, self.feature_axis, args, prefix='memnet')
            current_dim = self.feature.get_output_dim()
        elif args.model == 'nlm':
            input_dims = [0 for _ in range(args.nlm_breadth + 1)]
            input_dims[2] = current_dim
            self.features = LogicMachine.from_args(
                input_dims, args.nlm_attributes, args, prefix='nlm')
            current_dim = self.features.output_dims[self.feature_axis]
        elif args.model == 'dlm':
            input_dims = [0 for _ in range(args.nlm_breadth + 1)]
            input_dims[2] = current_dim
            self.features = DifferentiableLogicMachine.from_args(
                input_dims, args.nlm_attributes, args, prefix='nlm')
            current_dim = self.features.output_dims[self.feature_axis]

        self.final_transform = InputTransform('concat', exclude_self=False)
        if args.task == 'final':
            if args.concat_worlds:
                current_dim = (self.final_transform.get_output_dim(current_dim) +
                               transformed_dim) * 2

        if args.model == 'dlm':
            self.pred_valid = DLMInferenceBase(current_dim, 1, False, 'root_valid')
            self.pred = DLMInferenceBase(current_dim, 1, False, 'root')
        else:
            self.pred_valid = LogicInference(current_dim, 1, [])
            self.pred = LogitsInference(current_dim, 1, [])
        self.loss = PPOLoss()
        self.pred_loss = nn.BCELoss()

        range_dims=[]
        for i in range(args.curriculum_start, args.curriculum_graduate + args.curriculum_step, args.curriculum_step):
            if args.task == "final":
                range_dims.append([(i+1)*2, (i+1)*2, transformed_dim])
            elif args.task == "stack":
                range_dims.append([(i+1), (i+1), transformed_dim])

        if args.task == "final":
           n_action=lambda x: (x[0]//2)*((x[0]//2)-1)
        elif args.task == "stack":
            n_action=lambda x: (x[0])*(x[0]-1)

        self.isQnet = True
        if args.critic_type == 0:
            self.critic = InvariantNObject(MLPCritic, range_dims, dict())
            self.isQnet = False
        elif args.critic_type == 1:
            self.critic = GRUCritic(transformed_dim)
            self.isQnet = False
        elif args.critic_type == 2:
            self.critic = GRUCritic(transformed_dim, shuffle_index=True)
            self.isQnet = False
        elif args.critic_type == 3:
            nlmcrit = LogicMachine.from_args(input_dims, args.nlm_attributes_critic, args, prefix='nlmcrit')
            self.critic = InvariantNObject(NLMMLPCritic, range_dims, dict(nlm=nlmcrit, feature_axis=None if args.model != 'dlm' else 0))
            self.isQnet = False
        elif args.critic_type == 4:
            self.critic = InvariantNObject(ConvCritic, range_dims, dict(input_channel=transformed_dim))
            self.isQnet = False
        elif args.critic_type == 5:
            self.critic = MixedGRUCritic(transformed_dim)
            self.isQnet = False
        elif args.critic_type == 6:
            self.critic = InvariantNObject(MLPCriticQ, range_dims, dict(n_action_func=n_action))
        elif args.critic_type == 7:
            nlmcrit = LogicMachine.from_args(input_dims, args.nlm_attributes_critic, args, prefix='nlmcrit')
            self.critic = InvariantNObject(NLMMLPCriticQ, range_dims, dict(nlm=nlmcrit, n_action_func=n_action, feature_axis=None if args.model != 'dlm' else 0))
        elif args.critic_type == 8:
            self.critic = InvariantNObject(ConvReduceCritic, range_dims, dict())
            self.isQnet = False
        elif args.critic_type == 9:
            self.critic = InvariantNObject(ConvReduceCriticQ, range_dims, dict(n_action_func=n_action))
        else:
            print("unkown critic_type")
            quit()
        self.critic_loss = nn.MSELoss()

    def forward(self, feed_dict):
        feed_dict = GView(feed_dict)

        states = feed_dict.states.float()
        f, fs = self.get_binary_relations(states)
        if args.model == 'dlm':
            logits = self.pred(f)[0].squeeze(dim=-1).view(states.size(0), -1)
        else:
            logits = self.pred(f).squeeze(dim=-1).view(states.size(0), -1)
        policy = F.softmax(logits, dim=-1).clamp(min=1e-20)

        if feed_dict.eval_critic:
            if self.isQnet:
                crit_out = self.critic(fs)
                approxr = (crit_out*policy).sum(-1)
            else:
                approxr = crit_out = self.critic(fs)
        else:
            crit_out = None

        if not self.training:
            return dict(policy=policy, logits=logits, value=crit_out)

        loss, monitors = self.loss(policy, feed_dict.old_policy, feed_dict.actions,
                                   feed_dict.advantages, args.epsilon,
                                   feed_dict.entropy_beta)
        if self.isQnet:
            crit_preloss = (crit_out * feed_dict.actions_ohe).sum(-1)
        else:
            crit_preloss = crit_out
        if args.clip_vf is not None and args.clip_vf > 0.0:
            if self.isQnet:
                previous_val = (feed_dict.values*feed_dict.actions_ohe).sum(-1)
                crit_preloss = previous_val + torch.clamp(crit_preloss - previous_val, -args.clip_vf, args.clip_vf)
            else:
                crit_preloss = feed_dict.values + torch.clamp(crit_preloss - feed_dict.values, -args.clip_vf, args.clip_vf)
        losscrit = self.critic_loss(crit_preloss, feed_dict.returns)
        monitors['critic_accuracy'] = losscrit
        loss += losscrit

        pred_states = feed_dict.pred_states.float()
        f, _ = self.get_binary_relations(pred_states, depth=args.pred_depth)
        if args.model == 'dlm':
            f = self.pred_valid(f)[0].squeeze(dim=-1).view(pred_states.size(0), -1)
        else:
            f = self.pred_valid(f).squeeze(dim=-1).view(pred_states.size(0), -1)
        valid = f[range(pred_states.size(0)), feed_dict.pred_actions].clamp(min=1e-20)
        pred_loss = self.pred_loss(valid, feed_dict.valid)
        monitors['pred/accuracy'] = feed_dict.valid.eq((valid > 0.5).float()).float().mean()
        loss = loss + args.pred_weight * pred_loss
        return loss, monitors, dict()

    def get_binary_relations(self, states, depth=None):
        """get binary relations given states, up to certain depth."""
        # total = 2 * the number of objects in each world
        total = states.size()[1]

        f = self.transform(states)
        fs = f
        if args.model == 'memnet':
            f = self.feature(f)
        else:
            inp = [None for i in range(args.nlm_breadth + 1)]
            inp[2] = f
            features = self.features(inp, depth=depth)
            if args.model == 'dlm':
                f = features[0][self.feature_axis]
            else:
                f = features[self.feature_axis]

        if args.task == 'final':
            assert total % 2 == 0
            nr_objects = total // 2
            if args.concat_worlds:
                # To concat the properties of blocks with the same id in both world.
                f = torch.cat([f[:, :nr_objects], f[:, nr_objects:]], dim=-1)
                states = torch.cat([states[:, :nr_objects], states[:, nr_objects:]], dim=-1)
                transformed_input = self.transform(states)
                # And perform a 'concat' transform to binary representation (relations).
                f = torch.cat([self.final_transform(f), transformed_input], dim=-1)
            else:
                f = f[:, :nr_objects, :nr_objects].contiguous()
        elif args.task == 'stack':
            nr_objects = total
            f = f[:, :nr_objects, :nr_objects].contiguous()
        else:
            raise ()

        f = meshgrid_exclude_self(f)
        return f, fs


def make_data(traj, gamma, succ, last_next_value, lam, isQnet):
    """Aggregate data as a batch for RL optimization."""
    #  q = 0
    #  discount_rewards = []
    #  for reward in traj['rewards'][::-1]:
    #    q = q * gamma + reward
    #    discount_rewards.append(q)
    #  discount_rewards.reverse()
    #  traj['discount_rewards'] = as_tensor(np.array(discount_rewards)).float()

    traj['actions'] = as_tensor(np.array(traj['actions']))
    traj['actions_ohe'] = as_tensor(np.array(traj['actions_ohe']))
    traj['states'] = torch.cat(traj['states'], dim=0)
    traj['values'] = torch.cat(traj['values'], dim=0)
    traj['old_policy'] = torch.cat(traj['old_policy'], dim=0)
    traj['advantages'] = torch.zeros(traj['states'].shape[0])
    last_gae_lam = 0  # the next state of the last state in traj is always the terminal state.
    for step in reversed(range(len(traj['values']))):
        if step == len(traj['values']) - 1:
            next_values = 0 if succ else last_next_value
        elif isQnet:
            next_values = (traj['values'][step + 1] * traj['old_policy'][step + 1].cpu()).sum(-1)
        else:
            next_values = traj['values'][step + 1]
        if isQnet:
            delta = traj['rewards'][step] + gamma * next_values - (traj['values'][step]*traj['actions_ohe'][step]).sum(-1)
        else:
            delta = traj['rewards'][step] + gamma * next_values - traj['values'][step]
        last_gae_lam = delta + gamma * lam * last_gae_lam
        traj['advantages'][step] = last_gae_lam

    if isQnet:
        traj['returns'] = traj['advantages'] + (traj['values'] * traj['actions_ohe']).sum(-1)
    else:
        traj['returns'] = traj['advantages'] + traj['values']

    return traj


def run_episode(env,
                model,
                number,
                play_name='',
                dump=False,
                dataset=None,
                eval_only=False,
                use_argmax=False,
                need_restart=False,
                entropy_beta=0.0,
                eval_critic=False):
    """Run one episode using the model with $number blocks."""
    is_over = False
    traj = collections.defaultdict(list)
    score = 0
    if need_restart:
        env.restart()
    nr_objects = number + 1
    # If dump_play=True, store the states and actions in a json file
    # for visualization.
    dump_play = args.dump_play and dump
    if dump_play:
        array = env.unwrapped.current_state
        moves, new_pos, policies = [], [], []

    while not is_over:
        state = env.current_state
        feed_dict = dict(states=np.array([state]))
        feed_dict['entropy_beta'] = as_tensor(entropy_beta).float()
        feed_dict['eval_critic'] = as_tensor(eval_critic)
        feed_dict = as_tensor(feed_dict)

        with torch.set_grad_enabled(False):
            output_dict = model(feed_dict)
        policy = output_dict['policy']
        p = as_numpy(policy.data[0])
        action = p.argmax() if use_argmax else random.choice(len(p), p=p)
        # Need to ensure that the env.utils.MapActionProxy is the outermost class.
        mapped_x, mapped_y = env.mapping[action]
        # env.unwrapped to get the innermost Env class.
        valid = env.unwrapped.world.moveable(mapped_x, mapped_y)
        reward, is_over = env.action(action)
        if dump_play:
            moves.append([mapped_x, mapped_y])
            res = tuple(env.current_state[mapped_x][2:])
            new_pos.append((int(res[0]), int(res[1])))

            logits = as_numpy(output_dict['logits'].data[0])
            tops = np.argsort(p)[-10:][::-1]
            tops = list(
                map(lambda x: (env.mapping[x], float(p[x]), float(logits[x])), tops))
            policies.append(tops)
        # For now, assume reward=1 only when succeed, otherwise reward=0.
        # Manipulate the reward and get success information according to reward.
        if reward == 0 and args.penalty is not None:
            reward = args.penalty
        succ = 1 if is_over and reward > 0.99 else 0

        score += reward
        if not eval_only:
            traj['values'].append(output_dict['value'].detach().cpu())
            traj['states'].append(feed_dict['states'].detach().cpu())
            traj['rewards'].append(reward)
            traj['actions'].append(action)
            traj['actions_ohe'].append(to_categorical(action, num_classes=policy.shape[1]))
            traj['old_policy'].append(policy.detach().cpu())
        if not eval_only and dataset is not None and mapped_x != mapped_y:
            dataset.append(nr_objects, state, action, valid)

    if eval_critic:
        state = env.current_state
        feed_dict = dict(states=np.array([state]))
        feed_dict['entropy_beta'] = as_tensor(entropy_beta).float()
        feed_dict['eval_critic'] = as_tensor(eval_critic)
        feed_dict = as_tensor(feed_dict)

        with torch.set_grad_enabled(False):
            output_dict = model(feed_dict)
        if model.isQnet:
            last_next_value = (output_dict['value'].detach() * output_dict['policy'].detach()).sum(-1).cpu().numpy()
        else:
            last_next_value = output_dict['value'].detach().cpu().numpy()
    else:
        last_next_value = None

    # Dump json file as record of the playing.
    if dump_play and not (args.dump_fail_only and succ):
        array = array[:, 2:].astype('int32').tolist()
        array = [array[:nr_objects], array[nr_objects:]]
        json_str = json.dumps(
            # Let indent=True for an indented view of json files.
            dict(array=array, moves=moves, new_pos=new_pos,
                 policies=policies))
        dump_file = os.path.join(
            args.current_dump_dir,
            '{}_blocks{}.json'.format(play_name, env.unwrapped.nr_blocks))
        with open(dump_file, 'w') as f:
            f.write(json_str)

    length = len(traj['rewards'])
    return succ, score, traj, length, last_next_value


class MyTrainer(MiningTrainerBase):
    def save_checkpoint(self, name):
        if args.checkpoints_dir is not None:
            checkpoint_file = os.path.join(args.checkpoints_dir,
                                           'checkpoint_{}.pth'.format(name))
            super().save_checkpoint(checkpoint_file)

    def _dump_meters(self, meters, mode):
        if args.summary_file is not None:
            meters_kv = meters._canonize_values('avg')
            meters_kv['mode'] = mode
            meters_kv['epoch'] = self.current_epoch
            with open(args.summary_file, 'a') as f:
                f.write(io.dumps_json(meters_kv))
                f.write('\n')

    def _prepare_dataset(self, epoch_size, mode):
        pass

    def _get_player(self, number, mode):
        player = make_env(args.task, number)
        player.restart()
        return player

    def _get_result_given_player(self, index, meters, number, player, mode):
        assert mode in ['train', 'test', 'mining', 'inherit']
        params = dict(
            eval_only=True,
            eval_critic=False,
            number=number,
            play_name='{}_epoch{}_episode{}'.format(mode, self.current_epoch, index))
        backup = None
        if mode == 'train':
            params['eval_only'] = False
            params['eval_critic'] = True
            params['dataset'] = self.valid_action_dataset
            params['entropy_beta'] = self.entropy_beta
            meters.update(lr=self.lr, entropy_beta=self.entropy_beta)
        elif mode == 'test':
            params['dump'] = True
            params['use_argmax'] = True
        else:
            backup = copy.deepcopy(player)
            params['use_argmax'] = self.is_candidate

        if mode == 'train':
            if args.use_gpu:
                self.model.cpu()

            self.model.eval()
            mergedfc = []
            for i in range(args.ntrajectory):
                succ, score, traj, length, last_next_value = run_episode(player, self.model, need_restart=(i!=0), **params)
                meters.update(number=number, succ=succ, score=score, length=length)
                feed_dict = make_data(traj, args.gamma, succ, last_next_value[0], lam=args.lam, isQnet=self.model.isQnet)
                # content from valid_move dataset
                states, actions, labels = self.valid_action_dataset.sample_batch(args.batch_size)
                feed_dict['pred_states'] = as_tensor(states)
                feed_dict['pred_actions'] = as_tensor(actions)
                feed_dict['valid'] = as_tensor(labels).float()
                mergedfc.append(feed_dict)
            for k in feed_dict.keys():
                if k not in ["rewards"]: #reward not used to update loss
                    feed_dict[k] = torch.cat([j[k] for j in mergedfc], dim=0)
            feed_dict['entropy_beta'] = as_tensor(self.entropy_beta).float()
            feed_dict['eval_critic'] = as_tensor(True)

            if feed_dict['advantages'].shape[0] > 1 and not args.no_adv_norm:
                feed_dict['advantages'] = (feed_dict['advantages'] - feed_dict['advantages'].mean())/(feed_dict['advantages'].std() + 10**-7)

            self.model.train()

            if args.use_gpu:
                feed_dict = as_cuda(feed_dict)
                self.model.cuda()
            return feed_dict
        else:
            if args.use_gpu:
                self.model.cpu()
            succ, score, traj, length, last_next_value = run_episode(player, self.model, **params)
            meters.update(number=number, succ=succ, score=score, length=length)
            message = ('> {} iter={iter}, number={number}, succ={succ}, '
                       'score={score:.4f}, length={length}').format(mode, iter=index, **meters.val)
            return message, dict(succ=succ, number=number, backup=backup)

    def _extract_info(self, extra):
        return extra['succ'], extra['number'], extra['backup']

    def _get_accuracy(self, meters):
        return meters.avg['succ']

    def _get_threshold(self):
        candidate_relax = 0 if self.is_candidate else args.candidate_relax
        return super()._get_threshold() - \
               self.curriculum_thresh_relax * candidate_relax

    def _upgrade_lesson(self):
        super()._upgrade_lesson()
        # Adjust lr & entropy_beta w.r.t different lesson progressively.
        self.lr *= args.lr_decay
        self.entropy_beta *= args.entropy_beta_decay
        self.set_learning_rate(self.lr)

    def _train_epoch(self, epoch_size):
        model = self.model
        meters = GroupMeters()
        self._prepare_dataset(epoch_size, mode='train')

        def train_func(index):
            model.eval()
            feed_dict = self._get_train_data(index, meters)
            model.train()
            nbatch = feed_dict['states'].shape[0]
            minibatch_size = args.batch_size
            inds = np.arange(nbatch)
            if not args.no_shuffle_minibatch:
                np.random.shuffle(inds)
                for _ in range(args.noptepochs):
                    for start in range(0, nbatch, minibatch_size):
                        end = start + minibatch_size
                        mbinds = inds[start:end]
                        subfeed_dict = {}
                        for k in feed_dict.keys():
                            if type(feed_dict[k]) == torch.Tensor and len(feed_dict[k].shape) != 0 and k not in ('pred_states', 'pred_actions', 'valid'):
                                subfeed_dict[k] = feed_dict[k][mbinds, ...]
                            else:
                                subfeed_dict[k] = feed_dict[k]
                        message, _ = self._train_step(subfeed_dict, meters)
            else:
                for _ in range(args.noptepochs):
                    message, _ = self._train_step(feed_dict, meters)
            return message

        # For $epoch_size times, do train_func with tqdm progress bar.
        tqdm_for(epoch_size, train_func)
        logger.info(
            meters.format_simple(
                '> Train Epoch {:5d}: '.format(self.current_epoch),
                compressed=False))
        self._dump_meters(meters, 'train')
        if not self.is_graduated:
            self._take_exam(train_meters=copy.copy(meters))

        i = self.current_epoch
        if args.save_interval is not None and i % args.save_interval == 0:
            self.save_checkpoint(str(i))
        if args.test_interval is not None and i % args.test_interval == 0:
            score = self.test()
            cost = np.array([score[i].avg['succ'] for i in range(len(score))])
#            if (cost > 0.2).any():
#                ind = (cost > 0.2).nonzero()[0][-1]
#                cost = cost[ind]
#                self.trial.report((cost_weight[ind + 1] - cost_weight[ind]) * cost + cost_weight[ind], i)
#            else:
#                self.trial.report(0.0, i)
            self.trial.report(np.dot(cost, cost_weights), i)
            if self.trial.should_prune():
                raise optuna.exceptions.TrialPruned()

        return meters

    def _early_stop(self, meters):
        t = args.early_drop_epochs
        if t is not None and self.current_epoch > t * (self.nr_upgrades + 1):
            return True
        return super()._early_stop(meters)

    def train(self):
        self.valid_action_dataset = ValidActionDataset()
        self.lr = args.lr
        self.entropy_beta = args.entropy_beta
        return super().train()


def main(run_id, run_str, trial):
    if args.dump_dir is not None:
        io.mkdir(args.dump_dir)
    else:
        args.checkpoints_dir = None
        args.summary_file = None

    if args.dump_dir is not None:
        args.current_dump_dir = os.path.join(args.dump_dir, 'run_{}_{}'.format(run_id, run_str))
        io.mkdir(args.current_dump_dir)
        args.checkpoints_dir = os.path.join(args.current_dump_dir, 'checkpoints')
        io.mkdir(args.checkpoints_dir)
        args.summary_file = os.path.join(args.current_dump_dir, 'summary.json')

        args.log_file = os.path.join(args.current_dump_dir, 'log.log')
        set_output_file(args.log_file)

    args.epoch_size = args.epoch_size // args.ntrajectory
    logger.info(format_args(args))

    model = Model()
    optimizer = get_optimizer(args.optimizer, model, args.lr)

    trainer = MyTrainer.from_args(model, optimizer, args)
    trainer.trial = trial

    if args.load_checkpoint is not None:
        trainer.load_checkpoint(args.load_checkpoint)

    if args.test_only:
        trainer.current_epoch = 0
        return None, trainer.test()

    graduated = trainer.train()
    trainer.save_checkpoint('last')
    test_meters = trainer.test()
    return graduated, test_meters

if __name__ == '__main__':
    n_startup_trials = 50
    seed = None
    sampler = TPESampler(n_startup_trials=n_startup_trials, seed=seed)
    pruner = SuccessiveHalvingPruner(reduction_factor=3)  # less aggressive pruning
    #pruner = MedianPruner()  # even less aggressive pruning but much more slow
    study = optuna.create_study(sampler=sampler, pruner=pruner, study_name=args.study_name,
                                direction="maximize", storage=os.getenv('SQLURI'), load_if_exists=True)

    def sample_hyperparams(trial):
        # nlm_attributes = trial.suggest_categorical('nlm_attributes', [8, 16])
        lr = trial.suggest_float('lr', 0.0001, 0.01)
        # gamma = trial.suggest_float('gamma', 0.9, 1)
        # penalty = trial.suggest_float('penalty', 0.9, 1) - 1
        lr_decay = trial.suggest_categorical('lr_decay', [0.9, 1.0])
        entropy_beta = trial.suggest_loguniform('entropy_beta', 0.00000001, 0.5)
        entropy_beta_decay = trial.suggest_categorical('entropy_beta_decay', [1.0, 0.8])
        # critic_type = trial.suggest_categorical('critic_type', [0,1,2,3,4,5,6,7,8,9])
        noptepochs = trial.suggest_categorical('noptepochs', [2, 4, 8])
        epsilon = trial.suggest_float('epsilon', 0, 0.3)
        lam = trial.suggest_float('lam', 0.9, 1)
        clip_vf = trial.suggest_float('clip_vf', 0, 0.3)
        no_shuffle_minibatch = trial.suggest_categorical('no_shuffle_minibatch', [True, False])
        no_adv_norm = trial.suggest_categorical('no_adv_norm', [True, False])
        ntrajectory = trial.suggest_categorical('ntrajectory', [1, 5, 10, 20])
        batch_size = trial.suggest_categorical('batch_size', [4, 8, 16, 32])
       

        return {
            # 'nlm_attributes': nlm_attributes,
            'lr': lr,
            # 'gamma': gamma,
            # 'penalty': penalty,
            'lr_decay': lr_decay,
            'entropy_beta': entropy_beta,
            'entropy_beta_decay': entropy_beta_decay,
            'noptepochs': noptepochs,
            'epsilon': epsilon,
            'lam': lam,
            'clip_vf': clip_vf,
            'no_shuffle_minibatch': no_shuffle_minibatch,
            'no_adv_norm': no_adv_norm,
            'ntrajectory': ntrajectory,
            'batch_size': batch_size,
        }

    def objective(trial):
        global args

        stats = []
        nr_graduated = 0
        i = 0

        hparams = vars(main_args)
        sampled_hparams = sample_hyperparams(trial)
        print()
        print("Sample new hyperparameters:")
        print(sampled_hparams)
        print()
        hparams.update(sampled_hparams)
        args = Namespace(**hparams)

        graduated, test_meters = main(trial._trial_id, '_'.join([str(s) for s in sampled_hparams.values()]), trial)
        logger.info('run {}'.format(trial._trial_id))

        if test_meters is not None:
            for j, meters in enumerate(test_meters):
                if len(stats) <= j:
                    stats.append(GroupMeters())
                stats[j].update(number=meters.avg['number'], test_succ=meters.avg['succ'])

            for meters in stats:
                logger.info('number {}, test_succ {}'.format(meters.avg['number'], meters.avg['test_succ']))

        if not args.test_only:
            nr_graduated += int(graduated)
            if graduated:
                for j, meters in enumerate(test_meters):
                    stats[j].update(grad_test_succ=meters.avg['succ'])
            if nr_graduated > 0:
                for meters in stats:
                    logger.info('number {}, grad_test_succ {}'.format(meters.avg['number'], meters.avg['grad_test_succ']))

        cost = np.array([stats[i].avg['test_succ'] for i in range(len(stats))])
#        if (cost > 0.2).any():
#            ind = (cost > 0.2).nonzero()[0][-1]
#            cost = cost[ind]
#            return (cost_weight[ind + 1] - cost_weight[ind]) * cost + cost_weight[ind]
#        else:
#            return 0.0
        return np.dot(cost, cost_weight)

    try:
        study.optimize(objective, n_trials=args.n_trials, n_jobs=1)
    except KeyboardInterrupt:
        pass

    print('Number of finished trials: ', len(study.trials))

    print('Best trial:')
    trial = study.best_trial

    print('Value: ', trial.value)

    print('Params: ')
    for key, value in trial.params.items():
        print('    {}: {}'.format(key, value))

    study.trials_dataframe()
