# This code is borrowed from original repo https://github.com/GFNOrg/gflownet
import argparse
import gzip
import os
import pdb
import pickle
import threading
import time
import warnings
from copy import deepcopy

import networkx as nx
import numpy as np
import torch
import torch.nn as nn

import model_atom, model_block, model_fingerprint
from mol_mdp_ext import MolMDPExtended, BlockMoleculeDataExtended
import functools
# from utils.chem import compute_num_of_modes

import sys
import io

import wandb
import random
from scipy.stats import spearmanr
from contextlib import contextmanager
from rdkit import Chem,DataStructs
from typing import Iterable, Dict, Any, List, Tuple


if hasattr(sys.stdout, "buffer"):
    sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', line_buffering=True)

warnings.filterwarnings('ignore')

tmp_dir = "./tmp"
os.makedirs(tmp_dir, exist_ok=True)

parser = argparse.ArgumentParser()

parser.add_argument("--learning_rate", default=5e-4, help="Learning rate", type=float)
parser.add_argument("--mbsize", default=4, help="Minibatch size", type=int)
parser.add_argument("--opt_beta", default=0.9, type=float)
parser.add_argument("--opt_beta2", default=0.999, type=float)
parser.add_argument("--opt_epsilon", default=1e-8, type=float)
parser.add_argument("--nemb", default=256, help="#hidden", type=int)
parser.add_argument("--min_blocks", default=2, type=int)
parser.add_argument("--max_blocks", default=8, type=int)
parser.add_argument("--num_iterations", default=50000, type=int) # led-gfn & fl-gfn: 250000
parser.add_argument("--save_every", default=2000, type=int, help="Steps to perform test and saving. Save models/training informations if do_save=True")
parser.add_argument('--eval_every', default=500, type=int, help="Steps to evaluate the stats generated by the trained policy.")
parser.add_argument("--num_conv_steps", default=10, type=int)
parser.add_argument("--log_reg_c", default=2.5e-5, type=float)
parser.add_argument("--reward_exp", default=4, type=float) # led-gfn:8, fl-gfn: 4
parser.add_argument("--reward_norm", default=8, type=float)
parser.add_argument("--sample_prob", default=1, type=float)
parser.add_argument("--R_min", default=0.1, type=float)
parser.add_argument("--leaf_coef", default=10, type=float)
parser.add_argument("--clip_grad", default=0, type=float)
parser.add_argument("--clip_loss", default=0, type=float)
parser.add_argument("--replay_mode", default='online', type=str)
parser.add_argument("--bootstrap_tau", default=0, type=float)
parser.add_argument("--weight_decay", default=0, type=float)
parser.add_argument("--random_action_prob", default=0.1, type=float) 
parser.add_argument("--array", default='')
parser.add_argument("--repr_type", default='block_graph')
parser.add_argument("--model_version", default='v4')
parser.add_argument("--run", default=0, help="run", type=int)
parser.add_argument("--save_path", default='results/')
parser.add_argument("--proxy_path", default='./data/pretrained_proxy')
parser.add_argument("--print_array_length", default=False, action='store_true')
parser.add_argument("--progress", default='yes')
parser.add_argument("--floatX", default='float64')
parser.add_argument("--include_nblocks", default=False)
parser.add_argument("--balanced_loss", default=True)
parser.add_argument("--early_stop_reg", default=0.1, type=float) # led-gfn:0
parser.add_argument("--initial_log_Z", default=30, type=float)
parser.add_argument("--objective", default='fm', type=str)
parser.add_argument("--entropy_coeff",  default=1.0,  help="Temperature／entropy coefficient for policy sampling", type=float)

parser.add_argument("--ignore_parents", default=False)

parser.add_argument("--subtb_lambda", default=0.99, type=float)

parser.add_argument("--model_seed", default=0, type=int)
parser.add_argument("--sampling_seed", default=0, type=int)

# Epsilon (epsilon-greedy) decay schedule
parser.add_argument("--use_exp_weight_decay", choices=[0,1], default=0, type=int)
parser.add_argument("--exp_weight_sched", type=str, default='linear', choices=['linear','cosine'])
parser.add_argument("--exp_weight_warm_frac", type=float, default=0.0)
parser.add_argument("--eps_start", type=float, default=0.2, help="Starting epsilon at step 0")
parser.add_argument("--eps_end_frac", type=float, default=0.5, help="Fraction of total steps by which epsilon reaches random_action_prob")

# Alpha schedule (active only when alpha is NOT trainable)
parser.add_argument("--use_alpha_scheduler", choices=[0,1], default=1, type=int)
parser.add_argument("--alpha_warm_frac", type=float, default=0.8)
parser.add_argument("--alpha_sched", type=str, default='hold_exp', choices=['linear','cos','hold_exp','cyc','poly'])

parser.add_argument("--alpha",default=0.5, type=float)
# trainable alpha
parser.add_argument("--mode_threshold", default=7.0, type=float, help="Threshold of reward for mode detection")
parser.add_argument("--tanimoto_threshold", default=0.7, type=float, help="Tanimoto threshold for mode separation")
parser.add_argument("--top_k",nargs="+",type=int,default=[10,50,100,1000],help="Compute metrics for multiple top-K cutoffs, e.g., --top_k 10 100 1000")

parser.add_argument('--wdb', choices=[0,1], default=1, type=int, help="Whether to use wandb")
parser.add_argument('--wdb_project', default='alphagfn-mols', type=str, help="wandb project name")

# forward-looking
parser.add_argument("--fl",choices=[0,1], default=0, type=int, help="Whether to use forward-looking GFlowNets")

# fix bugs
parser.add_argument('--debug_no_threads',choices=[0,1], default=1, type=int, help="Whether to use multithreading for sampling. If set to 1, it will use single thread for sampling. If set to 0, it will use multithreading for sampling. Reproducibility is ensured if 2, otherwise no reproducibility of results.")
parser.add_argument('--num_threads', default=8, type=int, help="Number of threads for sampling if multithread is used") # there are bugs here

# Evaluation repository deduplication (OFF by default)

parser.add_argument("--eval_dedup_key", choices=["smiles",'none'], default="none",
                    help="Key used for dedup when eval_dedup=1. 'smiles' uses molecule.smiles; 'none' disables dedup.")

# test
parser.add_argument('--vec',choices=[0,1],default=0,type=int)

def tf(x,dtype,device):
    return torch.tensor(x, device=device).to(dtype)
# tf = lambda x: torch.tensor(x, device=device).to(args.floatX)
def tint(x, device):
    return torch.tensor(x, device=device).long()
# tint = lambda x: torch.tensor(x, device=device).long()
def to_log(x,device):
    if isinstance(x, list):
        return torch.log(torch.cat(x)).to(device)
    elif isinstance(x, torch.Tensor):
        return torch.log(x).to(device)
    else:
        raise ValueError(f'Unknown type of x {type(x)}')
    
def set_model_seed(args):
    # Model-level seeds
    torch.manual_seed(args.model_seed)
    np.random.seed(args.model_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.model_seed)
    random.seed(args.model_seed)

def _dict_to_str(d: Dict) -> str:
    def _fmt(v):
        if isinstance(v, (float, np.floating)):
            return f"{float(v):.3f}"
        elif isinstance(v, (int, np.integer, str)):
            return str(v)
        else:
            return str(v)
    return ', '.join(f'{k}={_fmt(v)}' for k, v in d.items())

def timer(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        start = time.perf_counter()
        result = func(*args, **kwargs)
        end = time.perf_counter()
        print(f"[{func.__name__}] Executed in {end-start:.6f}s", flush=True)
        return result
    return wrapper
class ExpWeightScheduler:
    """Anneal epsilon (epsilon-greedy) from `start` down to `end` between warm_frac and end_frac of total steps."""
    def __init__(self, end: float, total_steps: int, kind: str = "linear",
                 warm_frac: float = 0.0, start: float = 0.5, end_frac: float = 0.5):
        assert warm_frac < end_frac, "Cannot set warm_frac >= end_frac"
        import math
        self.math = math
        self.start = float(start)
        self.end = float(end)
        self.T = max(1, int(total_steps))

        wf = max(0.0, min(1.0, float(warm_frac)))
        ef = max(0.0, min(1.0, float(end_frac)))
        self.warm_step = int(round(wf * (self.T - 1)))
        self.end_step  = int(round(max(wf, ef) * (self.T - 1)))
        self.kind = kind
    def __call__(self, step: int) -> float:
        # t ∈ [0, T-1]
        # step+=1 # because steps are 0~num_iterations
        t = max(0, min(int(step), self.T - 1))

        if t <= self.warm_step:
            return self.start

        span = max(1, self.end_step - self.warm_step)   
        x = (t - self.warm_step) / span                 
        x = max(0.0, min(1.0, x))                       

        if self.kind == "linear":
            return self.start * (1.0 - x) + self.end * x
        elif self.kind == "cosine":
            w = 0.5 * (1.0 + self.math.cos(self.math.pi * x))
            return self.end + (self.start - self.end) * w
        else:
            raise ValueError(self.kind)


def _make_loggable_args(args) -> Dict[str, object]:
    """Pick scalar/bounded list args and convert numpy scalars so wandb can log them."""
    def _is_prim(x):
        return isinstance(x, (bool, int, float, str, np.bool_, np.integer, np.floating))
    def _to_py(x):
        if isinstance(x, (np.bool_,)):
            return bool(x)
        if isinstance(x, (np.integer,)):
            return int(x)
        if isinstance(x, (np.floating,)):
            return float(x)
        return x

    log_args = {}
    for k, v in vars(args).items():
        # skip obviously non-serializable / big objects
        if k in {'alpha'}:
            k=f"{k}_init"
        elif k in {'random_action_prob'}:
            k=f"{k}_final"
        if k in {"floatX"}:   # torch dtype; easy to break logging
            continue
        if _is_prim(v):
            log_args[f"{k}"] = _to_py(v)
        elif isinstance(v, (list, tuple)) and len(v) <= 64 and all(_is_prim(x) for x in v):
            log_args[f"{k}"] = [_to_py(x) for x in v]
        # else: skip (sets, dicts, Namespace, callables, massive lists, tensors, etc.)
    return log_args

class AlphaScheduler:
    """Schedule a fixed (non-trainable) alpha from alpha0 to alpha_final."""
    def __init__(self, total_steps: int, alpha0: float, warm_frac: float = 0.4,
                alpha_final: float = 0.5):
        import math
        self.math = math
        self.T = max(1, int(total_steps))
        self.a0 = float(alpha0)
        self.af = float(alpha_final)
        self.warm_frac = max(0.0, min(1.0, float(warm_frac)))
        self.T_hold = int(self.T * self.warm_frac)
        self.decay_k = 4.0
        self.poly_p = 0.5
    def __call__(self, step: int) -> float:
        # step+=1
        t = max(0, min(step, self.T-1))
        if t < self.T_hold:
            return self.a0
        tail_t = t - self.T_hold
        tail_len = max(1, self.T - self.T_hold)
        r = self.math.exp(-self.decay_k * tail_t / tail_len)
        return self.af + (self.a0 - self.af) * r



@torch.jit.script
def detailed_balance_loss(P_F, P_B, F, R, traj_lengths,alpha):
    bias=torch.log(alpha/(1-alpha))
    cumul_lens = torch.cumsum(torch.cat([torch.zeros(1, device=traj_lengths.device), traj_lengths]), 0).long()
    total_loss = torch.zeros(1, device=traj_lengths.device)
    for ep in range(traj_lengths.shape[0]):
        offset = cumul_lens[ep]
        T = int(traj_lengths[ep])
        for i in range(T):
            # This flag is False if the endpoint flow of this trajectory is R == F(s_T)
            flag = float(i + 1 < T)
            acc = (F[offset + i] - F[offset + min(i + 1, T - 1)] * flag - R[ep] * (1 - flag)
                   + P_F[offset + i] - P_B[offset + i]+bias) # is here P_B[offset+i] * flag?
            total_loss += acc.pow(2)
    return total_loss


# borrowed from https://github.com/ling-pan/FL-GFN/blob/main/mols/gflownet.py
@torch.jit.script
def forward_looking_detailed_balance_loss(P_F, P_B, F, R, traj_lengths, transition_rs, alpha):
    """
    the original implementation in Pan et al may be incorrect. 
    curr_F_next  = F[offset + min(i + 1, T - 1)] is wrong
    curr_F_next = F[offset + min(i + 1, T - 1)] * flag may be the right implementation, but it is not the case in the paper Pan et al.
    """
    bias=torch.log(alpha/(1-alpha))
    cumul_lens = torch.cumsum(torch.cat([torch.zeros(1, device=traj_lengths.device), traj_lengths]), 0).long()
    
    total_loss = torch.zeros(1, device=traj_lengths.device)
    for ep in range(traj_lengths.shape[0]):
        offset = cumul_lens[ep]
        T = int(traj_lengths[ep])
        for i in range(T): # or T-1 ? maybe changing T to T-1 is more reliable
            flag = float(i + 1 < T)
            curr_PF = P_F[offset + i]
            curr_PB = P_B[offset + i]
            curr_F = F[offset + i]
            curr_F_next = F[offset + min(i + 1, T - 1)]
            curr_r = transition_rs[offset + i]
            acc = curr_F + curr_PF - curr_F_next - curr_PB - curr_r+bias

            total_loss += acc.pow(2)

    return total_loss

@torch.jit.script
def trajectory_balance_loss(P_F, P_B, logZ, R, traj_lengths,alpha):
    bias=torch.log(alpha/(1-alpha))
    cumul_lens = torch.cumsum(torch.cat([torch.zeros(1, device=traj_lengths.device), traj_lengths]), 0).long()
    total_loss = torch.zeros(1, device=traj_lengths.device)
    for ep in range(traj_lengths.shape[0]):
        offset = cumul_lens[ep]
        T = int(traj_lengths[ep])
        total_loss += (logZ - R[ep] + P_F[offset:offset+T].sum() - P_B[offset:offset+T].sum()+T*bias).pow(2)
    return total_loss / float(traj_lengths.shape[0])

@torch.jit.script
def tb_lambda_loss(P_F, P_B, F, R, traj_lengths, Lambda,alpha):
    bias=torch.log(alpha/(1-alpha))
    cumul_lens = torch.cumsum(torch.cat([torch.zeros(1, device=traj_lengths.device), traj_lengths]), 0).long()
    total_loss = torch.zeros(1, device=traj_lengths.device)
    total_Lambda = torch.zeros(1, device=traj_lengths.device)
    for ep in range(traj_lengths.shape[0]):
        offset = cumul_lens[ep]
        T = int(traj_lengths[ep])
        for i in range(T):
            for j in range(i, T):
                # This flag is False if the endpoint flow of this subtrajectory is R == F(s_T)
                flag = float(j + 1 < T)
                acc = F[offset + i] - F[offset + min(j + 1, T - 1)] * flag - R[ep] * (1 - flag)
                for k in range(i, j + 1):
                    acc += P_F[offset + k] - P_B[offset + k]
                total_loss += (acc+(j-i+1)*bias).pow(2) * Lambda ** (j - i + 1)
                total_Lambda += Lambda ** (j - i + 1)
    return total_loss / total_Lambda

@torch.jit.script
def forward_looking_tb_lambda_loss(P_F, P_B, F, R, traj_lengths,transition_rs, Lambda,alpha):
    bias=torch.log(alpha/(1-alpha))
    cumul_lens = torch.cumsum(torch.cat([torch.zeros(1, device=traj_lengths.device), traj_lengths]), 0).long()
    total_loss = torch.zeros(1, device=traj_lengths.device)
    total_Lambda = torch.zeros(1, device=traj_lengths.device)
    for ep in range(traj_lengths.shape[0]):
        offset = cumul_lens[ep]
        T = int(traj_lengths[ep])
        for i in range(T):
            for j in range(i, T):
                # This flag is False if the endpoint flow of this subtrajectory is R == F(s_T)
                flag = float(j + 1 < T)
                acc = F[offset + i] - F[offset + min(j + 1, T - 1)] # following Pan et al, we do not use flag and R[ep] here. However, should we use * flag, or change T to T-1?
                for k in range(i, j + 1):
                    acc += P_F[offset + k] - P_B[offset + k]- transition_rs[offset + k]
                total_loss += (acc+(j-i+1)*bias).pow(2) * Lambda ** (j - i + 1)
                total_Lambda += Lambda ** (j - i + 1)
    return total_loss / total_Lambda


class Dataset:
    def __init__(self, args, bpath, device, floatX=torch.double):
        self.test_split_rng = np.random.RandomState(142857)
        self.train_rng = np.random.RandomState(args.sampling_seed)
        self.sampling_gen = torch.Generator(device=device).manual_seed(args.sampling_seed)
        self.sampling_seed = args.sampling_seed
        self.train_mols = []
        self.test_mols = []
        self.train_mols_map = {}
        self.mdp = MolMDPExtended(bpath)
        self.mdp.post_init(device, args.repr_type, include_nblocks=args.include_nblocks)
        self.mdp.build_translation_table()
        self._device = device
        self.seen_molecules = set()
        self.stop_event = threading.Event()
        self.target_norm = [-8.6, 1.10]
        self.sampling_model = None
        self.sampling_model_prob = 0
        self.floatX = floatX
        self.mdp.floatX = self.floatX
        #######
        # This is the "result", here a list of (reward, BlockMolDataExt, info...) tuples
        self.sampled_mols = []

        get = lambda x, d: getattr(args, x) if hasattr(args, x) else d
        self.min_blocks = get('min_blocks', 2)
        self.max_blocks = get('max_blocks', 10)
        self.mdp._cue_max_blocks = self.max_blocks
        self.replay_mode = get('replay_mode', 'dataset')
        self.reward_exp = get('reward_exp', 1)
        self.reward_norm = get('reward_norm', 1)
        self.random_action_prob = get('random_action_prob', 0)
        self.R_min = get('R_min', 1e-8)
        self.ignore_parents = get('ignore_parents', False)
        self.early_stop_reg = get('early_stop_reg', 0)

        self.online_mols = []
        self.max_online_mols = 1000

        self.fl = args.fl
        self.record_samples = False  # if True, _get_sample_model() appends to sampled_mols; training keeps this False


    def _get(self, i, dset):
        if ((self.sampling_model_prob > 0 and # don't sample if we don't have to
             self.train_rng.uniform() < self.sampling_model_prob)
            or len(dset) < 32):
                return self._get_sample_model()
        # Sample trajectories by walking backwards from the molecules in our dataset

        # Handle possible multithreading issues when independent threads
        # add/substract from dset:
        while True:
            try:
                m = dset[i]
            except IndexError:
                i = self.train_rng.randint(0, len(dset))
                continue
            break
        if not isinstance(m, BlockMoleculeDataExtended):
            m = m[-1]

        r = m.reward
        done = 1

        samples = []
        # a sample is a tuple (parents(s), parent actions, reward(s), s, done)
        # an action is (blockidx, stemidx) or (-1, x) for 'stop'
        # so we start with the stop action, unless the molecule is already
        # a "terminal" node (if it has no stems, no actions).
        if len(m.stems):
            samples.append(((m,), ((-1, 0),), r, m, done))
            r = done = 0
        while len(m.blocks): # and go backwards
            parents, actions = zip(*self.mdp.parents(m))
            samples.append((parents, actions, r, m, done))
            r = done = 0
            m = parents[self.train_rng.randint(len(parents))]
        return samples

    def set_sampling_model(self, model, proxy_reward, sample_prob=0.5):
        self.sampling_model = model
        self.sampling_model_prob = sample_prob
        self.proxy_reward = proxy_reward

    def inverse_r2r(self, reward):
        # This inverse transform works only for reward > R_min
        return self.reward_norm * reward**(1/self.reward_exp)

    def _get_sample_model(self):
        m = BlockMoleculeDataExtended()
        samples = []
        max_blocks = self.max_blocks

        # Optional early-stop regularization: randomly force a STOP at a given step
        if self.early_stop_reg > 0 and np.random.uniform() < self.early_stop_reg:
            early_stop_at = np.random.randint(self.min_blocks, self.max_blocks + 1)
        else:
            early_stop_at = max_blocks + 1

        trajectory_stats = []
        for t in range(max_blocks):
            # Encode current molecule to model inputs
            s = self.mdp.mols2batch([self.mdp.mol2repr(m)])
            s_o, m_o = self.sampling_model(s)

            if t<self.min_blocks:
                m_o = m_o * 0 - 1000 

            q = torch.cat([m_o[:, 0].reshape(-1), s_o.reshape(-1)])
            if (not torch.isfinite(q).all()) or q.numel() == 0:
                print(f"[SKIP] _get_sample_model: raw logits invalid at step {t} "
                    f"(size={q.numel()}, dtype={q.dtype}). return [].", flush=True)
                return []

            # 4) Stable softmax for sampling. If probabilities degenerate, fall back to masked-uniform.
            probs = torch.softmax(q - q.max(), dim=0)
            s_sum = probs.sum()
            if (not torch.isfinite(s_sum)) or s_sum <= 0:
                probs = torch.ones_like(q)
                if t < self.min_blocks:
                    probs[0] = 0.0
                probs = probs / probs.sum()

            # 5) Sample an action (compatible with your random_action_prob and early_stop_reg)
            action = torch.multinomial(probs, 1, generator=self.sampling_gen).item()

            # Optional epsilon exploration: when t < min_blocks, lower bound = 1 to exclude STOP
            if self.random_action_prob > 0 and self.train_rng.uniform() < self.random_action_prob:
                lo = int(t < self.min_blocks)  # 1 if t < min_blocks (exclude STOP), else 0
                action = self.train_rng.randint(lo, q.shape[0])

            # Forced early stop if the pre-sampled step is reached
            if t == early_stop_at:
                action = 0  # STOP

            # ---- Metrics computed on the VALID domain (exclude STOP when t < min_blocks) ----
            valid_logits = q.clone()
            if t < self.min_blocks:
                valid_logits[0] = -float('inf')

            # Log-partition over valid actions (stable)
            valid_logZ = torch.logsumexp(valid_logits, dim=0)
            if (not torch.isfinite(valid_logZ)) or (not torch.isfinite(q[action])):
                print(f"[SKIP] _get_sample_model: non-finite metrics (logZ/action) at step {t}. return [].", flush=True)
                return []

            # Chosen action log-probability on the same valid domain
            taken_logp = q[action] - valid_logZ

            # If you must store probability, clamp to float32-safe exponent range; otherwise prefer log form
            p_taken = float(torch.exp(torch.clamp(taken_logp, max=80)).item())
            # Recommended alternative (safer): p_taken_log = float(taken_logp.item())

            # Forward entropy on the valid domain (stable softmax + safe renormalization)
            p_vec = torch.softmax(valid_logits - torch.nan_to_num(valid_logits.max(), nan=0.0), dim=0)
            p_vec = torch.nan_to_num(p_vec, nan=0.0).clamp_min_(0.0)
            s_sum = p_vec.sum()
            if (not torch.isfinite(s_sum)) or s_sum <= 0:
                # Fallback: uniform over the valid support
                mask = torch.isfinite(valid_logits)
                p_vec = torch.where(mask, torch.ones_like(valid_logits), torch.zeros_like(valid_logits))
                p_vec = p_vec / p_vec.sum()
            fwd_ent = float(-(p_vec * torch.log(p_vec + 1e-12)).sum().item())

            # ----- APPLY ACTION -----
            if t >= self.min_blocks and action == 0:
                # STOP is allowed: terminal
                r = self._get_reward(m)
                if self.fl:
                    r_fl = self._get_reward(m)
                    samples.append(((m,), ((-1, 0),), r, r_fl, None, 1))
                else:
                    samples.append(((m,), ((-1, 0),), r, None, 1))

                # backward: number of parents for the terminal state
                back_parent_cnt = len(self.mdp.parents(m)) if m is not None else 1
                trajectory_stats.append(
                    (p_taken, int(action), float(valid_logZ.item()), fwd_ent, int(back_parent_cnt))
                )
                break
            else:
                # Map flat action index (excluding STOP) to (block_idx, stem_idx)
                action_idx = max(0, action - 1)
                action_pair = (action_idx % self.mdp.num_blocks, action_idx // self.mdp.num_blocks)

                m_old = m
                m = self.mdp.add_block_to(m, *action_pair)

                # If new state is terminal (no stems), or we reached the last step, end trajectory
                if (len(m.blocks) and not len(m.stems)) or t == max_blocks - 1:
                    r = self._get_reward(m)
                    if self.fl:
                        r_fl = r
                        if self.ignore_parents:
                            samples.append(((m_old,), (action_pair,), r, r_fl, m, 1))
                        else:
                            samples.append((*zip(*self.mdp.parents(m)), r, r_fl, m, 1))
                    else:
                        if self.ignore_parents:
                            samples.append(((m_old,), (action_pair,), r, m, 1))
                        else:
                            samples.append((*zip(*self.mdp.parents(m)), r, m, 1))

                    # backward parents for the terminal state
                    back_parent_cnt = len(self.mdp.parents(m)) if m is not None else 1
                    trajectory_stats.append(
                        (p_taken, int(action), float(valid_logZ.item()), fwd_ent, int(back_parent_cnt))
                    )
                    break
                else:
                    # Non-terminal transition
                    if self.fl:
                        r_fl = self._get_reward(m)
                        if self.ignore_parents:
                            samples.append(((m_old,), (action_pair,), 0, r_fl, m, 0))
                        else:
                            samples.append((*zip(*self.mdp.parents(m)), 0, r_fl, m, 0))
                    else:
                        if self.ignore_parents:
                            samples.append(((m_old,), (action_pair,), 0, m, 0))
                        else:
                            samples.append((*zip(*self.mdp.parents(m)), 0, m, 0))

                    # backward parents for the next state m
                    back_parent_cnt = len(self.mdp.parents(m)) if m is not None else 1
                    trajectory_stats.append(
                        (p_taken, int(action), float(valid_logZ.item()), fwd_ent, int(back_parent_cnt))
                    )

        p = self.mdp.mols2batch([self.mdp.mol2repr(i) for i in samples[-1][0]])
        qp = self.sampling_model(p, None)
        qsa_p = self.sampling_model.index_output_by_action(
            p, qp[0], qp[1][:, 0],
            torch.tensor(samples[-1][1], device=self._device).long()
        )
        inflow = torch.logsumexp(qsa_p.flatten(), 0).item()

        if self.record_samples:
            self.sampled_mols.append((self.inverse_r2r(r), m, trajectory_stats, inflow, t + 1))

        if self.replay_mode == 'online' or self.replay_mode == 'prioritized':
            m.reward = r
            self._add_mol_to_online(r, m, inflow)

        return samples


    def _add_mol_to_online(self, r, m, inflow):
        if self.replay_mode == 'online':
            r = r + self.train_rng.normal() * 0.01
            if len(self.online_mols) < self.max_online_mols or r > self.online_mols[0][0]:
                self.online_mols.append((r, m))
            if len(self.online_mols) > self.max_online_mols:
                self.online_mols = sorted(self.online_mols)[max(int(0.05 * self.max_online_mols), 1):]
        elif self.replay_mode == 'prioritized':
            self.online_mols.append((abs(inflow - np.log(r)), m))
            if len(self.online_mols) > self.max_online_mols * 1.1:
                self.online_mols = self.online_mols[-self.max_online_mols:]


    def _get_reward(self, m):
        rdmol = m.mol
        if rdmol is None:
            return self.R_min
        smi = m.smiles
        if smi in self.train_mols_map:
            return self.train_mols_map[smi].reward
        return self.r2r(normscore=self.proxy_reward(m))
    def sample(self, n):
        if self.replay_mode == 'dataset':
            eidx = self.train_rng.randint(0, len(self.train_mols), n)
            samples = sum((self._get(i, self.train_mols) for i in eidx), [])
        elif self.replay_mode == 'online':
            eidx = self.train_rng.randint(0, max(1,len(self.online_mols)), n)
            samples = sum((self._get(i, self.online_mols) for i in eidx), [])
        elif self.replay_mode == 'prioritized':
            if not len(self.online_mols):
                # _get will sample from the model
                samples = sum((self._get(0, self.online_mols) for i in range(n)), [])
            else:
                prio = np.float32([i[0] for i in self.online_mols])
                eidx = self.train_rng.choice(len(self.online_mols), n, False, prio/prio.sum())
                samples = sum((self._get(i, self.online_mols) for i in eidx), [])
        return zip(*samples)

    def sample2batch(self, mb):
        if self.fl:
            p, a, r, r_fl, s, d, *o = mb
        else:
            p, a, r, s, d, *o = mb
        mols = (p, s)
        # The batch index of each parent
        p_batch = torch.tensor(sum([[i]*len(p) for i,p in enumerate(p)], []),
                               device=self._device).long()
        # Convert all parents and states to repr. Note that this
        # concatenates all the parent lists, which is why we need
        # p_batch
        p = self.mdp.mols2batch(list(map(self.mdp.mol2repr, sum(p, ()))))
        s = self.mdp.mols2batch([self.mdp.mol2repr(i) for i in s])
        # Concatenate all the actions (one per parent per sample)
        a = torch.tensor(sum(a, ()), device=self._device).long()
        # rewards and dones
        r = torch.tensor(r, device=self._device).to(self.floatX)
        d = torch.tensor(d, device=self._device).to(self.floatX)
        if self.fl:
            r_fl = torch.tensor(r_fl, device=self._device).to(self.floatX)
            return (p, p_batch, a, r, r_fl, s, d, mols, *o)
        else:
            return (p, p_batch, a, r, s, d, mols, *o)

    def r2r(self, dockscore=None, normscore=None):
        if dockscore is not None:
            normscore = 4-(min(0, dockscore)-self.target_norm[0])/self.target_norm[1]
        normscore = max(self.R_min, normscore)
        return (normscore/self.reward_norm) ** self.reward_exp


    def start_samplers(self, n, mbsize):
        self.ready_events = [threading.Event() for i in range(n)]
        self.resume_events = [threading.Event() for i in range(n)]
        self.results = [None] * n
        def f(idx):
            while not self.stop_event.is_set():
                try:
                    self.results[idx] = self.sample2batch(self.sample(mbsize))
                except Exception as e:
                    print("Exception while sampling:")
                    print(e)
                    self.sampler_threads[idx].failed = True
                    self.sampler_threads[idx].exception = e
                    self.ready_events[idx].set()
                    break
                self.ready_events[idx].set()
                self.resume_events[idx].clear()
                self.resume_events[idx].wait()
        self.sampler_threads = [threading.Thread(target=f, args=(i,)) for i in range(n)]
        [setattr(i, 'failed', False) for i in self.sampler_threads]
        [i.start() for i in self.sampler_threads]
        round_robin_idx = [0]
        def get():
            while True:
                idx = round_robin_idx[0]
                # print(f"get samples from idx {idx}") # debug
                round_robin_idx[0] = (round_robin_idx[0] + 1) % n
                if self.ready_events[idx].is_set():
                    r = self.results[idx]
                    self.ready_events[idx].clear()
                    self.resume_events[idx].set()
                    # print(f"get samples from idx {idx}") # debug
                    return r
                elif round_robin_idx[0] == 0:
                    time.sleep(0.001)
        return get

    def stop_samplers_and_join(self):
        self.stop_event.set()
        if hasattr(self, 'sampler_threads'):
          while any([i.is_alive() for i in self.sampler_threads]):
            [i.set() for i in self.resume_events]
            [i.join(0.05) for i in self.sampler_threads]



class DatasetDirect(Dataset):
    def sample(self, n):
        trajectories = [self._get_sample_model() for i in range(n)]
        trajectories = [x for x in trajectories if len(x)>=1] # filter out nan trajectories
        batch = (*zip(*sum(trajectories, [])),
                 sum([[i] * len(t) for i, t in enumerate(trajectories)], []),
                 [len(t) for t in trajectories])

        return batch
    def sample2batch(self, mb):
        if self.fl:
            s, a, r, r_fl, sp, d, idc, lens = mb
        else:
            s, a, r, sp, d, idc, lens = mb
        mols = (s, sp)
        s = self.mdp.mols2batch([self.mdp.mol2repr(i[0]) for i in s])
        a = torch.tensor(sum(a, ()), device=self._device).long()
        r = torch.tensor(r, device=self._device).to(self.floatX)
        d = torch.tensor(d, device=self._device).to(self.floatX)
        n = torch.tensor([len(self.mdp.parents(m)) if (m is not None) else 1 for m in sp], device=self._device).to(self.floatX)
        idc = torch.tensor(idc, device=self._device).long()
        lens = torch.tensor(lens, device=self._device).long()
        if self.fl:
            r_fl = torch.tensor(r_fl, device=self._device).to(self.floatX)
            return (s, a, r, r_fl, d, n, mols, idc, lens)
        else:
            return (s, a, r, d, n, mols, idc, lens)

def make_model(args, mdp, out_per_mol=1):
    if args.repr_type == 'block_graph':
        model = model_block.GraphAgent(nemb=args.nemb,
                                       nvec=0,
                                       out_per_stem=mdp.num_blocks,
                                       out_per_mol=out_per_mol,
                                       num_conv_steps=args.num_conv_steps,
                                       mdp_cfg=mdp,
                                       version=args.model_version)
    elif args.repr_type == 'atom_graph':
        model = model_atom.MolAC_GCN(nhid=args.nemb,
                                     nvec=0,
                                     num_out_per_stem=mdp.num_blocks,
                                     num_out_per_mol=out_per_mol,
                                     num_conv_steps=args.num_conv_steps,
                                     version=args.model_version,
                                     do_nblocks=(hasattr(args,'include_nblocks')
                                                 and args.include_nblocks), dropout_rate=0.1)
    elif args.repr_type == 'morgan_fingerprint':
        raise ValueError('reimplement me')
        model = model_fingerprint.MFP_MLP(args.nemb, 3, mdp.num_blocks, 1)
    return model


class Proxy:
    def __init__(self, args, bpath, device):
        eargs = pickle.load(gzip.open(f'{args.proxy_path}/info.pkl.gz'))['args']
        params = pickle.load(gzip.open(f'{args.proxy_path}/best_params.pkl.gz'))
        self.mdp = MolMDPExtended(bpath)
        self.mdp.post_init(device, eargs.repr_type)
        self.mdp.floatX = args.floatX
        self.proxy = make_model(eargs, self.mdp)
        # If you get an error when loading the proxy parameters, it is probably due to a version
        # mismatch in torch geometric. Try uncommenting this code instead of using the
        # super_hackish_param_map
        # for a,b in zip(self.proxy.parameters(), params):
        #    a.data = torch.tensor(b, dtype=self.mdp.floatX)
        super_hackish_param_map = {
            'mpnn.lin0.weight': params[0],
            'mpnn.lin0.bias': params[1],
            'mpnn.conv.bias': params[3],
            'mpnn.conv.nn.0.weight': params[4],
            'mpnn.conv.nn.0.bias': params[5],
            'mpnn.conv.nn.2.weight': params[6],
            'mpnn.conv.nn.2.bias': params[7],
            'mpnn.conv.lin.weight': params[2],
            'mpnn.gru.weight_ih_l0': params[8],
            'mpnn.gru.weight_hh_l0': params[9],
            'mpnn.gru.bias_ih_l0': params[10],
            'mpnn.gru.bias_hh_l0': params[11],
            'mpnn.lin1.weight': params[12],
            'mpnn.lin1.bias': params[13],
            'mpnn.lin2.weight': params[14],
            'mpnn.lin2.bias': params[15],
            'mpnn.set2set.lstm.weight_ih_l0': params[16],
            'mpnn.set2set.lstm.weight_hh_l0': params[17],
            'mpnn.set2set.lstm.bias_ih_l0': params[18],
            'mpnn.set2set.lstm.bias_hh_l0': params[19],
            'mpnn.lin3.weight': params[20],
            'mpnn.lin3.bias': params[21],
        }
        for k, v in super_hackish_param_map.items():
            self.proxy.get_parameter(k).data = torch.tensor(v, dtype=self.mdp.floatX)
        self.proxy.to(device)

    def __call__(self, m):
        m = self.mdp.mols2batch([self.mdp.mol2repr(m)])
        return self.proxy(m, do_stems=False)[1].item()

_stop = [None]

def train_one_step(args,model,minibatch,device):
    if args.fl:
        s, a, r, r_fl, d, n, mols, idc, lens, *o = minibatch
        st_idx = 0
        ris = []
        for data_idx in range(d.shape[0]):
            if d[data_idx]:
                curr_r_fls = r_fl[st_idx : data_idx + 1]
                curr_r_fls = torch.cat((tf([1.0], args.floatX, device), curr_r_fls))

                curr_r_fls_prev = curr_r_fls[:-1]
                curr_r_fls_next = curr_r_fls[1:]

                curr_ris = curr_r_fls_next / curr_r_fls_prev

                ris.append(curr_ris)
                st_idx = data_idx + 1
    else:
        s, a, r, d, n, mols, idc, lens, *o = minibatch
    stem_out_s, mol_out_s = model(s, None)
    # index parents by their corresponding actions
    logits = -model.action_negloglikelihood(s, a, 0, stem_out_s, mol_out_s)
    tzeros = torch.zeros(idc[-1]+1, device=device, dtype=args.floatX)
    traj_r = tzeros.index_add(0, idc, r)


    alpha = model.alpha

    log_PB=-torch.log(n)
    if args.objective == 'tb':
            # loss=trajectory_balance_loss(logits,log_PB,model.logZ,torch.log(traj_r),lens,alpha) # old implementation
            loss=trajectory_balance_loss(logits,log_PB,model.logZ,torch.log(traj_r.clamp_min(1e-6)),lens,alpha) # new implementation

    elif args.objective == 'db':
        if args.fl:
            loss = forward_looking_detailed_balance_loss(logits, log_PB, mol_out_s[:, 1], torch.log(traj_r), lens,to_log(ris,device),alpha)
        else:
            loss = detailed_balance_loss(logits, log_PB, mol_out_s[:, 1], torch.log(traj_r), lens,alpha)
    elif args.objective == 'subtb':
        Lambda = tf([args.subtb_lambda],args.floatX,device)
        if args.fl:
            loss = forward_looking_tb_lambda_loss(logits, log_PB, mol_out_s[:, 1], torch.log(traj_r), lens,to_log(ris,device), Lambda, alpha)
        else:
            loss = tb_lambda_loss(logits, log_PB, mol_out_s[:, 1], torch.log(traj_r), lens, Lambda,alpha)
    else:
        raise ValueError(f'Unknown objective {args.objective}')
    return loss,r,mols # here r and mols are added to be compatible with the original codebase

# Borrowed from https://github.com/aanjaa/gflownet/blob/trunk/src/gflownet/utils/metrics_final_eval.py with modifications
@timer
def compute_num_of_modes(
    candidates: List[Tuple[Any, Any, Any, Any, Any]],
    reward_thresh: float = 8.0,
    tanimoto_thresh: float = 0.7,
    top_k: Iterable[int] = (100,),
) -> Dict[str, Any]:
    """
    Efficient metrics over sampled molecules without fingerprinting all candidates.

    Args:
        candidates:
            Each item is a tuple:
              (normalized_reward, BlockMoleculeDataExtended, traj_stats, inflow, traj_len).
            We use:
              reward := c[0]
              mol    := c[1].mol (RDKit Mol)
              length := c[-1]
        reward_thresh:
            Only candidates with reward >= reward_thresh enter mode counting.
        tanimoto_thresh:
            Start a new mode if max(similarity to existing modes) < tanimoto_thresh.
        top_k:
            One or many K values. For each K:
              - take the reward-sorted prefix,
              - keep only valid molecules (Mol & fingerprint),
              - if some among top-K are invalid, extend the window until K valid are collected,
              - compute mean reward/length on those K valid,
              - compute mean pairwise Tanimoto among those K valid.

    Returns:
        Dict with keys:
          - "all_samples_avg_reward", "all_samples_avg_length"
          - "num_candidates_above_thresh",
            "candidates_avg_reward", "candidates_avg_length", "candidates_avg_similarity"
          - "num_modes",
            "modes_avg_reward", "modes_avg_length", "modes_avg_similarity"
          - For each k in top_k:
              "top_{k}_count"
              "top_{k}_avg_reward", "top_{k}_avg_length", "top_{k}_avg_similarity"
          - "config": {"reward_thresh", "tanimoto_thresh", "top_k"}
    """
    # Normalize top_k
    if isinstance(top_k, int):
        top_ks = [top_k]
    else:
        top_ks = sorted({int(k) for k in top_k if int(k) > 0})

    out: Dict[str, Any] = {}

    n_total = len(candidates)
    # print(f"total number of candidates: {n_total}")
    if n_total == 0:
        base_nan = {
            "all_samples_avg_reward": float("nan"),
            "all_samples_avg_length": float("nan"),
            "num_candidates_above_thresh": 0,
            "candidates_avg_reward": float("nan"),
            "candidates_avg_length": float("nan"),
            "candidates_avg_similarity": float("nan"),
            "num_modes": 0,
            "modes_avg_reward": float("nan"),
            "modes_avg_length": float("nan"),
            "modes_avg_similarity": float("nan"),
        }
        out.update(base_nan)
        for k in top_ks:
            out[f"top_{k}_count"] = 0
            out[f"top_{k}_avg_reward"] = float("nan")
            out[f"top_{k}_avg_length"] = float("nan")
            out[f"top_{k}_avg_similarity"] = float("nan")
        return out

    # 1) Single sort by reward (no fingerprinting)
    c_sorted = sorted(candidates, key=lambda m: m[0], reverse=True)
    rewards = np.asarray([c[0] for c in c_sorted], dtype=float)
    lengths = np.asarray([c[-1] for c in c_sorted], dtype=float)

    out["all_samples_avg_reward"] = float(np.mean(rewards)) if rewards.size else float("nan")
    out["all_samples_avg_length"] = float(np.mean(lengths)) if lengths.size else float("nan")

    # 2) Top-K metrics: fingerprint only as many as needed
    if top_ks:
        max_k = max(top_ks)
        top_rewards: List[float] = []
        top_lengths: List[float] = []
        top_fps: List[Any] = []

        i = 0
        while len(top_fps) < max_k and i < n_total:
            mol = getattr(c_sorted[i][1], "mol", None)
            if mol is not None:
                try:
                    fp = Chem.RDKFingerprint(mol)
                    top_fps.append(fp)
                    top_rewards.append(float(c_sorted[i][0]))
                    top_lengths.append(float(c_sorted[i][-1]))
                except Exception:
                    pass  # invalid Mol/fp -> skip & extend window
            i += 1

        # Prefix pairwise similarity accumulators for O(1) query
        pair_sum_prefix: List[float] = []
        pair_cnt_prefix: List[int] = []
        ref_fps: List[Any] = []
        for idx, fp_i in enumerate(top_fps):
            if idx == 0:
                pair_sum_prefix.append(0.0)
                pair_cnt_prefix.append(0)
                ref_fps.append(fp_i)
                continue
            sims = DataStructs.BulkTanimotoSimilarity(fp_i, ref_fps)
            pair_sum_prefix.append(pair_sum_prefix[-1] + float(np.sum(sims)))
            pair_cnt_prefix.append(pair_cnt_prefix[-1] + len(sims))
            ref_fps.append(fp_i)

        for k in top_ks:
            k_eff = min(k, len(top_fps))
            out[f"top_{k}_count"] = int(k_eff)
            if k_eff == 0:
                out[f"top_{k}_avg_reward"] = float("nan")
                out[f"top_{k}_avg_length"] = float("nan")
                out[f"top_{k}_avg_similarity"] = float("nan")
            else:
                out[f"top_{k}_avg_reward"] = float(np.mean(top_rewards[:k_eff]))
                out[f"top_{k}_avg_length"] = float(np.mean(top_lengths[:k_eff]))
                if k_eff == 1:
                    out[f"top_{k}_avg_similarity"] = float("nan")
                else:
                    ps = pair_sum_prefix[k_eff - 1]
                    pc = pair_cnt_prefix[k_eff - 1]  # = k_eff*(k_eff-1)/2
                    out[f"top_{k}_avg_similarity"] = (ps / pc) if pc > 0 else float("nan")

    # 3) Modes: only consider reward >= threshold (saves work)
    cutoff = 0
    rth = float(reward_thresh)
    while cutoff < n_total and rewards[cutoff] >= rth:
        cutoff += 1

    out["num_candidates_above_thresh"] = int(cutoff)
    if cutoff == 0:
        out.update({
            "candidates_avg_reward": float("nan"),
            "candidates_avg_length": float("nan"),
            "candidates_avg_similarity": float("nan"),
            "num_modes": 0,
            "modes_avg_reward": float("nan"),
            "modes_avg_length": float("nan"),
            "modes_avg_similarity": float("nan"),
        })
        return out

    rewards_c = rewards[:cutoff]
    lengths_c = lengths[:cutoff]
    out["candidates_avg_reward"] = float(np.mean(rewards_c))
    out["candidates_avg_length"] = float(np.mean(lengths_c))

    # Greedy mode extraction over the thresholded prefix.
    mode_fps: List[Any] = []
    modes_rewards: List[float] = []
    modes_lengths: List[float] = []
    modes_added_sims: List[float] = []     # similarity of a newly-added mode to existing modes
    candidates_avg_sims: List[float] = []  # similarity of each candidate to current modes

    # Seed with the first VALID candidate in the prefix (not necessarily index 0)
    seed_idx = None
    seed_fp = None
    for j in range(cutoff):
        mol = getattr(c_sorted[j][1], "mol", None)
        if mol is None:
            continue
        try:
            seed_fp = Chem.RDKFingerprint(mol)
            seed_idx = j
            break
        except Exception:
            continue

    if seed_idx is None:
        # No valid molecules in the thresholded set
        out.update({
            "candidates_avg_similarity": float("nan"),
            "num_modes": 0,
            "modes_avg_reward": float("nan"),
            "modes_avg_length": float("nan"),
            "modes_avg_similarity": float("nan"),
        })
        return out

    # Add seed mode
    mode_fps.append(seed_fp)
    modes_rewards.append(float(rewards[seed_idx]))
    modes_lengths.append(float(lengths[seed_idx]))

    # Continue greedy selection from the next element
    for j in range(seed_idx + 1, cutoff):
        mol = getattr(c_sorted[j][1], "mol", None)
        if mol is None:
            continue
        try:
            fp = Chem.RDKFingerprint(mol)
        except Exception:
            continue

        sims = DataStructs.BulkTanimotoSimilarity(fp, mode_fps)
        mean_sim = float(np.mean(sims)) if len(sims) > 0 else float("nan")
        if not np.isnan(mean_sim):
            candidates_avg_sims.append(mean_sim)

        # New mode if far from all existing modes
        if (len(sims) == 0) or (max(sims) < float(tanimoto_thresh)):
            mode_fps.append(fp)
            modes_rewards.append(float(rewards[j]))
            modes_lengths.append(float(lengths[j]))
            if not np.isnan(mean_sim):
                modes_added_sims.append(mean_sim)

    out["candidates_avg_similarity"] = (
        float(np.mean(candidates_avg_sims)) if candidates_avg_sims else float("nan")
    )
    out["num_modes"] = int(len(modes_rewards))
    out["modes_avg_reward"] = float(np.mean(modes_rewards)) if modes_rewards else float("nan")
    out["modes_avg_length"] = float(np.mean(modes_lengths)) if modes_lengths else float("nan")
    out["modes_avg_similarity"] = (
        float(np.mean(modes_added_sims)) if modes_added_sims else float("nan")
    )

    return out



def _dedup_new(repo_list, key_index: set, new_slice, key: str = "smiles"):
    """
    Deduplicate only the newly appended `new_slice` of repo_list using `key`.
    key in {"smiles","none"}.
    Returns the kept items list.
    """
    if len(new_slice) == 0:
        return []
    if key == "none":
        return list(new_slice)  # no dedup

    def _key_fn(mol):
        if key == "smiles":
            return getattr(mol, "smiles", None)
        return None  # default: treat as unique

    kept = []
    for item in new_slice:
        mol = item[1]
        k = _key_fn(mol)
        if (k is None) or (k not in key_index):
            kept.append(item)
            if (k is not None):
                key_index.add(k)
    return kept



@contextmanager
def temporarily_set_sampling_model(dataset, new_model, proxy,
                                   epsilon: float = 0.0,
                                   disable_early_stop: bool = True,
                                   isolate_replay: bool = True,
                                   record_samples: bool = True,
                                   use_inference_mode: bool = True,
                                   set_eval: bool = True):
    """
    Temporarily:
      - set dataset.sampling_model to `new_model`,
      - set epsilon (random_action_prob) to `epsilon` (e.g., 0.0 for eval),
      - optionally disable early-stop regularization,
      - optionally isolate replay effects by switching replay_mode to 'dataset',
      - control whether _get_sample_model() appends to dataset.sampled_mols via `record_samples`.
      - **NEW**: if `use_inference_mode` is True, wrap the body in ``torch.inference_mode()``;
        if `set_eval` is True, set ``new_model.eval()`` / ``proxy.eval()`` temporarily.

    Restores the previous state afterwards.
    """
    old = {
        'sampling_model': dataset.sampling_model,
        'sampling_model_prob': dataset.sampling_model_prob,
        'proxy_reward': getattr(dataset, 'proxy_reward', None),
        'random_action_prob': dataset.random_action_prob,
        'early_stop_reg': dataset.early_stop_reg,
        'replay_mode': dataset.replay_mode,
        'record_samples': dataset.record_samples,
    }
    # cache train/eval states to restore later
    _model_was_training = getattr(new_model, 'training', False)
    _proxy_was_training = getattr(proxy, 'training', False) if proxy is not None else False
    # print(f" _model_was_training: { _model_was_training}, _proxy_was_training: {_proxy_was_training}")

    dataset.set_sampling_model(new_model, proxy, sample_prob=1.0)
    dataset.random_action_prob = float(epsilon)
    if disable_early_stop:
        dataset.early_stop_reg = 0.0
    if isolate_replay:
        dataset.replay_mode = 'dataset'
    dataset.record_samples = bool(record_samples)

    # switch to eval() for sampling to speed up layers like dropout/bn and avoid grad metadata
    if set_eval and hasattr(new_model, 'eval'):
        new_model.eval()
    if set_eval and (proxy is not None) and hasattr(proxy, 'eval'):
        proxy.eval()

    # enter inference_mode if requested (safe to nest with no_grad)
    try:
        if use_inference_mode:
            with torch.inference_mode():
                yield
        else:
            yield
    finally:
        # restore dataset hooks
        dataset.set_sampling_model(old['sampling_model'], old['proxy_reward'],
                                   sample_prob=old['sampling_model_prob'])
        dataset.random_action_prob = old['random_action_prob']
        dataset.early_stop_reg = old['early_stop_reg']
        dataset.replay_mode = old['replay_mode']
        dataset.record_samples = old['record_samples']
        # restore model/proxy training states
        if set_eval and hasattr(new_model, 'train'):
            if _model_was_training:
                new_model.train()
        if set_eval and (proxy is not None) and hasattr(proxy, 'train'):
            if _proxy_was_training:
                proxy.train()
@torch.no_grad()
def accumulate_trained_policy_samples(dataset, trained_model, proxy,
                                      n_trajs: int,
                                      key_index: set = None,
                                      dedup_key: str = "none"):
    """
    Generate N trajectories using the trained policy (epsilon=0),
    and accumulate them into dataset.sampled_mols.
    If dedup_key != 'none', deduplicate the newly appended slice using key_index.
    """
    start_len = len(dataset.sampled_mols)
    with temporarily_set_sampling_model(dataset, trained_model, proxy,
                                        epsilon=0.0,
                                        disable_early_stop=True,
                                        isolate_replay=True,
                                        record_samples=True,
                                        use_inference_mode=True,
                                        set_eval=True):
        for _ in range(int(n_trajs)):
            dataset._get_sample_model()

    if dedup_key != "none":
        if key_index is None:
            key_index = set()
        new_slice = dataset.sampled_mols[start_len:]
        kept = _dedup_new(dataset.sampled_mols, key_index, new_slice, key=dedup_key)
        dataset.sampled_mols[start_len:] = kept

@timer
@torch.no_grad()
def eval_accumulate_and_compute_metrics(dataset, trained_model, proxy, args,
                                        n_trajs: int,
                                        key_index: set = None,
                                        dedup_key: str = "none",
                                        entropy_scope: str = 'batch'):
    """
    Generate N eval trajectories with the trained policy (epsilon=0), accumulate into
    dataset.sampled_mols (optional dedup by `dedup_key`), then compute:
      - forward/backward chosen-action prob & entropy,
      - modes/top-k/length/reward summaries.
    """
    start_len = len(dataset.sampled_mols)
    with temporarily_set_sampling_model(dataset, trained_model, proxy,
                                        epsilon=0.0,
                                        disable_early_stop=True,
                                        isolate_replay=True,
                                        record_samples=True,
                                        use_inference_mode=True,
                                        set_eval=True):
        for _ in range(int(n_trajs)):
            dataset._get_sample_model()
    if dedup_key != "none":
        if key_index is None:
            key_index = set()
        new_slice = dataset.sampled_mols[start_len:]
        kept = _dedup_new(dataset.sampled_mols, key_index, new_slice, key=dedup_key)
        dataset.sampled_mols[start_len:] = kept


    # --- entropy / chosen-action prob ---
    def _agg_entropy(slice_list):
        if len(slice_list) == 0:
            return {
                'forward_avg_action_prob_eval': float('nan'),
                'forward_policy_entropy_eval':  float('nan'),
                'backward_avg_action_prob_eval': float('nan'),
                'backward_policy_entropy_eval':  float('nan'),
                'avg_current_reward_eval':        float('nan'),
            }
        fwd_p, fwd_H, bwd_p, bwd_H, batch_rewards = [], [], [], [], []
        for normscore, _mol, traj_stats, _inflow, _tlen in slice_list:
            batch_rewards.append(float(normscore))
            for (p_taken, _a, _valid_logZ, f_ent, back_par) in traj_stats:
                fwd_p.append(float(p_taken))
                fwd_H.append(float(f_ent))
                if back_par and back_par > 0:
                    bwd_p.append(1.0 / float(back_par))
                    bwd_H.append(float(np.log(back_par + 1e-12)))
        return {
            'forward_avg_action_prob_eval': float(np.mean(fwd_p)) if fwd_p else float('nan'),
            'forward_policy_entropy_eval':  float(np.mean(fwd_H)) if fwd_H else float('nan'),
            'backward_avg_action_prob_eval': float(np.mean(bwd_p)) if bwd_p else float('nan'),
            'backward_policy_entropy_eval':  float(np.mean(bwd_H)) if bwd_H else float('nan'),
            'avg_current_reward_eval':        float(np.mean(batch_rewards)) if batch_rewards else float('nan'),
        }

    entropy_stats = _agg_entropy(dataset.sampled_mols if entropy_scope=='repo'
                                 else dataset.sampled_mols[start_len:])

    # --- modes/top-k on the accumulated repository ---
    modes_raw = compute_num_of_modes(
        candidates=dataset.sampled_mols,
        reward_thresh=args.mode_threshold,
        tanimoto_thresh=args.tanimoto_threshold,
        top_k=args.top_k,
    )

    # Add a consistent suffix to keep the logging namespace stable
    modes_stats = {}
    for k, v in modes_raw.items():
        modes_stats[f"{k}_eval"] = v

    # You already had this line; keep it
    modes_stats["eval_repo_size"] = len(dataset.sampled_mols)

    out = {}
    out.update(entropy_stats)
    out.update(modes_stats)
    return out


def train_model_with_proxy(args, model, proxy, dataset, num_steps, do_save=False):
    # debug_no_threads = False
    # debug_no_threads=args.debug_no_threads #reproducibility is ensured if True, otherwise no reproducibility
    device = torch.device('cuda')

    target_ema_tau = args.bootstrap_tau
    if args.bootstrap_tau > 0:
        target_model = deepcopy(model)

    if do_save:
        exp_dir = f'{args.save_path}/{args.array}_{args.run}/'
        os.makedirs(exp_dir, exist_ok=True)

    dataset.set_sampling_model(model, proxy, sample_prob=args.sample_prob)

    @timer
    def save_stuff(iter,do_save=False):
        # corr_logp = compute_correlation(model, dataset.mdp, entropy_coeff=args.entropy_coeff)
        corr_logp = compute_correlation(model, dataset.mdp, args)
        if do_save:
            pickle.dump(corr_logp, gzip.open(f'{exp_dir}/{iter}_model_logp_pred.pkl.gz', 'wb'))

            pickle.dump([i.data.cpu().numpy() for i in model.parameters()],
                        gzip.open(f'{exp_dir}/' + str(iter) + '_params.pkl.gz', 'wb'))

            pickle.dump(dataset.sampled_mols,
                        gzip.open(f'{exp_dir}/' + str(iter) + '_sampled_mols.pkl.gz', 'wb'))

            pickle.dump({'train_losses': train_losses,
                        'test_losses': test_losses,
                        'test_infos': test_infos,
                        'time_start': time_start,
                        'time_now': time.time(),
                        'args': args,},
                        gzip.open(f'{exp_dir}/' + str(iter) + '_info.pkl.gz', 'wb'))
            pickle.dump(train_infos,
                        gzip.open(f'{exp_dir}/' + str(iter) + '_train_info.pkl.gz', 'wb'))
        true_log_r = [np.log(corr_logp[i][0][0]) for i in range(len(corr_logp))]
        pred_log_r = [corr_logp[i][1] for i in range(len(corr_logp))]
        return true_log_r, pred_log_r


    # Add alpha as a learnable parameter to the model
    model.alpha = nn.Parameter(torch.tensor(args.alpha, device=device, dtype=args.floatX, requires_grad=False))

    if args.objective == 'tb':
        model.logZ = nn.Parameter(tf(args.initial_log_Z,args.floatX,device))


    opt = torch.optim.Adam(
        [
            {
                'params': [p for n, p in model.named_parameters() if 'alpha' not in n],
                'lr': args.learning_rate,
                'weight_decay': args.weight_decay,
                'betas': (args.opt_beta, args.opt_beta2),
                'eps': args.opt_epsilon,
            }
        ]
    )


    mbsize = args.mbsize


    if not args.debug_no_threads:
        sampler = dataset.start_samplers(args.num_threads, mbsize)
    
    def stop_everything():
        print('joining')
        dataset.stop_samplers_and_join()
    _stop[0] = stop_everything

    my_losses=[]
    train_losses = []
    test_losses = []
    test_infos = []
    train_infos = []
    time_start = time.time()
    time_last_check = time.time()

    loginf = 1000 # to prevent nans
    log_reg_c = args.log_reg_c
    clip_loss_val = float(args.clip_loss)
    balanced_loss = args.balanced_loss
    do_nblocks_reg = False
    max_blocks = args.max_blocks
    leaf_coef = args.leaf_coef

    last_eval_wall = time.time()
    # index for dedup (used only if dedup_key != 'none')
    eval_key_index = set()
    clip_acc = {'sum_scale':0.0, 'sum_trigger':0.0, 'cnt':0}
    for i in range(num_steps):
        # Epsilon schedule (stored inside dataset so _get_sample_model reads it)
        if bool(getattr(args, 'use_exp_weight_decay', 0)):
            if i == 0:
                eps_sched = ExpWeightScheduler(
                    end=args.random_action_prob,
                    total_steps=args.num_iterations,
                    kind=getattr(args, 'exp_weight_sched'),
                    warm_frac=getattr(args, 'exp_weight_warm_frac'),
                    start=getattr(args, 'eps_start'),
                    end_frac=getattr(args, 'eps_end_frac'),
                )
            dataset.random_action_prob = eps_sched(i)

        # Alpha schedule (only when alpha is NOT trainable)
        if args.use_alpha_scheduler:
            if i == 0:
                alpha_sched = AlphaScheduler(
                    total_steps=args.num_iterations,
                    alpha0=args.alpha,
                    warm_frac=getattr(args, 'alpha_warm_frac'),
                    alpha_final=0.5,
                )
            with torch.no_grad():
                model.alpha.copy_(torch.tensor(alpha_sched(i), device=device, dtype=args.floatX))

        if not args.debug_no_threads:
            r = sampler()
            for thread in dataset.sampler_threads:
                if thread.failed:
                    stop_everything()
                    pdb.post_mortem(thread.exception.__traceback__)
                    return
            minibatch = r
        else:
            minibatch = dataset.sample2batch(dataset.sample(mbsize))
        
        if args.objective not in  ['tb','db','subtb']:
            raise ValueError(f'Unknown objective {args.objective}')
        else:
            loss,r,mols=train_one_step(args,model,minibatch,device)
            opt.zero_grad(set_to_none=True)
            # Guard 1: loss must be finite
            if not torch.isfinite(loss):
                print("[WARN] non-finite loss -> skip step and reset optimizer state", flush=True)
                opt.zero_grad(set_to_none=True)
                opt.state.clear()  # clear Adam moments to avoid NaN propagation
                continue
            loss.backward()
            if args.clip_grad > 0:
                # torch.nn.utils.clip_grad_value_(model.parameters(),
                #                             args.clip_grad)
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                            args.clip_grad)
            if not all(p.grad is None or torch.isfinite(p.grad).all() for p in model.parameters()):
                print("[WARN] non-finite grads -> skip step and reset optimizer state", flush=True)
                opt.zero_grad(set_to_none=True)
                opt.state.clear()
                continue
            with torch.no_grad():
                total_sq = 0.0
                for p in model.parameters():
                    if p.grad is not None:
                        g = p.grad.detach()
                        total_sq += float(g.pow(2).sum().item())
                total_norm = total_sq ** 0.5
                clip_tau = float(args.clip_grad) if args.clip_grad > 0 else float('inf')
                scale = float(clip_tau / (total_norm + 1e-12)) if total_norm > clip_tau else 1.0
                trigger = 1.0 if total_norm > clip_tau else 0.0
                clip_acc['sum_scale'] += scale
                clip_acc['sum_trigger'] += trigger
                clip_acc['cnt'] += 1
            opt.step()


  
            my_losses.append(loss.detach().cpu().item())
            train_losses.append((loss.detach().cpu().item(),))
            if not i % 50:
                train_infos.append((
                    r.data.cpu().numpy(),
                    mols[1],
                    [i.pow(2).sum().item() for i in model.parameters()],
                ))
        model.training_steps = i + 1
        if target_ema_tau > 0:
            for _a,b in zip(model.parameters(), target_model.parameters()):
                b.data.mul_(1-target_ema_tau).add_(target_ema_tau*_a)

        if (not i % args.save_every) or (i == num_steps - 1): # calculate the test stats at every save step
            true_log_r, pred_log_r = save_stuff(i,do_save)
            pearson_corr_test = np.corrcoef(true_log_r, pred_log_r)[0][1]
            spearman_corr_test, _ =spearmanr(true_log_r, pred_log_r)
            # Warning: NOT recommended to use candidates=dataset.sampled_mols in compute_num_of_modes() due to the multi-thread locking issues. At present, mult-thread is not safe for evaluation.
            test_stats={
                'pearson_corr_test':pearson_corr_test,
                'spearman_corr_test':spearman_corr_test
            }
        else:
            test_stats={}

        if (i % args.eval_every == 0) or (i == num_steps - 1):
            eval_stats = eval_accumulate_and_compute_metrics(
                dataset, model, proxy, args,
                n_trajs=args.mbsize,
                key_index=eval_key_index,
                dedup_key=args.eval_dedup_key,
                entropy_scope='batch'   # or 'repo'
            )
            # print(eval_stats)
            # loss
            start_idx = max(0, len(my_losses) - args.eval_every)
            avg_loss = float(np.mean(my_losses[start_idx:])) if my_losses else float('nan')
            curr_loss = float(my_losses[-1]) if my_losses else float(loss.detach().cpu().item())
            # clip
            if clip_acc['cnt'] > 0:
                clip_trigger_rate_interval = clip_acc['sum_trigger'] / clip_acc['cnt']
                clip_scale_avg_interval = clip_acc['sum_scale'] / clip_acc['cnt']
            else:
                clip_trigger_rate_interval = 0.0
                clip_scale_avg_interval = 1.0
            log_stats={
                'step': i,
                'alpha': alpha_sched(i),
                'random_action_prob':dataset.random_action_prob,
                f"avg_loss_of_last_{args.eval_every}_steps": avg_loss,
                'current_loss':curr_loss,
                'clip_trigger_rate_interval': clip_trigger_rate_interval,
                'clip_scale_avg_interval': clip_scale_avg_interval,
            }
            eval_interval_time_s = time.time() - last_eval_wall

            log_stats.update(eval_stats)
            log_stats.update(test_stats)
            
            print(f'task={args.wdb_name}, eval_interval={eval_interval_time_s:3f}s, '+_dict_to_str(log_stats))

            # reset accumulators
            clip_acc = {'sum_scale':0.0, 'sum_trigger':0.0, 'cnt':0}
            last_eval_wall = time.time()
            if args.wdb and (i!=num_steps-1):
                wandb.log(log_stats)
        else:
            # only generate a batch of sampls for evaluation
            accumulate_trained_policy_samples(
                dataset, model, proxy,
                n_trajs=args.mbsize,
                key_index=eval_key_index,
                dedup_key=args.eval_dedup_key  # "none" (no dedup) or "smiles"
            )
    setting_stats=_make_loggable_args(args)
    log_stats.update(setting_stats)
    wandb.log(log_stats)
    print(f"task={args.wdb_name}, args: "+_dict_to_str(setting_stats))
    
    stop_everything()
    if do_save:
        save_stuff(i,do_save)
    print('End training!',flush=True)
    return model


def main(args):
    set_model_seed(args)
    bpath = "data/blocks_PDB_105.json"

    device = torch.device('cuda')
    print(args)

    if args.floatX == 'float32':
        args.floatX = torch.float
    else:
        args.floatX = torch.double
        
    if args.objective == 'fm':
        dataset = Dataset(args, bpath, device, floatX=args.floatX)
    else:
        args.ignore_parents = True
        dataset = DatasetDirect(args, bpath, device, floatX=args.floatX)


    mdp = dataset.mdp

    model = make_model(args, mdp, out_per_mol=1 + (1 if args.objective in ['subtb', 'db'] else 0))
    model.to(args.floatX)
    model.to(device)

    proxy = Proxy(args, bpath, device)
    
    train_model_with_proxy(args, model, proxy, dataset, num_steps=args.num_iterations, do_save=False) # omit saving to accelerate the training process

    print('Done.')



def get_mol_path_graph(mol):
    bpath = "data/blocks_PDB_105.json"
    mdp = MolMDPExtended(bpath)
    mdp.post_init(torch.device('cpu'), 'block_graph')
    mdp.build_translation_table()
    mdp.floatX = torch.float
    agraph = nx.DiGraph()
    agraph.add_node(0)
    ancestors = [mol]
    ancestor_graphs = []

    par = mdp.parents(mol)
    mstack = [i[0] for i in par]
    pstack = [[0, a] for i,a in par]
    while len(mstack):
        m = mstack.pop() #pop = last item is default index
        p, pa = pstack.pop()
        match = False
        mgraph = mdp.get_nx_graph(m)
        for ai, a in enumerate(ancestor_graphs):
            if mdp.graphs_are_isomorphic(mgraph, a):
                agraph.add_edge(p, ai+1, action=pa)
                match = True
                break
        if not match:
            agraph.add_edge(p, len(ancestors), action=pa) #I assume the original molecule = 0, 1st ancestor = 1st parent = 1
            ancestors.append(m) #so now len(ancestors) will be 2 --> and the next edge will be to the ancestor labelled 2
            ancestor_graphs.append(mgraph)
            if len(m.blocks):
                par = mdp.parents(m)
                mstack += [i[0] for i in par]
                pstack += [(len(ancestors)-1, i[1]) for i in par]

    for u, v in agraph.edges:
        c = mdp.add_block_to(ancestors[v], *agraph.edges[(u,v)]['action'])
        geq = mdp.graphs_are_isomorphic(mdp.get_nx_graph(c, true_block=True),
                                        mdp.get_nx_graph(ancestors[u], true_block=True))
        if not geq: # try to fix the action
            block, stem = agraph.edges[(u,v)]['action']
            for i in range(len(ancestors[v].stems)):
                c = mdp.add_block_to(ancestors[v], block, i)
                geq = mdp.graphs_are_isomorphic(mdp.get_nx_graph(c, true_block=True),
                                                mdp.get_nx_graph(ancestors[u], true_block=True))
                if geq:
                    agraph.edges[(u,v)]['action'] = (block, i)
                    break
        if not geq:
            raise ValueError('could not fix action')
    for u in agraph.nodes:
        agraph.nodes[u]['mol'] = ancestors[u]
    return agraph
    
@timer
def compute_correlation(model, mdp, args):
    device = torch.device('cuda')

    test_mols = pickle.load(gzip.open('data/some_mols_U_1k.pkl.gz'))
    logsoftmax = nn.LogSoftmax(0)
    logp = []
    reward = []
    numblocks = []
    for moli in (test_mols[:1000]):
        reward.append(np.log(moli[0]))
        try:
            agraph = get_mol_path_graph(moli[1])
        except:
            continue
        s = mdp.mols2batch([mdp.mol2repr(agraph.nodes[i]['mol']) for i in agraph.nodes])
        numblocks.append(len(moli[1].blocks))
        with torch.no_grad():
            stem_out_s, mol_out_s = model(s, None)  # get the mols_out_s for ALL molecules not just the end one.
            # Application of entropy coefficient
            stem_out_s = stem_out_s / args.entropy_coeff
            mol_out_s = mol_out_s / args.entropy_coeff
        per_mol_out = []
        # Compute pi(a|s)
        for j in range(len(agraph.nodes)):
            a,b = s._slice_dict['stems'][j:j+2]

            stop_allowed = len(agraph.nodes[j]['mol'].blocks) >= args.min_blocks
            mp = args.entropy_coeff * logsoftmax(torch.cat([
                stem_out_s[a:b].reshape(-1),
                # If num_blocks < min_blocks, the model is not allowed to stop
                mol_out_s[j, :1] if stop_allowed else tf([-1000],args.floatX,device)]))
            per_mol_out.append((mp[:-1].reshape((-1, stem_out_s.shape[1])), mp[-1]))

        # When the model reaches 8 blocks, it is stopped automatically. If instead it stops before
        # that, we need to take into account the STOP action's logprob
        if len(moli[1].blocks) < args.max_blocks:
            stem_out_last, mol_out_last = model(mdp.mols2batch([mdp.mol2repr(moli[1])]), None)
            # Application of entropy coefficient
            stem_out_last = stem_out_last / args.entropy_coeff
            mol_out_last = mol_out_last / args.entropy_coeff
            mplast = args.entropy_coeff * logsoftmax(torch.cat([stem_out_last.reshape(-1), mol_out_last[0, :1]]))
            MSTOP = mplast[-1]

        # assign logprob to edges
        for u,v in agraph.edges:
            a = agraph.edges[u,v]['action']
            if a[0] == -1:
                agraph.edges[u,v]['logprob'] = per_mol_out[v][1]
            else:
                agraph.edges[u,v]['logprob'] = per_mol_out[v][0][a[1], a[0]]

        # propagate logprobs through the graph
        for n in list(nx.topological_sort(agraph))[::-1]: 
            for c in agraph.predecessors(n): 
                if len(moli[1].blocks) < args.max_blocks and c == 0:
                    agraph.nodes[c]['logprob'] = torch.logaddexp(
                        agraph.nodes[c].get('logprob', tf(-1000,args.floatX,device)),
                        agraph.edges[c, n]['logprob'] + agraph.nodes[n].get('logprob', 0) + MSTOP)
                else:
                    agraph.nodes[c]['logprob'] = torch.logaddexp(
                        agraph.nodes[c].get('logprob', tf(-1000,args.floatX,device)),
                        agraph.edges[c, n]['logprob'] + agraph.nodes[n].get('logprob',0))

        logp.append((moli, agraph.nodes[n]['logprob'].item())) #add the first item
    return logp

    
try:
    from arrays import *
except:
    print("no arrays")

good_config = {
    'replay_mode': 'online',
    'sample_prob': 1,
    'mbsize': 4,
    'max_blocks': 8,
    'min_blocks': 2,
    # This repr actually is pretty stable
    'repr_type': 'block_graph',
    'model_version': 'v4',
    'nemb': 256,
    # at 30k iterations the models usually have "converged" in the
    # sense that the reward distribution doesn't get better, but the
    # generated molecules keep being unique, so making this higher
    # should simply provide more high-reward states.
    'num_iterations': 30000,

    'R_min': 0.1,
    'log_reg_c': (0.1/8)**4,
    # This is to make reward roughly between 0 and 1 (proxy outputs
    # between ~0 and 10, but very few are above 8).
    'reward_norm': 8,
    # you can play with this, higher is more risky but will give
    # higher rewards on average if it succeeds.
    'reward_exp': 10,
    'learning_rate': 5e-4,
    'num_conv_steps': 10, # More steps is better but more expensive
    # Too low and there is less diversity, too high and the
    # high-reward molecules become so rare the model doesn't learn
    # about them, 0.05 and 0.02 are sensible values
    'random_action_prob': 0.05,
    'opt_beta2': 0.999, # Optimization seems very sensitive to this,
                        # default value works fine
    'leaf_coef': 10, # Can be much bigger, not sure what the trade off
                     # is exactly though
    'include_nblocks': False,
}

if __name__ == '__main__':
    args = parser.parse_args()
    args.debug_no_threads=bool(args.debug_no_threads)
    args.use_alpha_scheduler=bool(args.use_alpha_scheduler)
    args.use_exp_weight_decay=bool(args.use_exp_weight_decay)
    args.fl=bool(args.fl)
    args.vec=bool(args.vec)
    args.wdb=bool(args.wdb)
    # ---- normalize top_k ----
    if isinstance(args.top_k, int):
        args.top_k = [args.top_k]
    args.top_k = sorted({int(k) for k in args.top_k if int(k) > 0})

    method=args.objective
    if args.fl:
        method="fl_" + method
    wdb_name=f'm({method})_a({args.alpha})_s({args.sampling_seed})_v({args.vec})'
    if args.wdb:
        args.wdb_name= wdb_name
        wandb.init(project=args.wdb_project, name=args.wdb_name)
    print("start training", args.wdb_name, flush=True)
    try:
        main(args)
    finally:
        if args.wdb:
            wandb.finish()