from asyncio.constants import ACCEPT_RETRY_DELAY
import collections
import copy
import functools
import json
import os

import numpy as np

from numpy import linalg as LA
import torch
import torch.nn as nn
import torch.nn.functional as F

import jacinle.random as random
import jacinle.io as io

from random import randint

from keras.utils.np_utils import to_categorical
from difflogic.cli import format_args
from difflogic.dataset.utils import ValidActionDataset
from difflogic.nn.baselines import MemoryNet
from difflogic.nn.critic import *
from difflogic.nn.neural_logic import (
    InputTransform,
    LogicInference,
    LogicMachine,
    LogitsInference,
)
from difflogic.nn.neural_logic.modules._utils import meshgrid_exclude_self
from difflogic.nn.dlm.layer import DifferentiableLogicMachine
from difflogic.nn.dlm.neural_logic import DLMInferenceBase
from difflogic.nn.rl.reinforce import REINFORCELoss, REINFORCELogLoss

from difflogic.nn.rl.ppo import PPOLoss
from difflogic.train import MiningTrainerBase

from difflogic.envs.blocksworld.ddpg import *

from jacinle.cli.argument import JacArgumentParser
from jacinle.logging import get_logger
from jacinle.logging import set_output_file
from jacinle.utils.container import GView
from jacinle.utils.meter import GroupMeters
from difflogic.train.accum_grad import AccumGrad
from jactorch.optim.quickaccess import get_optimizer
from jactorch.utils.meta import as_cuda
from jactorch.utils.meta import as_numpy
from jactorch.utils.meta import as_tensor
from difflogic.tqdm_utils import tqdm_for
from difflogic.hopcroft import hopcroft_minimization
from difflogic.prompt import (
    get_public_prompt,
    set_input_prompt,
    get_task_prompt,
    tensor_to_tuple,
    action_to_letter,
)

import wandb

TASKS = ["highway", "final"]
global_times = 0
run_times = 0

parser = JacArgumentParser()

parser.add_argument(
    "--model",
    default="dlm",
    choices=["nlm", "memnet", "dlm"],
    help="model choices, nlm: Neural Logic Machine, memnet: Memory Networks, dlm: Differentiable Logic Machine",
)

# NLM parameters, works when model is 'nlm'.
nlm_group = parser.add_argument_group("Neural Logic Machines")
DifferentiableLogicMachine.make_nlm_parser(
    nlm_group,
    {"depth": 7, "breadth": 2, "exclude_self": True, "logic_hidden_dim": []},
    prefix="nlm",
)
nlm_group.add_argument(
    "--nlm-attributes",
    type=int,
    default=8,
    metavar="N",
    help="number of output attributes in each group of each layer of the LogicMachine",
)

# NLM critic parameters
nlmcritic_group = parser.add_argument_group("Neural Logic Machines critic")
LogicMachine.make_nlm_parser(
    nlmcritic_group,
    {
        "depth": 3,
        "breadth": 2,
        "residual": True,
        "exclude_self": True,
        "logic_hidden_dim": [],
    },
    prefix="nlmcrit",
)
nlmcritic_group.add_argument(
    "--nlm-attributes-critic",
    type=int,
    default=2,
    metavar="N",
    help="number of output attributes in each group of each layer of the critic LogicMachine",
)

# MemNN parameters, works when model is 'memnet'.
memnet_group = parser.add_argument_group("Memory Networks")
MemoryNet.make_memnet_parser(memnet_group, {}, prefix="memnet")

parser.add_argument("--task", required=True, choices=TASKS, help="tasks choices")

method_group = parser.add_argument_group("Method")
method_group.add_argument(
    "--concat-worlds",
    type="bool",
    default=True,
    metavar="B",
    help="concat the features of objects of same id among two worlds accordingly",
)
method_group.add_argument(
    "--pred-depth",
    type=int,
    default=None,
    metavar="N",
    help="the depth of nlm used for prediction task",
)
method_group.add_argument(
    "--pred-weight",
    type=float,
    default=0,
    metavar="F",
    help="the linear scaling factor for prediction task",
)


data_gen_group = parser.add_argument_group("Data Generation")
data_gen_group.add_argument(
    "--gen-method",
    default="dnc",
    choices=["dnc", "edge"],
    help="method use to generate random graph",
)
data_gen_group.add_argument(
    "--gen-graph-pmin",
    type=float,
    default=0.3,
    metavar="F",
    help="control parameter p reflecting the graph sparsity",
)
data_gen_group.add_argument(
    "--gen-graph-pmax",
    type=float,
    default=0.3,
    metavar="F",
    help="control parameter p reflecting the graph sparsity",
)
data_gen_group.add_argument(
    "--gen-max-len",
    type=int,
    default=4,
    metavar="N",
    help="maximum length of shortest path during training",
)
data_gen_group.add_argument(
    "--gen-test-len",
    type=int,
    default=4,
    metavar="N",
    help="length of shortest path during testing",
)
data_gen_group.add_argument(
    "--gen-directed", action="store_true", help="directed graph"
)


MiningTrainerBase.make_trainer_parser(
    parser,
    {
        "epochs": 1000,
        "epoch_size": 40,
        "test_epoch_size": 6,
        "test_number_begin": 3,
        "test_number_step": 1,
        "test_number_end": 3,
        "curriculum_start": 1,
        "curriculum_step": 1,
        "curriculum_graduate": 1,
        "curriculum_thresh_relax": 0.01,
        "curriculum_thresh": 1,
        # don't learn well on several lessons at a time with PPO/multiple trajectories
        "sample_array_capacity": 1,
        "disable-balanced-sample": True,
        "inherit_neg_data": False,
        "enable_mining": True,
        "mining_interval": 10,  # 6 for sort/path
        "mining_epoch_size": 20,
        "mining_dataset_size": 200,
        "prob_pos_data": 0.6,  # 0.5 for sort/path
    },
)

train_group = parser.add_argument_group("Train")
train_group.add_argument("--seed", type=int, default=None, metavar="SEED")
train_group.add_argument("--use-gpu", action="store_true", help="use GPU or not")
train_group.add_argument(
    "--optimizer",
    default="AdamW",
    choices=["SGD", "Adam", "AdamW"],
    help="optimizer choices",
)
train_group.add_argument(
    "--lr", type=float, default=0.05, metavar="F", help="initial learning rate"
)
train_group.add_argument(
    "--lr-decay",
    type=float,
    default=0.99,
    metavar="F",
    help="exponential decay of learning rate per lesson",
)
train_group.add_argument(
    "--accum-grad",
    type=int,
    default=1,
    metavar="N",
    help="accumulated gradient (default: 1)",
)
train_group.add_argument(
    "--ntrajectory",
    type=int,
    default=5,
    metavar="N",
    help="number of trajectories to compute gradient",
)
train_group.add_argument(
    "--batch-size",
    type=int,
    default=4,
    metavar="N",
    help="batch size for extra prediction",
)
train_group.add_argument(
    "--candidate-relax",
    type=int,
    default=0,
    metavar="N",
    help="number of thresh relaxation for candidate",
)
train_group.add_argument(
    "--extract-path", action="store_true", help="extract path or not"
)
train_group.add_argument("--extract-rule", action="store_true", help="extract rule")
train_group.add_argument("--gumbel-noise-begin", type=float, default=0.1)
train_group.add_argument("--dropout-prob-begin", type=float, default=0.001)
train_group.add_argument("--tau-begin", type=float, default=1)
train_group.add_argument("--last-tau", type=float, default=0.01)
train_group.add_argument(
    "--entropy-reg",
    type=float,
    default=0.0,
    metavar="F",
    help="entropy regularization weight for interpretability",
)

rl_group = parser.add_argument_group("Reinforcement Learning")
rl_group.add_argument(
    "--gamma",
    type=float,
    default=0.99,
    metavar="F",
    help="discount factor for accumulated reward function in reinforcement learning",
)
rl_group.add_argument(
    "--penalty",
    type=float,
    default=-0.01,
    metavar="F",
    help="a small penalty each step",
)
rl_group.add_argument(
    "--entropy-beta",
    type=float,
    default=0.0,
    metavar="F",
    help="entropy loss scaling factor",
)
rl_group.add_argument(
    "--entropy-beta-decay",
    type=float,
    default=0.8,
    metavar="F",
    help="entropy beta exponential decay factor",
)
rl_group.add_argument("--critic-type", type=int, default=5)
rl_group.add_argument("--noptepochs", type=int, default=2)
rl_group.add_argument("--epsilon", type=float, default=0.2)
rl_group.add_argument("--lam", type=float, default=0.9)
rl_group.add_argument("--clip-vf", type=float, default=0.2)
rl_group.add_argument("--no-shuffle-minibatch", type="bool", default=True)
rl_group.add_argument("--no-adv-norm", type="bool", default=False)
rl_group.add_argument(
    "--dlm-noise", type=int, default=2, metavar="N", help="dlm noise handling"
)
rl_group.add_argument(
    "--distribution",
    type=int,
    default=1,  # 0 NLRL, 1 softmax, 2 move e^F
    metavar="N",
    help="distribution used to transform reasonning to action selection",
)
rl_group.add_argument("--no-decay", type="bool", default=False)

io_group = parser.add_argument_group("Input/Output")
io_group.add_argument("--dump-dir", default=None, metavar="DIR", help="dump dir")
io_group.add_argument(
    "--dump-play",
    action="store_true",
    help="dump the trajectory of the plays for visualization",
)
io_group.add_argument(
    "--dump-fail-only", action="store_true", help="dump failure cases only"
)
io_group.add_argument(
    "--load-checkpoint",
    default=None,
    metavar="FILE",
    help="load parameters from checkpoint",
)

schedule_group = parser.add_argument_group("Schedule")
schedule_group.add_argument(
    "--runs", type=int, default=1, metavar="N", help="number of runs"
)
schedule_group.add_argument(
    "--early-drop-epochs",
    type=int,
    default=50,
    metavar="N",
    help="epochs could spend for each lesson, early drop",
)
schedule_group.add_argument(
    "--save-interval",
    type=int,
    default=10,
    metavar="N",
    help="the interval(number of epochs) to save checkpoint",
)
schedule_group.add_argument(
    "--test-interval",
    type=int,
    default=None,
    metavar="N",
    help="the interval(number of epochs) to do test",
)
schedule_group.add_argument("--test-only", action="store_true", help="test-only mode")
schedule_group.add_argument(
    "--test-inter",
    action="store_true",
    help="test-inter mode when turning on test-only mode",
)
schedule_group.add_argument(
    "--test-not-graduated", action="store_true", help="test not graduated models also"
)

args = parser.parse_args()

args.use_gpu = args.use_gpu and torch.cuda.is_available()
args.dump_play = args.dump_play and (args.dump_dir is not None)
args.epoch_size = args.epoch_size // args.ntrajectory

succ_rate = [0] * ((args.test_epoch_size * 4) // 10)
tot_step = 0

if args.dump_dir is not None:
    io.mkdir(args.dump_dir)
    args.log_file = os.path.join(args.dump_dir, "log.log")
    set_output_file(args.log_file)
else:
    args.checkpoints_dir = None
    args.summary_file = None

if args.seed is not None:
    random.reset_global_seed(args.seed)

if "nlrl" in args.task:
    args.concat_worlds = False
    args.penalty = None
    # TODO: dataset for pred_valid cannot contain unary+binary states
    args.pred_weight = 0.0

    # no curriculum learn for NLRL tasks
    args.curriculum_start = 4
    args.curriculum_graduate = 4
    args.mining_epoch_size = 20

    # not used
    args.test_number_begin = 4
    args.test_number_step = 1
    args.test_number_end = 4
elif args.task in ["sort", "path"]:
    args.concat_worlds = False
    args.pred_weight = 0.0
#    args.curriculum_start = 3

if args.task == "highway":
    from difflogic.envs.blocksworld import get_highway_env as make_env

    make_env = functools.partial(make_env)
else:
    from difflogic.envs.blocksworld import make as make_env

    make_env = functools.partial(
        make_env, random_order=True, exclude_self=True, fix_ground=True
    )

logger = get_logger(__file__)


def grip_near_object(err):
    def predicate(sys_state, res_state):
        dist = sys_state[:3] - (sys_state[3:6] + np.array([0.0, 0.0, 0.065]))
        dist = np.concatenate([dist, [sys_state[9] + sys_state[10] - 0.1]])
        return -LA.norm(dist) + err

    return predicate


def hold_object(err):
    def predicate(sys_state, res_state):
        dist = sys_state[:3] - sys_state[3:6]
        dist2 = np.concatenate([dist, [sys_state[9] + sys_state[10] - 0.045]])
        return -LA.norm(dist2) + err

    return predicate


def object_in_air(sys_state, res_state):
    return sys_state[5] - 0.45


def object_at_goal(err):
    def predicate(sys_state, res_state):
        dist = np.concatenate([sys_state[-3:], [sys_state[9] + sys_state[10] - 0.045]])
        return -LA.norm(dist) + err

    return predicate


class Model(nn.Module):

    def __init__(self):
        super().__init__()
        input_dims = [62, 0, 0]
        self.feature_axis = 0

        if args.model == "dlm":
            # print(input_dims)
            self.features = DifferentiableLogicMachine(
                depth=3,
                breadth=2,
                input_dims=input_dims,
                output_dims=args.nlm_attributes,
                # logic_hidden_dim=self._logic_hidden_dims,
                residual=False,
                dlm_intern_params={
                    "atoms_per_rule": 2,
                    "fuzzy_or": True,
                    "add_negation": True,
                },
            )
            nullary_output_dim = self.features.output_dims[0]
            unary_output_dim = self.features.output_dims[1]
            # output_dim = nullary_output_dim + 5 * unary_output_dim
            output_dim = nullary_output_dim
            # print(output_dim)
        else:
            raise ()

        if args.model == "dlm":
            self.pred_valid = DLMInferenceBase(output_dim, 4, False, "root_valid")
            self.pred = DLMInferenceBase(
                output_dim, 4, False, "root", atoms_per_rule=2
            )  # 试一下上下都改
            # if args.test_only:
            #     self.pred.weight.argmax(-1)
            # print(self.pred)
            if args.distribution == 2:
                self.ac_selector = ActionSelector(output_dim)

            self.tau = args.tau_begin
            self.dropout_prob = args.dropout_prob_begin
            self.gumbel_prob = args.gumbel_noise_begin

            self.update_stoch()
            if args.entropy_reg != 0.0:
                self.lowernoise()
                self.restorenoise()

        self.loss = PPOLoss()
        self.pred_loss = nn.BCELoss()
        self.force_decay = False

        range_dims = []
        for i in range(
            args.curriculum_start,
            args.curriculum_graduate + args.curriculum_step,
            args.curriculum_step,
        ):
            if args.task == "final":
                range_dims.append([(i + 1) * 2, (i + 1) * 2, input_dims])

        if args.task == "final":
            n_action = lambda x: (x[0] // 2) * ((x[0] // 2) - 1)

        self.isQnet = True
        if args.critic_type == 0:
            self.critic = InvariantNObject(MLPCritic, range_dims, dict())
            self.isQnet = False
        elif args.critic_type == 1:
            self.critic = GRUCritic(input_dims)
            self.isQnet = False
        elif args.critic_type == 2:
            self.critic = GRUCritic(input_dims, shuffle_index=True)
            self.isQnet = False
        elif args.critic_type == 3:
            nlmcrit = LogicMachine.from_args(
                input_dims, args.nlm_attributes_critic, args, prefix="nlmcrit"
            )
            self.critic = InvariantNObject(
                NLMMLPCritic,
                range_dims,
                dict(nlm=nlmcrit, feature_axis=None if args.model != "dlm" else 0),
            )
            self.isQnet = False
        elif args.critic_type == 4:
            self.critic = InvariantNObject(
                ConvCritic, range_dims, dict(input_channel=input_dims)
            )
            self.isQnet = False
        elif args.critic_type == 5:
            self.critic = MixedGRUCritic(input_dims)
            self.isQnet = False
        elif args.critic_type == 6:
            self.critic = InvariantNObject(
                MLPCriticQ, range_dims, dict(n_action_func=n_action)
            )
        elif args.critic_type == 7:
            nlmcrit = LogicMachine.from_args(
                input_dims, args.nlm_attributes_critic, args, prefix="nlmcrit"
            )
            self.critic = InvariantNObject(
                NLMMLPCriticQ,
                range_dims,
                dict(
                    nlm=nlmcrit,
                    n_action_func=n_action,
                    feature_axis=None if args.model != "dlm" else 0,
                ),
            )
        elif args.critic_type == 8:
            self.critic = InvariantNObject(ConvReduceCritic, range_dims, dict())
            self.isQnet = False
        elif args.critic_type == 9:
            self.critic = InvariantNObject(
                ConvReduceCriticQ, range_dims, dict(n_action_func=n_action)
            )
        else:
            # print("unkown critic_type")
            quit()
        self.critic_loss = nn.MSELoss()

        # cjy
        state_dim = 31
        action_dim = 4
        action_bound = [1.0, 1.0, 1.0, 1.0]
        hyperparams = DDPGParams(
            state_dim,
            action_dim,
            action_bound,
            minibatch_size=256,
            num_episodes=i,
            discount=0.95,
            actor_hidden_dim=256,
            critic_hidden_dim=256,
            epsilon_decay=3e-6,
            decay_function="linear",
            steps_per_update=100,
            gradients_per_update=100,
            buffer_size=200000,
            sigma=0.15,
            epsilon_min=0.3,
            target_noise=0.0003,
            target_clip=0.003,
            warmup=1000,
            max_timesteps=1000,
        )
        self.ddpg_controller = [
            DDPG(hyperparams),
            DDPG(hyperparams),
            DDPG(hyperparams),
            DDPG(hyperparams),
        ]
        # cjy

    def update_stoch(self):
        # return

        if args.model == "dlm":
            self.features.update_tau(self.tau)
            self.pred_valid.update_tau(self.tau)
            self.pred.update_tau(self.tau)

            self.features.update_gumbel_noise(self.gumbel_prob)
            self.pred_valid.update_gumbel_noise(self.gumbel_prob)
            self.pred.update_gumbel_noise(self.gumbel_prob)

            self.features.update_dropout_prob(self.dropout_prob)
            self.pred_valid.update_dropout_prob(self.dropout_prob)
            self.pred.update_dropout_prob(self.dropout_prob)

    def lowernoise(self):
        # return

        if args.model == "dlm":
            self.pred.independant_noise_per_sample = False
            self.pred.with_gumbel = False
            self.pred.with_dropout = False

            self.pred_valid.independant_noise_per_sample = False
            self.pred_valid.with_gumbel = False
            self.pred_valid.with_dropout = False

            self.features.independant_noise_per_sample(False)
            self.features.with_gumbel(False)
            self.features.with_dropout(False)

    def restorenoise(self):
        # return

        if args.model == "dlm":
            self.pred.independant_noise_per_sample = True
            self.pred_valid.independant_noise_per_sample = True
            self.features.independant_noise_per_sample(True)

            self.pred.with_gumbel = True
            self.pred_valid.with_gumbel = True
            self.features.with_gumbel(True)

            if args.entropy_reg == 0.0:
                self.pred.with_dropout = True
                self.pred_valid.with_dropout = True
                self.features.with_dropout(True)

    def stoch_decay(self, lesson, train_succ):
        # return

        if (
            args.model == "dlm"
            and not args.no_decay
            and lesson == args.curriculum_graduate
            and train_succ > 0.8
        ) or self.force_decay:
            self.force_decay = True
            self.tau = self.tau * 0.995
            self.gumbel_prob = self.gumbel_prob * 0.98
            self.dropout_prob = self.dropout_prob * 0.98
            args.pred_weight = args.pred_weight * 0.98

            # considered it failed
            if self.tau <= 0.2:
                self.tau = args.tau_begin
                self.dropout_prob = args.dropout_prob_begin
                self.gumbel_prob = args.gumbel_noise_begin

            self.update_stoch()

    def forward(self, feed_dict):
        feed_dict = GView(feed_dict)

        if args.task in ["highway", "final", "stack", "sort"]:
            states = feed_dict.states.float()
            batch_size = states.size(0)
        else:
            states = feed_dict.states
            batch_size = states[0].size(0)

        f, fs, more_info = self.get_outputs(states)
        saved_for_fa = f

        # OUTPUT CHECKPOINT
        if args.model == "dlm":
            # print("Binary Relations")
            # print(f.size())
            nullary_features = f[0]
            # unary_features = f[1]
            # print(nullary_features.size(), unary_features.size())
            # unary_features = torch.flatten(unary_features, start_dim=1)
            # print(nullary_features.size(), unary_features.size())
            # f = torch.cat((nullary_features, unary_features), dim=1)
            f = nullary_features
            # print(f.size())
            f = f.float()
            print("after dlm cal: ", f)
            f = self.pred(f)
            # print("after inference base ori: ", f[0])                          #inference base导致算出来的东西都一样
            # f[0][0] = 0
            # print(int(torch.sum(f[0])//2))
            # f[0][int(torch.sum(f[0])//2)] = 1
            # print("after inference base: ", f)
            # print(f)
            # print(f[0].size())
            logits = f[0].squeeze(dim=-1).view(batch_size, -1)
            # print(logits)
            # tt = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
            # print(torch.sum(logits[0]))
            # logits = torch.tensor([tt[3 - int(torch.sum(logits[0])//2)]])

            print(logits)
            logits = 1e-5 + logits * (1.0 - 2e-5)
            if args.distribution == 0:
                sigma = logits.sum(-1).unsqueeze(-1)
                policy = torch.where(
                    sigma > 1.0, logits / sigma, logits + (1 - sigma) / logits.shape[1]
                )
            elif args.distribution == 1:
                # print(f"1: {logits}")
                policy = F.softmax(logits / args.last_tau, dim=-1).clamp(min=1e-20)
                # policy = F.softmax(logits, dim=-1).clamp(min=1e-20)
                # print(policy)

                if args.test_only:
                    print(
                        [round(float(x), 5) for x in logits[0]],
                        [round(float(x), 5) for x in policy[0]],
                    )
            elif args.distributsion == 2:
                if self.training:
                    fa = self.ac_selector(saved_for_fa.detach())
                    policy = (fa.sigmoid() + 1e-5) * logits
                else:
                    policy = logits
                policy = policy / policy.sum(-1).unsqueeze(-1)
            else:
                raise ()

            if feed_dict.training:
                if "saturation" in more_info.keys():
                    more_info["saturation"].extend(f[1]["saturation"])
                else:
                    more_info["saturation"] = [f[1]["saturation"]]
                if "entropies" in more_info.keys():
                    more_info["entropies"].extend(f[1]["entropies"])
        else:
            logits = self.pred(f).squeeze(dim=-1).view(batch_size, -1)
            policy = F.softmax(logits, dim=-1).clamp(min=1e-20)

        if feed_dict.eval_critic:
            if self.isQnet:
                crit_out = self.critic(fs)
                approxr = (crit_out * policy).sum(-1)
            else:
                approxr = crit_out = self.critic(fs)
        else:
            crit_out = None

        if not feed_dict.training:
            # print(policy)
            return dict(
                policy=policy, logits=logits, value=crit_out, more_info=more_info
            )

        loss, monitors = self.loss(
            policy,
            feed_dict.old_policy,
            feed_dict.actions,
            feed_dict.advantages,
            args.epsilon,
            feed_dict.entropy_beta,
        )
        if self.isQnet:
            crit_preloss = (crit_out * feed_dict.actions_ohe).sum(-1)
        else:
            crit_preloss = crit_out
        if args.clip_vf is not None and args.clip_vf > 0.0:
            if self.isQnet:
                previous_val = (feed_dict.values * feed_dict.actions_ohe).sum(-1)
                crit_preloss = previous_val + torch.clamp(
                    crit_preloss - previous_val, -args.clip_vf, args.clip_vf
                )
            else:
                crit_preloss = feed_dict.values + torch.clamp(
                    crit_preloss - feed_dict.values, -args.clip_vf, args.clip_vf
                )
        losscrit = self.critic_loss(crit_preloss, feed_dict.returns)
        monitors["critic_accuracy"] = losscrit
        loss += losscrit
        if args.pred_weight != 0.0:
            pred_states = feed_dict.pred_states.float()
            f, _, _ = self.get_binary_relations(pred_states, depth=args.pred_depth)
            if args.model == "dlm":
                f = self.pred_valid(f)[0].squeeze(dim=-1).view(pred_states.size(0), -1)
            else:
                f = self.pred_valid(f).squeeze(dim=-1).view(pred_states.size(0), -1)
            # Set minimal value to avoid loss to be nan.
            valid = f[range(pred_states.size(0)), feed_dict.pred_actions].clamp(
                min=1e-20
            )
            pred_loss = self.pred_loss(valid, feed_dict.valid)
            monitors["pred/accuracy"] = (
                feed_dict.valid.eq((valid > 0.5).float()).float().mean()
            )
            loss = loss + args.pred_weight * pred_loss

        if args.model == "dlm":
            pred = (logits.detach().cpu() > 0.5).float()
            sat = 1 - (logits.detach().cpu() - pred).abs()
            monitors.update({"saturation/min": np.array(sat.min())})
            monitors.update({"saturation/mean": np.array(sat.mean())})
            saturation_inside = torch.cat(
                [a.flatten() for a in more_info["saturation"]]
            )
            monitors.update(
                {"saturation-inside/min": np.array(saturation_inside.cpu().min())}
            )
            monitors.update(
                {"saturation-inside/mean": np.array(saturation_inside.cpu().mean())}
            )
            monitors.update({"tau": np.array(self.tau)})
            monitors.update({"dropout_prob": np.array(self.dropout_prob)})
            monitors.update({"gumbel_prob": np.array(self.gumbel_prob)})

            if args.entropy_reg != 0.0:
                entropies = torch.cat([a.flatten() for a in more_info["entropies"]])
                loss += args.entropy_reg * entropies.mean()

        return (
            loss,
            monitors,
            dict() if more_info.get("p_star_dict") is None else more_info,
        )

    def get_outputs(self, states, depth=None):
        """get binary relations given states, up to certain depth."""
        more_info = None

        # PREDICATES CHECKPOINT
        full_obs = states.tolist()
        all_nullary_predicates = []
        all_unary_predicates = []
        all_binary_predicates = []

        print("full_obs: ", full_obs)

        nullary_predicates_dict = {
            "x": {
                "list": [],
                "boolean_list": [],
                "desc": "The x-axis relative distance of the end effector and the object is between _ and *.",
            },
            "y": {
                "list": [],
                "boolean_list": [],
                "desc": "The y-axis relative distance of the end effector and the object is between _ and *.",
            },
            "z": {
                "list": [],
                "boolean_list": [],
                "desc": "The z-axis relative distance of the end effector and the object is between _ and *.",
            },
            "gripper1": {
                "list": [],
                "boolean_list": [],
                "desc": "The displacement of the left gripper is between _ and *.",
            },
            "gripper2": {
                "list": [],
                "boolean_list": [],
                "desc": "The displacement of the right gripper is between _ and *.",
            },
            "block_z1": {
                "list": [],
                "boolean_list": [],
                "desc": "The height of the object is lower than the target height 0.45.",
            },
            "block_z2": {
                "list": [],
                "boolean_list": [],
                "desc": "The height of the object is higher than the target height 0.45.",
            },
        }
        ll = [
            0,
            0.002,
            0.004,
            0.006,
            0.008,
            0.01,
            0.012,
            0.014,
            0.016,
            0.018,
            0.02,
            0.026,
            1,
        ]
        for _ in range(len(full_obs)):
            nullary_predicates = []
            # nullary_predicates.append(states[0][-2] < 0.5)
            # nullary_predicates.append(states[0][-2] > 0.5)
            # nullary_predicates.append(states[0][-1] < 0.5)
            # nullary_predicates.append(states[0][-1] > 0.5)

            s_counter = 0
            for i in range(len(ll) - 1):
                # End effector x position in global coordinates - Block x position in global coordinates
                nullary_predicates.append(
                    ll[i] <= abs(states[0][0] - states[0][3]) < ll[i + 1]
                )
                nullary_predicates_dict["x"]["list"].append(s_counter)
                nullary_predicates_dict["x"]["boolean_list"].append(
                    nullary_predicates[-1]
                )
                s_counter += 1
                # End effector y position in global coordinates - Block y position in global coordinates
                nullary_predicates.append(
                    ll[i] <= abs(states[0][1] - states[0][4]) < ll[i + 1]
                )
                nullary_predicates_dict["y"]["list"].append(s_counter)
                nullary_predicates_dict["y"]["boolean_list"].append(
                    nullary_predicates[-1]
                )
                s_counter += 1
                # End effector z position in global coordinates - Block z position in global coordinates
                nullary_predicates.append(
                    ll[i] <= abs(states[0][2] - states[0][5] - 0.065) < ll[i + 1]
                )
                nullary_predicates_dict["z"]["list"].append(s_counter)
                nullary_predicates_dict["z"]["boolean_list"].append(
                    nullary_predicates[-1]
                )
                s_counter += 1
                # Joint displacement of the right gripper finger - Joint displacement of the left gripper finger
                nullary_predicates.append(
                    ll[i] <= abs(states[0][9] - states[0][10] - 0.1) < ll[i + 1]
                )
                nullary_predicates_dict["gripper1"]["list"].append(s_counter)
                nullary_predicates_dict["gripper1"]["boolean_list"].append(
                    nullary_predicates[-1]
                )
                s_counter += 1
                nullary_predicates.append(
                    ll[i] <= abs(states[0][9] - states[0][10] - 0.045) < ll[i + 1]
                )
                nullary_predicates_dict["gripper2"]["list"].append(s_counter)
                nullary_predicates_dict["gripper2"]["boolean_list"].append(
                    nullary_predicates[-1]
                )
                s_counter += 1

            # Block z position in global coordinates
            nullary_predicates.append(states[0][5] <= 0.45)
            nullary_predicates_dict["block_z1"]["list"].append(s_counter)
            nullary_predicates_dict["block_z1"]["boolean_list"].append(
                nullary_predicates[-1]
            )
            s_counter += 1
            nullary_predicates.append(states[0][5] > 0.45)
            nullary_predicates_dict["block_z2"]["list"].append(s_counter)
            nullary_predicates_dict["block_z2"]["boolean_list"].append(
                nullary_predicates[-1]
            )
            s_counter += 1

            all_nullary_predicates.append(nullary_predicates)

            # UNARY

            all_unary_predicates.append([[]])

            # BINARY

            binary_predicates = []

            all_binary_predicates.append(binary_predicates)

            # print(unary_predicates.size())
            # print(binary_predicates.size())

        all_nullary_predicates = torch.tensor(all_nullary_predicates).float()
        all_unary_predicates = torch.tensor(all_unary_predicates).float()
        all_binary_predicates = torch.tensor(all_binary_predicates).float()

        fs = [all_nullary_predicates, all_unary_predicates, all_binary_predicates]
        inp = [all_nullary_predicates, all_unary_predicates, all_binary_predicates]

        # PREDICATES END
        print("input to dlm: ", inp)
        print("all_nullary_predicates: ", inp[0].shape)
        if args.model == "dlm":
            if args.extract_path:
                self.features.extract_graph(self.feature_axis, self.pred)
                for i in range(len(inp)):
                    if inp[i] is None:
                        continue
                    # print("Hello", inp[i])
                    inp[i] = inp[i].bool()
            # print([x.size() if x is not None else 0 for x in inp])
            # print(inp)

            features = self.features(inp, depth=depth, extract_rule=args.extract_rule)
            # print("After transformation: f")
            print("features: ", features)
            f = features[0]
            # print("length of f[0]: ", len(f[0]))
            more_info = features[1]
            p_star = features[2]
            # print(f)

        # f = meshgrid_exclude_self(f)
        print("more info: ", more_info)
        # p_star是list，将p_star中重复的元素去掉
        p_star = list(set(p_star[0]))
        inp = inp[0][0]
        # 根据p_star中的元素，将inp对应位置的元素取出，存入p_star_dict中
        print("inp: ", inp)
        print("inp's length: ", len(inp))
        print("p_star: ", p_star)
        print("nullary_predicates_dict: ", nullary_predicates_dict)
        # 将nullary_predicates_dict中的list中的元素与desc一一对应
        p_dict = {}
        prompt = "["
        for key in nullary_predicates_dict.keys():
            for i, p in enumerate(nullary_predicates_dict[key]["list"]):
                p_dict[p] = {}
                p_dict[p]["bool"] = nullary_predicates_dict[key]["boolean_list"][i]
                p_dict[p]["desc"] = nullary_predicates_dict[key]["desc"]
                p_dict[p]["desc"] = p_dict[p]["desc"].replace("_", str(ll[i]))
                p_dict[p]["desc"] = p_dict[p]["desc"].replace("*", str(ll[i + 1]))
                # print(f"p_dict[{p}]: {p_dict[p]}")
                if i == 0 and len(nullary_predicates_dict[key]["list"]) > 1:
                    prompt += f'{{"id": "P{p}", "desc": "{p_dict[p]["desc"]}"}}, ..., '
                if i == len(nullary_predicates_dict[key]["list"]) - 1:
                    prompt += f'{{"id": "P{p}", "desc": "{p_dict[p]["desc"]}"}}, '
        prompt += "]"
        print("prompt: ", prompt)
        # 根据p_star中的元素，将p_dict中的元素取出，存入p_star_dict中
        p_star_dict = {}
        for p in p_star:
            p_star_dict[p] = p_dict[p]
            # print(f"p_star_dict[{p}]: {p_star_dict[p]}")
            # print(f"p_dict[{p}]: {p_dict[p]}")
        more_info["p_star_dict"] = p_star_dict
        more_info["p_dict"] = p_dict
        return f, fs, more_info


def make_data(traj, gamma, succ, last_next_value, lam, isQnet):
    """Aggregate data as a batch for RL optimization."""
    traj["actions"] = as_tensor(np.array(traj["actions"]))
    traj["actions_ohe"] = as_tensor(np.array(traj["actions_ohe"]))
    if type(traj["states"][0]) is list:
        f1 = [f[0] for f in traj["states"]]
        f2 = [f[1] for f in traj["states"]]
        traj["states"] = [torch.cat(f1, dim=0), torch.cat(f2, dim=0)]
    else:
        traj["states"] = torch.cat(traj["states"], dim=0)
    traj["values"] = torch.cat(traj["values"], dim=0)
    traj["old_policy"] = torch.cat(traj["old_policy"], dim=0)
    traj["advantages"] = torch.zeros(traj["values"].shape[0])
    last_gae_lam = (
        0  # the next state of the last state in traj is always the terminal state.
    )
    for step in reversed(range(len(traj["values"]))):
        if step == len(traj["values"]) - 1:
            next_values = 0 if succ else last_next_value
        elif isQnet:
            next_values = (
                traj["values"][step + 1] * traj["old_policy"][step + 1].cpu()
            ).sum(-1)
        else:
            next_values = traj["values"][step + 1]
        if isQnet:
            delta = (
                traj["rewards"][step]
                + gamma * next_values
                - (traj["values"][step] * traj["actions_ohe"][step]).sum(-1)
            )
        else:
            delta = traj["rewards"][step] + gamma * next_values - traj["values"][step]
        last_gae_lam = delta + gamma * lam * last_gae_lam
        traj["advantages"][step] = last_gae_lam

    if isQnet:
        traj["returns"] = traj["advantages"] + (
            traj["values"] * traj["actions_ohe"]
        ).sum(-1)
    else:
        traj["returns"] = traj["advantages"] + traj["values"]
    # print(traj)
    return traj


def run_episode(
    env,
    model,
    mode,
    number,
    play_name="",
    dump=False,
    dataset=None,
    eval_only=False,
    use_argmax=False,
    need_restart=False,
    entropy_beta=0.0,
    eval_critic=False,
):
    """Run one episode using the model with $number blocks."""
    is_over = False
    traj = collections.defaultdict(list)
    score = 0
    if need_restart:
        env.restart()

    optimal = None

    if args.model == "dlm":
        # by default network isn't in training mode during data collection
        # but with dlm we don't want to use argmax only
        # except in 2 cases (testing the interpretability or the last mining phase to get an interpretable policy):
        if ("inter" in mode) or (
            ("mining" in mode)
            or ("inherit" in mode)
            or ("test" in mode)
            and number == args.curriculum_graduate
        ):
            model.lowernoise()
        else:
            model.train(True)

            if args.dlm_noise == 1 and (
                ("mining" in mode) or ("inherit" in mode) or ("test" in mode)
            ):
                model.lowernoise()
            elif args.dlm_noise == 2:
                model.lowernoise()

    if args.test_only:
        pass
        # use_argmax = True
        # model.eval()
    if "test" in mode:
        use_argmax = True
        model.eval()

    step = 0

    global_predicate_info = np.array([0, 0])
    cnt_2 = 0
    stage = 1
    history = []
    succ = 0
    while not is_over:
        state = env.current_state
        state = np.concatenate([state, global_predicate_info])
        if "nlrl" not in args.task or args.task == "sort":
            feed_dict = dict(states=np.array([state]))
        else:
            feed_dict = dict(states=state)

        feed_dict["entropy_beta"] = as_tensor(entropy_beta).float()
        feed_dict["eval_critic"] = as_tensor(eval_critic)
        feed_dict["training"] = as_tensor(False)
        feed_dict = as_tensor(feed_dict)
        if args.use_gpu:
            feed_dict = as_cuda(feed_dict)

        with torch.set_grad_enabled(False):
            output_dict = model(feed_dict)
        policy = output_dict["policy"]
        # 提取出p_star_dict
        # print("more_info['p_star_dict']: ", output_dict["more_info"]["p_star_dict"])

        p = as_numpy(policy.data[0])

        action = p.argmax() if use_argmax else random.choice(len(p), p=p)
        if args.pred_weight != 0.0:
            # Need to ensure that the env.utils.MapActionProxy is the outermost class.
            mapped_x, mapped_y = env.mapping[action]
            # env.unwrapped to get the innermost Env class.
            valid = env.unwrapped.world.moveable(mapped_x, mapped_y)

        # 1 -> 2 -> 3
        reward = 0
        flag1 = 0
        flag2 = 0
        flag3 = 0
        global global_times
        ############################optimize###############################
        if (
            (args.load_checkpoint is not None)
            and ("train" in mode)
            and stage == 3
            and global_times <= 6
        ):
            print("optimize stage 3!")
            # if np.random.rand() < 0.5:        #opt 2 and 3
            #     action = 3
            # else:
            action = 2
        ############################optimize###############################

        for i in range(10):  # cjy 连续做十次
            state = env.current_state
            state = np.concatenate([state, global_predicate_info])
            sys_dim = 31
            sys_state = state[:sys_dim]
            res_state = state[sys_dim:]
            a = grip_near_object(0.03)(sys_state, res_state)
            b = hold_object(0.03)(sys_state, res_state)
            c = object_in_air(sys_state, res_state)
            d = object_at_goal(0.03)(sys_state, res_state)
            if c > 0:
                flag3 = 1
            else:
                flag3 = 0
            if a > 0 and stage == 1 and i > 1:
                flag1 = 1
                if args.test_only:
                    break
            elif b > 0 and stage == 2 and i > 1:
                flag2 = 1
                if args.test_only:
                    break

            # if a > 0 and i > 1:
            #     flag1 = 1
            # if b > 0 and stage > 1 and i > 1:          #去掉stage == 2，有局限性了   或者stage == 2 3 4
            #     flag2 = 1
            #     # if args.test_only:
            #     # breakflag = 1
            # if c > 0:
            #     flag3 = 1
            # elif c <= 0:    #如果拿着物块的时候物块掉下来了
            #     flag3 = 0
            print("a, b, c: ", a, b, c)
            ddpg = model.ddpg_controller[action]

            # print("state1: ", state)
            ddpg_action = ddpg.actor.get_action(
                state[:-2]
            )  # model determine action given state

            env_reward, is_over = env.action(ddpg_action)
            if env_reward == 0.0:
                succ = 1
            next_state = env.current_state  # cjy 这边用DLM的next state的方法
            step += 1

            next_state = np.concatenate([next_state, global_predicate_info])
            print("act_dlm: ", action)

        # graph algorithm
        ori_stage = stage
        p_star_dict = output_dict["more_info"]["p_star_dict"]
        p_dict = output_dict["more_info"]["p_dict"]
        if flag1 and stage == 1:
            stage = 2
        elif action != 1 and stage == 2:
            stage = 1
        elif flag2 and stage == 2:
            stage = 3
        elif stage == 3 and action == 0:
            stage = 1
        elif stage == 3 and flag3 and action == 2:
            stage = 4
        elif stage == 4 and action == 3:
            reward += 4
        tempstr = str(flag1) + str(flag2) + str(flag3)
        # ori_stage = stage
        # if tempstr in ['000']:
        #     stage = 1
        # elif tempstr in ['100']:
        #     stage = 2
        # elif tempstr in ['110', '010']:
        #     stage = 3
        #     if action == 2:             #稍微鼓励下2
        #         reward += 1.1
        # elif tempstr in ['111', '011', '001', '101']:
        #     stage = 4                 #先全改成3看看
        if d > 0:
            stage = 5
        print("mode info: ", mode, use_argmax, p)
        print(ori_stage, "->", stage, " tempstr: ", tempstr)

        if stage == 4 and action == 3:
            reward += 4

        print(ori_stage, "->", stage)
        reward += stage - ori_stage

        # if reward == 0 and args.penalty is not None:
        #     reward = args.penalty
        if reward == 0:
            reward -= 0.1
        history.append(
            [action, str(ori_stage) + "->" + str(stage), reward, p_star_dict, p_dict]
        )
        print("mode, use_argmax, p: ", mode, use_argmax, p)
        print("step: ", step, "fetch reward (adapted): ", reward)
        print("p_star_dict: ", p_star_dict)
        # succ = 1 if is_over and score > env.proxy.proxy.get_target_score() else 0
        # succ = 0
        # print(succ, reward)

        score += reward
        if cnt_2 >= 5 and "test" not in mode:
            break

        if not eval_only:
            traj["values"].append(output_dict["value"].detach().cpu())
            if type(feed_dict["states"]) is list:
                traj["states"].append([f.detach().cpu() for f in feed_dict["states"]])
            else:
                traj["states"].append(feed_dict["states"].detach().cpu())
            traj["rewards"].append(reward)
            traj["actions"].append(action)
            traj["actions_ohe"].append(
                to_categorical(action, num_classes=policy.shape[1])
            )
            traj["old_policy"].append(policy.detach().cpu())
        if args.pred_weight != 0.0:
            if not eval_only and dataset is not None and mapped_x != mapped_y:
                dataset.append(number + 1, state, action, valid)
        print("#########################################################")

    if eval_critic:
        state = env.current_state
        if "nlrl" not in args.task:
            feed_dict = dict(states=np.array([state]))
        else:
            feed_dict = dict(states=state)

        feed_dict["entropy_beta"] = as_tensor(entropy_beta).float()
        feed_dict["eval_critic"] = as_tensor(eval_critic)
        feed_dict["training"] = as_tensor(False)
        feed_dict = as_tensor(feed_dict)
        if args.use_gpu:
            feed_dict = as_cuda(feed_dict)

        with torch.set_grad_enabled(False):
            output_dict = model(feed_dict)
        if model.isQnet:
            last_next_value = (
                (output_dict["value"].detach() * output_dict["policy"].detach())
                .sum(-1)
                .cpu()
                .numpy()
            )
        else:
            last_next_value = output_dict["value"].detach().cpu().numpy()
    else:
        last_next_value = None

    length = step

    if args.model == "dlm":
        model.restorenoise()

    print("over see the action of this episode: ")
    for a in history:
        print(
            "action: ",
            a[0],
            " transition: ",
            a[1],
            " reward: ",
            a[2],
            " p_star_dict: ",
            a[3],
        )

    return succ, score, traj, length, last_next_value, optimal, history


class MyTrainer(MiningTrainerBase):
    def __init__(
        self,
        model,
        optimizer,
        epochs,
        epoch_size,
        test_epoch_size,
        test_number_begin,
        test_number_step,
        test_number_end,
        curriculum_start,
        curriculum_step,
        curriculum_graduate,
        enable_candidate,
        curriculum_thresh,
        curriculum_thresh_relax,
        curriculum_force_upgrade_epochs,
        sample_array_capacity,
        enhance_epochs,
        enable_mining,
        repeat_mining,
        candidate_mul,
        mining_interval,
        mining_epoch_size,
        mining_dataset_size,
        inherit_neg_data,
        disable_balanced_sample,
        prob_pos_data,
    ):
        super().__init__(
            model,
            optimizer,
            epochs,
            epoch_size,
            test_epoch_size,
            test_number_begin,
            test_number_step,
            test_number_end,
            curriculum_start,
            curriculum_step,
            curriculum_graduate,
            enable_candidate,
            curriculum_thresh,
            curriculum_thresh_relax,
            curriculum_force_upgrade_epochs,
            sample_array_capacity,
            enhance_epochs,
            enable_mining,
            repeat_mining,
            candidate_mul,
            mining_interval,
            mining_epoch_size,
            mining_dataset_size,
            inherit_neg_data,
            disable_balanced_sample,
            prob_pos_data,
        )

    def save_checkpoint(self, name):
        if args.checkpoints_dir is not None:
            checkpoint_file = os.path.join(
                args.checkpoints_dir, "checkpoint_{}.pth".format(name)
            )
            super().save_checkpoint(checkpoint_file)

    def _dump_meters(self, meters, mode):
        if args.summary_file is not None:
            meters_kv = meters._canonize_values("avg")
            meters_kv["mode"] = mode
            meters_kv["epoch"] = self.current_epoch
            with open(args.summary_file, "a") as f:
                f.write(io.dumps_json(meters_kv))
                f.write("\n")

    def _prepare_dataset(self, epoch_size, mode):
        pass

    def _get_player(self, number, mode, index):
        if args.task == "path":
            test_len = args.gen_test_len
            test_mode = (
                args.model == "dlm"
                and number == args.curriculum_graduate
                and "mining" in mode
            ) or ("test" in mode)
            dist_range = (test_len, test_len) if test_mode else (1, args.gen_max_len)
            player = make_env(args.task, number, dist_range=dist_range)
        elif "nlrl" not in args.task or "test" not in mode:
            player = make_env(args.task, number)
        else:
            # testing env. for NLRL
            # suppose 5 variations per env.
            player = make_env(args.task, number, variation_index=(index % 5))

        player.restart()
        return player

    def _get_result_given_player(self, index, meters, number, player, mode):
        assert mode in [
            "train",
            "test",
            "mining",
            "mining-stoch",
            "mining-deter",
            "inherit",
            "test-inter",
            "test-inter-deter",
            "test-deter",
        ], f"{mode}"

        if args.test_only and args.test_inter:
            mode = "test-inter"
        elif args.test_only and not args.test_inter:
            mode = "test"

        params = dict(
            eval_only=True,
            eval_critic=False,
            number=number,
            play_name="{}_epoch{}_episode{}".format(mode, self.current_epoch, index),
        )
        backup = None
        if mode == "train":
            params["eval_only"] = False
            params["eval_critic"] = True
            params["dataset"] = self.valid_action_dataset
            params["entropy_beta"] = self.entropy_beta
            meters.update(lr=self.lr, entropy_beta=self.entropy_beta)
        elif "test" in mode:
            params["dump"] = True
            params["use_argmax"] = "deter" in mode or args.test_only
            # print(params['use_argmax'])
        else:
            backup = copy.deepcopy(player)
            #            params['use_argmax'] = self.is_candidate
            params["use_argmax"] = "deter" in mode

        if mode == "train":
            mergedfc = []
            for i in range(args.ntrajectory):
                succ, score, traj, length, last_next_value, optimal, history = (
                    run_episode(
                        player, self.model, mode, need_restart=(i != 0), **params
                    )
                )
                if args.task in ["sort", "path"]:
                    meters.update(
                        number=number,
                        succ=succ,
                        score=score,
                        length=length,
                        optimal=optimal,
                    )
                else:
                    meters.update(number=number, succ=succ, score=score, length=length)
                feed_dict = make_data(
                    traj,
                    args.gamma,
                    succ,
                    last_next_value[0],
                    lam=args.lam,
                    isQnet=self.model.isQnet,
                )
                # content from valid_move dataset
                if args.pred_weight != 0.0:
                    states, actions, labels = self.valid_action_dataset.sample_batch(
                        args.batch_size
                    )
                    feed_dict["pred_states"] = as_tensor(states)
                    feed_dict["pred_actions"] = as_tensor(actions)
                    feed_dict["valid"] = as_tensor(labels).float()
                mergedfc.append(feed_dict)

            for k in feed_dict.keys():
                if k not in ["rewards"]:  # reward not used to update loss
                    if type(mergedfc[0][k]) is list:
                        f1 = [j[k][0] for j in mergedfc]
                        f2 = [j[k][1] for j in mergedfc]
                        feed_dict[k] = [torch.cat(f1, dim=0), torch.cat(f2, dim=0)]
                    else:
                        feed_dict[k] = torch.cat([j[k] for j in mergedfc], dim=0)
            feed_dict["entropy_beta"] = as_tensor(self.entropy_beta).float()
            feed_dict["eval_critic"] = as_tensor(True)
            feed_dict["training"] = as_tensor(True)

            if feed_dict["advantages"].shape[0] > 1 and not args.no_adv_norm:
                feed_dict["advantages"] = (
                    feed_dict["advantages"] - feed_dict["advantages"].mean()
                ) / (feed_dict["advantages"].std() + 10**-7)

            if args.use_gpu:
                feed_dict = as_cuda(feed_dict)
            self.model.train()
            return feed_dict
        else:
            succ, score, traj, length, last_next_value, optimal, history = run_episode(
                player, self.model, mode, **params
            )
            global succ_rate
            global tot_step
            print("************ total step {} ************".format(tot_step))
            succ_rate[tot_step // 10] += succ
            tot_step += 1
            if args.task in ["sort", "path"]:
                meters.update(
                    number=number,
                    succ=succ,
                    score=score,
                    length=length,
                    optimal=optimal,
                )
                message = (
                    "> {} iter={iter}, number={number}, succ={succ}, "
                    "score={score:.4f}, length={length}, optimal={optimal}"
                ).format(mode, iter=index, **meters.val)
            else:
                meters.update(number=number, succ=succ, score=score, length=length)
                message = (
                    "> {} iter={iter}, number={number}, succ={succ}, "
                    "score={score:.4f}, length={length}"
                ).format(mode, iter=index, **meters.val)
            self.histories.append(history)
            return message, dict(
                succ=succ, number=number, backup=backup, history=history
            )

    def _extract_info(self, extra):
        return extra["succ"], extra["number"], extra["backup"]

    def _get_accuracy(self, meters):
        return meters.avg["succ"]

    def _get_threshold(self):
        candidate_relax = 0 if self.is_candidate else args.candidate_relax
        return super()._get_threshold() - self.curriculum_thresh_relax * candidate_relax

    def _upgrade_lesson(self):
        super()._upgrade_lesson()
        # Adjust lr & entropy_beta w.r.t different lesson progressively.
        self.lr *= args.lr_decay
        self.entropy_beta *= args.entropy_beta_decay
        self.set_learning_rate(self.lr)

    def _train_epoch(self, epoch_size):
        global global_times, run_times
        global_times += 1
        model = self.model
        meters = GroupMeters()
        self._prepare_dataset(epoch_size, mode="train")
        self.lr *= args.lr_decay

        def train_func(index):
            model.eval()
            feed_dict = self._get_train_data(index, meters)
            model.train()
            if not args.no_shuffle_minibatch:
                nbatch = feed_dict["states"].shape[0]
                minibatch_size = args.batch_size
                inds = np.arange(nbatch)
                np.random.shuffle(inds)
                for _ in range(args.noptepochs):
                    for start in range(0, nbatch, minibatch_size):
                        end = start + minibatch_size
                        mbinds = inds[start:end]
                        subfeed_dict = {}
                        for k in feed_dict.keys():
                            if (
                                type(feed_dict[k]) == torch.Tensor
                                and len(feed_dict[k].shape) != 0
                                and k not in ("pred_states", "pred_actions", "valid")
                            ):
                                subfeed_dict[k] = feed_dict[k][mbinds, ...]
                            else:
                                subfeed_dict[k] = feed_dict[k]
                        message, _ = self._train_step(subfeed_dict, meters)
            else:
                for _ in range(args.noptepochs):
                    message, _ = self._train_step(feed_dict, meters)
            return message

        # For $epoch_size times, do train_func with tqdm progress bar.
        tqdm_for(epoch_size, train_func)
        logger.info(
            meters.format_simple(
                "> Train Epoch {:5d}: ".format(self.current_epoch), compressed=False
            )
        )
        self._dump_meters(meters, "train")
        if not self.is_graduated:
            self._take_exam(train_meters=copy.copy(meters))

        self.model.stoch_decay(self.current_number, meters.avg["succ"])

        i = self.current_epoch
        # if args.save_interval is not None and i % args.save_interval == 0:
        #     self.save_checkpoint(str(i))
        if args.test_interval is not None and i % args.test_interval == 0:
            ret = self.test_cjy()
            if (
                args.save_interval is not None
                and i % args.save_interval == 0
                and ret[-1].avg["score"] >= 1
            ):
                self.save_checkpoint(str(i) + "_" + str(ret[-1].avg["score"]))
            print(
                "performance of checkpoint ",
                i,
                ": ",
                ret[-1].avg["score"],
                "with index: ",
                global_times,
                " current run id: ",
                run_times,
            )
            wandb.log({"Reward": ret[-1].avg["score"]})

        return meters

    def _early_stop(self, meters):
        t = args.early_drop_epochs
        if t is not None and self.current_epoch > t * (self.nr_upgrades + 1):
            return True
        return super()._early_stop(meters)

    def train(self):
        self.valid_action_dataset = ValidActionDataset()
        self.lr = args.lr
        self.entropy_beta = args.entropy_beta
        return super().train()

    def test(self):
        ret1 = super().test()
        ret2 = super().advanced_test(inter=False, deter=True)
        if args.model != "dlm":
            return ret1 if ret1[-1].avg["score"] > ret2[-1].avg["score"] else ret2

        ret1 = super().advanced_test(inter=True, deter=False)
        ret2 = super().advanced_test(inter=True, deter=True)
        # print("test -> length(self.histories): ", len(self.histories))
        # for i in range(len(self.histories)):
        #     print("test -> length(self.histories[{}]): ".format(i), len(self.histories[i]))
        results = self.generate_init_automaton()
        # self.generate_state_description(results)
        return ret1 if ret1[-1].avg["score"] > ret2[-1].avg["score"] else ret2

    def test_cjy(self):
        ret1 = super().advanced_test(inter=True, deter=False)
        return ret1

    def process_history(self, histories, largeset_letter=3):
        states_set = set()
        transitions = {}  # 按trajectory的顺序储存
        accept_states = set()
        alphabet = set()
        state_letter_mapping = {}
        letter_state_mapping = {}
        for i, history in enumerate(histories):
            episode_transitions = {}
            episode_states = []
            episode_actions = []
            for j, transition in enumerate(history):
                # print("action: ", a[0], " transition: ", a[1], " reward: ", a[2], " p_star_dict: ", a[3])
                action = transition[0]
                action = action_to_letter(action)
                state = transition[3]
                # print("state: ", state)
                _s = [v["bool"] for v in state.values()]
                ks = [k for k in state.keys()]
                # print("_s: ", _s)
                state_tensor = torch.tensor(_s, dtype=torch.int32)
                # print("state_tensor: ", state_tensor)
                # 一些state先被创建了新state，后被start state/accept state包含，需要处理；start states和accept states重合，需要处理
                if transition[1].split("->")[0] == "1":
                    letter = "1"
                elif transition[1].split("->")[1] == "5":
                    letter = "2"
                else:
                    if state_letter_mapping.get(tensor_to_tuple(state_tensor)) is None:
                        letter = str(largeset_letter)
                        largeset_letter += 1
                    else:  # not None
                        letter = state_letter_mapping[tensor_to_tuple(state_tensor)]
                ori_letter = state_letter_mapping.get(tensor_to_tuple(state_tensor))
                if ori_letter and ori_letter != letter:  # two different mapping
                    # print('repeated!', ori_letter, ":", letter)
                    letter = ori_letter

                # 额外加入一个key predicate
                # if letter == "2":
                #     state_tensor = torch.cat((state_tensor, torch.tensor([1])))
                # else:
                #     state_tensor = torch.cat((state_tensor, torch.tensor([0])))

                episode_states.append(letter)
                episode_actions.append(action)

                # add alphabet
                alphabet.add(str(action))

                # add state to state letter mapping and letter state mapping
                state_letter_mapping[tensor_to_tuple(state_tensor)] = letter
                if letter_state_mapping.get(letter) is None:
                    letter_state_mapping[letter] = set()
                letter_state_mapping[letter].add(tensor_to_tuple(state_tensor))

                # add state
                states_set.add(letter)

                # add accept state
                if letter == "2":
                    accept_states.add(letter)

                # add transition
                if j != 0:
                    episode_transitions[
                        (episode_states[j - 1], episode_actions[j - 1])
                    ] = letter
                # episode_transitions[(letter, str(action))] = next_letter

                transitions[i] = episode_transitions

        total_transitions = {}
        for kt in transitions.keys():
            for k, v in transitions[kt].items():
                total_transitions[k] = v

        return (
            states_set,
            alphabet,
            accept_states,
            total_transitions,
            state_letter_mapping,
            letter_state_mapping,
            ks,
        )

    def generate_init_automaton(self):
        """
        states = {'A', 'B', 'C', 'D', 'E', 'F'}
        alphabet = {'0', '1'}
        start_state = 'A'
        accept_states = {'C', 'D'}
        transitions = {
            ('A', '0'): 'A', ('A', '1'): 'B',
            ('B', '0'): 'B', ('B', '1'): 'A',
            ('C', '0'): 'C', ('C', '1'): 'D',
            ('D', '0'): 'D', ('D', '1'): 'C',
            ('E', '0'): 'E', ('E', '1'): 'F',
            ('F', '0'): 'F', ('F', '1'): 'E',
        }
        """
        ks = []

        divide = True
        if divide:
            fail_histories, succ_histories = self.divide_trajectory()
            print("length of fail_histories: ", len(fail_histories))
            print("length of succ_histories: ", len(succ_histories))
        else:
            succ_histories = self.histories
        print("length of histories: ", len(self.histories))
        (
            states_set,
            alphabet,
            accept_states,
            total_transitions,
            state_letter_mapping,
            letter_state_mapping,
            ks,
        ) = self.process_history(succ_histories)

        print("states: ", states_set)
        print("alphabet: ", alphabet)
        print("start_states: ", "1")
        print("accept_states: ", accept_states)
        print("total_transitions: ", total_transitions)
        print("------------------- After Hopcroft ---------------------")
        automaton_prompt = ""
        (min_states, min_start_state, min_accept_states, min_transitions) = (
            hopcroft_minimization(
                states_set, alphabet, "1", accept_states, total_transitions
            )
        )
        print("states =", min_states)
        # print("start_state =", min_start_state)
        # print("accept_states =", min_accept_states)
        # print("transitions =", min_transitions)
        automaton_prompt += "start_state = " + str(min_start_state) + "\n"
        automaton_prompt += "accept_states = " + str(min_accept_states) + "\n"
        automaton_prompt += "transitions = " + str(min_transitions) + "\n"
        automaton = {}
        automaton["states"] = min_states
        automaton["start_state"] = min_start_state
        automaton["accept_states"] = min_accept_states
        automaton["transitions"] = min_transitions
        print("automaton_prompt: ", automaton_prompt)
        print("----------------------------------------")
        # for k, v in state_letter_mapping.items():
        #     print(f"state_letter_mapping[{k}]: {v}")
        # print("----------------------------------------")
        # 额外加入一个key predicate
        # ks.append(62)
        print("key predicates no.: ", ks)
        # p_dict = deepcopy(self.histories[0][0][4])
        # p_dict[62] = {"bool": 0, "desc": "The object has reached the goal."}  # bool不重要，只是为了desc
        # TODO: 生成state的描述的prompt (重新开一个函数)
        system_prompt = get_public_prompt()
        input_prompt = set_input_prompt(
            p_dict=self.histories[0][0][4],
            key_predicates=ks,
            automaton=automaton,
            letter_state_mapping=letter_state_mapping,
            failure_trajectory=fail_histories,
        )
        task_prompt = get_task_prompt()
        # TODO: Add back LLM later
        return dict(
            min_states=min_states,
            alphabet=alphabet,
            min_start_state=min_start_state,
            min_accept_states=min_accept_states,
            min_transitions=min_transitions,
            state_letter_mapping=state_letter_mapping,
            letter_state_mapping=letter_state_mapping,
            key_predicates_num=ks,
        )

    def divide_trajectory(self):
        # 将最后一步不为5的trajectory保存到fail_histories中
        # TODO: 记录失败的transition，用于后续的分析，需要判定一下什么样的transition算失败的 -> 将traj和一个大概的判断输入LLM，让它具体分析哪里错了
        fail_histories = []
        succ_histories = []
        for i, history in enumerate(self.histories):
            # 额外加入一个key predicate
            # for j, transition in enumerate(history):
            #     # print("action: ", a[0], " transition: ", a[1], " reward: ", a[2], " p_star_dict: ", a[3])
            #     transition[3][62] = {}
            #     transition[3][62]['desc'] = "The object has reached the goal."
            #     if transition[1].split("->")[0] == "5":
            #         transition[3][62]['bool'] = torch.tensor(False)
            #     else:
            #         transition[3][62]['bool'] = torch.tensor(False)
            if history[-1][1].split("->")[1] != "5":
                fail_histories.append(history)
            else:
                succ_histories.append(history)
        return fail_histories, succ_histories

    def generate_state_description(self, input_dict):
        """generate state description for each state in the automaton"""
        key_predicates_num = input_dict["key_predicates_num"]
        state_letter_mapping = input_dict["state_letter_mapping"]
        letter_state_mapping = input_dict["letter_state_mapping"]
        p_star_dict = self.histories[0][0][3]
        p_dict = self.histories[0][0][4]
        state_prompt_dict = {}
        for k, v in p_star_dict.items():
            print(f"p_star_dict[{k}]: {v['desc']}")
        print("----------------------------------------")
        for k, v in letter_state_mapping.items():
            prompt = ""
            state_prompt_dict[k] = []
            for i, s in enumerate(v):
                # print("value: ", v)
                # print("boolean value: ", bv)
                for j, bv in enumerate(s):
                    s = p_star_dict[key_predicates_num[i]]["desc"]
                    if bv == 0:
                        s = s.replace("is between", "is not between")
                    prompt += s + " "
                state_prompt_dict[k].append(prompt)
                print(f"state_prompt_dict[{k}][{i}]: {prompt}")


def main(run_id):
    global run_times
    if args.dump_dir is not None:
        if args.runs > 1:
            args.current_dump_dir = os.path.join(
                args.dump_dir, "run_{}".format(run_times)
            )
            io.mkdir(args.current_dump_dir)
        else:
            args.current_dump_dir = args.dump_dir
        args.checkpoints_dir = os.path.join(args.current_dump_dir, "checkpoints")
        io.mkdir(args.checkpoints_dir)
        args.summary_file = os.path.join(args.current_dump_dir, "summary.json")

    logger.info(format_args(args))

    model = Model()
    optimizer = get_optimizer(args.optimizer, model, args.lr)
    if args.accum_grad > 1:
        optimizer = AccumGrad(optimizer, args.accum_grad)

    if args.use_gpu:
        model.cuda()
    trainer = MyTrainer.from_args(model, optimizer, args)

    if args.load_checkpoint is not None:
        trainer.load_checkpoint(args.load_checkpoint)

    if 1:
        print("load ddpg controller!")
        trainer.model.ddpg_controller[0].load("fetch_model_4/actor_0_0")
        trainer.model.ddpg_controller[1].load("fetch_model_4/actor_1_0")
        trainer.model.ddpg_controller[2].load("fetch_model_4/actor_2_0")
        trainer.model.ddpg_controller[3].load(
            # "/home/user/dlm-ddpg-fetch/fetch_model/actor_2_0")
            # "./fetch_model_4/actor_3_0"
            "models_PickAndPlace/actor_3_0"
        )

    if not args.test_only:
        wandb.init(
            settings=wandb.Settings(start_method="fork"),
            name=args.dump_dir[6:],
            project="dlm-fetch-aws-spec4",
            entity="neuralsymbolic",
        )

    if args.test_only:
        trainer.current_epoch = 0
        ret = trainer.test()
        global succ_rate
        global tot_step
        print("*************** tot_step: {} ***************".format(tot_step))
        succ_sum = 0
        for sr in succ_rate:
            print("*************** succ_rate: {} ***************".format(sr / 10))
            succ_sum += sr
        print(
            "*************** total succ_rate: {} ***************".format(
                succ_sum / tot_step
            )
        )
        return None, ret

    graduated = trainer.train()
    trainer.save_checkpoint("last")
    test_meters = trainer.test() if graduated or args.test_not_graduated else None
    run_times += 1
    return graduated, test_meters


if __name__ == "__main__":
    stats = []
    nr_graduated = 0
    print("runs: ", args.runs)
    for i in range(args.runs):
        graduated, test_meters = main(i)
        logger.info("run {}".format(i + 1))

        if test_meters is not None:
            for j, meters in enumerate(test_meters):
                if len(stats) <= j:
                    stats.append(GroupMeters())
                stats[j].update(
                    number=meters.avg["number"], test_succ=meters.avg["succ"]
                )

            for meters in stats:
                logger.info(
                    "number {}, test_succ {}".format(
                        meters.avg["number"], meters.avg["test_succ"]
                    )
                )

        if not args.test_only:
            nr_graduated += int(graduated)
            logger.info("graduate_ratio {}".format(nr_graduated / (i + 1)))
            if graduated:
                for j, meters in enumerate(test_meters):
                    stats[j].update(grad_test_succ=meters.avg["succ"])
            if nr_graduated > 0:
                for meters in stats:
                    logger.info(
                        "number {}, grad_test_succ {}".format(
                            meters.avg["number"], meters.avg["grad_test_succ"]
                        )
                    )
