import argparse
import copy
import gzip
import heapq
import itertools
import os
import pickle
from collections import defaultdict
from itertools import count

import numpy as np
from scipy.stats import norm,spearmanr,pearsonr
# from tqdm import tqdm
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
import torch.nn.functional as F

import functools
import wandb
import random
import sys
import tempfile
import datetime
import time

from itertools import chain,combinations
from typing import List, Dict

sys.stdout.reconfigure(line_buffering=True)

from abc import ABC,abstractmethod
import math


print("Loading args...",flush=True)
parser = argparse.ArgumentParser()

parser.add_argument("--device", default='cuda', type=str)
parser.add_argument("--wdb", choices=[0,1],default=1,type=int,help="Whether to use wandb")
parser.add_argument('--wdb_project', default='Trainable-Alpha-Set', type=str, help="wandb project name")
# parser.add_argument("--progress", choices=[0,1],default=1,int=1)

# separate model seed and sampling seed
parser.add_argument("--sampling_seed",type=int,default=0,help="Random seed for sampling only (does not affect model initialization or training)")
parser.add_argument("--model_seed",type=int, default=0, help="The random seed to determine the initialization of models")


parser.add_argument("--method", default='db_gfn', type=str)
parser.add_argument("--fl", choices=[0,1], default=0, type=int, help="Whether to use forward-looking")

parser.add_argument("--learning_rate", default=1e-4, help="Learning rate, deprecated", type=float)
parser.add_argument("--tb_lr", default=0.001, help="Learning rate", type=float)
parser.add_argument("--tb_z_lr", default=0.1, help="Learning rate", type=float)
parser.add_argument("--mbsize", default=16, help="Minibatch size", type=int)
parser.add_argument("--n_hid", default=256, type=int) # 64 for training, 256 for sampling
parser.add_argument("--n_layers", default=2, type=int) # 1 for training, 2 for sampling
parser.add_argument("--n_train_steps", default=10000, type=int)

parser.add_argument("--use_exp_weight_decay", choices=[0,1], default=1, type=int, help="Whether to use exp_weight decay")
parser.add_argument("--exp_weight", default=0.05, type=float, help="the final exp_weight after decay")
parser.add_argument("--exp_weight_sched", type=str, default="linear",
                    choices=["linear", "cosine"],
                    help="Schedule to anneal exp_weight from 1.0 to --exp_weight")
parser.add_argument("--exp_weight_warm_frac", type=float, default=0.0,
                    help="Warmup fraction of total training steps to keep exp_weight at 1.0")


parser.add_argument("--temp", default=1., type=float)
parser.add_argument("--rand_pb", default=0, type=int) 
# Env
parser.add_argument("--size", default='small', type=str, choices=['small', 'medium', 'large'])
parser.add_argument("--action_dim", default=10, type=int)
parser.add_argument("--set_size", default=5, type=int)
parser.add_argument("--bufsize", default=16, type=int)
parser.add_argument("--n_eval_steps",default=50, type=int, help="Number of steps to perform evaluation")
# Alpha GFlowNets 
parser.add_argument("--alpha", default=0.5,help="alpha-GFN forward policy weight. When alpha is not trainable, this should be a number in (0,1). When alpha is trainable, this will be proccessed to a logit.", type=float)
parser.add_argument("--mode_threshold", default=0.25,type=float) # manually set the reward threshold. 0.25 for small, 700000 for medium and large
parser.add_argument("--num_threads",default=8,type=int)

parser.add_argument("--test_set_size",type=int,default=1000,help="The size of the test set. In this set, correlation metrics are computed.")
parser.add_argument("--n_test_steps",default=50,type=int,help="The number of steps to test the model over the test set")

parser.add_argument("--use_alpha_scheduler", choices=[0,1], default=1, type=int,
                    help="Enable alpha warmup→anneal when alpha is NOT trainable")
parser.add_argument("--alpha_warm_frac", type=float, default=0.4,
                    help="Fraction of total steps to anneal alpha (start=args.alpha, end=0.5)")
parser.add_argument("--alpha_sched", type=str, default='cos',
                    choices=['linear','cos','hold_exp','cyc','poly'],
                    help="Schedule type for alpha (start=args.alpha, end=0.5)")

parser.add_argument("--use_grad_clip", choices=[0,1], default=1, type=int,
                    help="Enable global grad-norm clip when alpha is NOT trainable")
parser.add_argument("--grad_clip_norm", type=float, default=2.0,
                    help="Global grad-norm clip threshold; <=0 disables")
parser.add_argument("--subtb_lambda", default=0.99, type=float)
print("Initialized args, start initializing models",flush=True)



def timer(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        start = time.perf_counter()
        result = func(*args, **kwargs)
        end = time.perf_counter()
        print(f"[{func.__name__}] Executed in {end-start:.6f}s", flush=True)
        return result
    return wrapper


def _dict_to_str(d: Dict) -> str:
    def _fmt(v):
        if isinstance(v, (float, np.floating)):
            return f"{float(v):.3f}"
        elif isinstance(v, (int, np.integer, str)):
            return str(v)
        else:
            return str(v)
    return ', '.join(f'{k}={_fmt(v)}' for k, v in d.items())


_dev = [torch.device('cuda')]
# tf = lambda x: torch.FloatTensor(x).to(_dev[0])
# tl = lambda x: torch.LongTensor(x).to(_dev[0])

def set_device(dev):
    _dev[0] = dev


def tf(data, device=_dev[0]):
    """
    Converts various types of input data to a torch.FloatTensor efficiently.

    Args:
        data: The input data. Can be a numpy ndarray, list of ndarrays, list of lists, or list of scalars.
        device: The device to move the tensor to (e.g., 'cuda' or 'cpu').

    Returns:
        torch.FloatTensor or list of torch.FloatTensor (if shapes differ).
    """
    # Case 1: Single NumPy ndarray
    if isinstance(data, np.ndarray):
        tensor = torch.from_numpy(data).float()

    # Case 2: List of NumPy arrays
    elif isinstance(data, list) and all(isinstance(d, np.ndarray) for d in data):
        try:
            stacked = np.stack(data)  # works if all shapes match
            tensor = torch.from_numpy(stacked).float()
        except ValueError:
            # Fallback for irregular shapes
            return [torch.from_numpy(d).float().to(device) if device else torch.from_numpy(d).float()
                    for d in data]

    # Case 3: List of lists (assumed rectangular)
    elif isinstance(data, list) and all(isinstance(d, list) for d in data):
        tensor = torch.FloatTensor(data)

    # Case 4: List of scalars
    elif isinstance(data, list) and all(np.isscalar(d) for d in data):
        tensor = torch.FloatTensor(data)

    # Fallback for other types
    else:
        tensor = torch.FloatTensor(data)

    return tensor.to(device) if device else tensor

def tl(data, device=_dev[0]):
    """
    Converts input data to torch.LongTensor efficiently.

    Args:
        data: Input data (can be numpy array, list of ndarrays, list of lists, or list of scalars).
        device: Device to move the tensor to (e.g., 'cuda' or 'cpu').

    Returns:
        torch.LongTensor or list of torch.LongTensor (if shapes differ).
    """
    # Case 1: Single NumPy ndarray
    if isinstance(data, np.ndarray):
        tensor = torch.from_numpy(data).long()

    # Case 2: List of NumPy arrays (same or different shapes)
    elif isinstance(data, list) and all(isinstance(d, np.ndarray) for d in data):
        try:
            stacked = np.stack(data)  # assumes same shape
            tensor = torch.from_numpy(stacked).long()
        except ValueError:
            return [torch.from_numpy(d).long().to(device) if device else torch.from_numpy(d).long()
                    for d in data]

    # Case 3: List of lists (assumed rectangular)
    elif isinstance(data, list) and all(isinstance(d, list) for d in data):
        tensor = torch.LongTensor(data)

    # Case 4: List of scalars
    elif isinstance(data, list) and all(np.isscalar(d) for d in data):
        tensor = torch.LongTensor(data)

    # Fallback for other types
    else:
        tensor = torch.tensor(data, dtype=torch.long)

    return tensor.to(device) if device else tensor

# def make_mlp(l, act=nn.LeakyReLU(), tail=[]):
#     return nn.Sequential(*(sum([[nn.Linear(i, o)] + ([act] if n < len(l)-2 else []) for n, (i, o) in enumerate(zip(l, l[1:]))], []) + tail))

def make_mlp(sizes, act=nn.LeakyReLU(), tail=[]):
    # refactored make_mlp() function
    layers = []
    for i, j in zip(sizes, sizes[1:]):
        layers.append(nn.Linear(i, j))
        if j != sizes[-1]:
            layers.append(act)
    layers.extend(tail)
    return nn.Sequential(*layers)

class ExpWeightScheduler:
    """Anneal exp_weight from 1.0 to a target value; supports linear / cosine with optional warmup."""
    def __init__(self, end: float, total_steps: int,
                 kind: str = "linear", warm_frac: float = 0.0):
        import math
        self.math = math
        self.start = 1.0
        self.end = float(end)
        self.T = max(1, int(total_steps))
        self.warm = int(max(0.0, min(1.0, warm_frac)) * self.T)
        self.kind = kind

    def __call__(self, step: int) -> float:
        t = max(0, min(step, self.T - 1))
        if t < self.warm:
            return self.start
        x = (t - self.warm) / max(1, (self.T - self.warm))
        if self.kind == "linear":
            w = self.start + (self.end - self.start) * x
        elif self.kind == "cosine":
            w = self.end + (self.start - self.end) * 0.5 * (1 + self.math.cos(self.math.pi * x))
        else:
            raise ValueError(f"Unknown exp_weight schedule kind: {self.kind}")
        return float(max(0.0, min(1.0, w)))

class AlphaScheduler:
    def __init__(self, total_steps: int, alpha0: float, warm_frac: float = 0.4,
                 alpha_final: float = 0.5):
        import math
        self.math = math
        self.T = max(1, int(total_steps))
        self.a0 = float(alpha0)
        self.af = float(alpha_final) 
        self.warm_frac = max(0.0, min(1.0, float(warm_frac)))
        self.T_hold = int(self.T * self.warm_frac)
        self.decay_k = 4.0
        self.poly_p = 0.5

    def __call__(self, step: int) -> float:
        t = max(0, min(step, self.T-1))
        if t < self.T_hold:
            return self.a0  
        tail_t = t - self.T_hold
        tail_len = max(1, self.T - self.T_hold)
        r = self.math.exp(-self.decay_k * tail_t / tail_len)
        return self.af + (self.a0 - self.af) * r


class SetEnv:
    def __init__(self, action_dim, set_size, intermediate_energies):
        self.action_dim = action_dim
        self.set_size = set_size
        self.intermediate_energies = np.array(intermediate_energies)
        self.intermediate_rewards = np.exp(-self.intermediate_energies)

    def reward_func(self, a):
        reward = self.intermediate_rewards[a]
        return reward

    def obs(self, s=None):
        return self._state if s is None else s

    def reset(self):
        self._state = np.int32([0] * self.action_dim)
        self._step = 0
        return self.obs()

    def step(self, a):
        self._state[a] = 1
        self._step += 1

        done = self._step == self.set_size
        rew = self.reward_func(a)

        return self.obs(), rew, done

class GFNAgent:
    def __init__(self, args, train_envs, test_envs):
        out_dim = (args.action_dim) + (args.action_dim) + (1)
        self.training_model = make_mlp([args.action_dim] + [args.n_hid] * args.n_layers + [out_dim]) # the seed to determine models is fixed in set_model_seed()
        self.sampling_model = copy.deepcopy(self.training_model)
        self.dev=args.dev
        self.training_model.to(self.dev)
        self.sampling_model.to(self.dev)
        print(self.training_model,flush=True)

        self.sampling_gen = torch.Generator(device=args.dev).manual_seed(args.sampling_seed) # the seed to generate samples

        self.action_dim = args.action_dim

        self.train_envs = train_envs
        self.test_envs = test_envs
        
        self.exp_weight = args.exp_weight
        self.temp = args.temp
        self.uniform_pb = args.rand_pb
        self.dev = args.dev

        self.all_eval_unique_Rs = [] 
        self.eval_visited_strs = []
        self.cur_eval_unique_Rs=[]
        self.cur_eval_Rs=[]

        self.test_set=None

        self.alpha=nn.Parameter(torch.tensor(args.alpha,device=args.dev, requires_grad=False)) # alpha-GFN

        self.inf=1000000000 # mask
    

    def parameters(self):
        return self.training_model.parameters()
    
    # @timer
    def update_sampling_model_with_training_model(self):
        """
        Use training_model to update sampling_model
        """
        self.sampling_model=copy.deepcopy(self.training_model)
    
    def reset_sampling_model(self,sampling_model):
        """
        Reset the sampling model with an external model
        """
        self.sampling_model=sampling_model.to(self.dev)
    
    def _reset_cur_eval_Rs(self):
        self.cur_eval_Rs=[]
        self.cur_eval_unique_Rs=[]

    def get_eval_Rs(self):
        return self.cur_eval_unique_Rs,self.cur_eval_Rs
    
    def gen_test_set_with_sampling_model(self,test_set_size):
        """
        This function is similar to sample_many, but only sampling_model is used
        Additionally, here the seed for sampling is the model seed, not the sampling seed
        """
        batch_s, batch_a = [[] for i in range(test_set_size)], [[] for i in range(test_set_size)]
        batch_ri = [[] for i in range(test_set_size)]
        env_idx_done_map = {i: False for i in range(test_set_size)}
        not_done_envs = [i for i in range(test_set_size)]
        env_idx_return_map = {}

        s = tf([i.reset() for i in self.test_envs])
        done = [False] * test_set_size

        while not all(done):
            with torch.no_grad():
                pred = self.sampling_model(s)

                edge_mask = s.float()
                logits = (pred[..., : self.action_dim] - self.inf * edge_mask).log_softmax(1)

                sample_ins_probs = (1 - self.exp_weight) * (logits / self.temp).softmax(1) + self.exp_weight * (1 - edge_mask) / (1 - edge_mask + 0.0000001).sum(1).unsqueeze(1)
                acts = sample_ins_probs.multinomial(1) # Here the model seed is used by default
                acts = acts.squeeze(-1)

            step = [i.step(a) for i, a in zip([e for d, e in zip(done, self.test_envs) if not d], acts)]

            for dat_idx, (curr_s, curr_a, (_, curr_r, _)) in enumerate(zip(s, acts, step)):
                env_idx = not_done_envs[dat_idx]

                batch_s[env_idx].append(curr_s)
                batch_a[env_idx].append(curr_a.unsqueeze(-1))
                batch_ri[env_idx].append(curr_r)

            for dat_idx, (ns, r, d) in enumerate(step):
                env_idx = not_done_envs[dat_idx]
                env_idx_done_map[env_idx] = d

                if d:
                    env_idx_return_map[env_idx] = r
                    batch_s[env_idx].append(tf(ns))

            not_done_envs = []
            for env_idx in env_idx_done_map:
                if not env_idx_done_map[env_idx]:
                    not_done_envs.append(env_idx)

            c = count(0)
            m = {j: next(c) for j in range(test_set_size) if not done[j]}
            done = [bool(d or step[m[i]][2]) for i, d in enumerate(done)]
            s = tf([i[0] for i in step if not i[2]])

        batch_steps = [len(batch_s[i]) for i in range(len(batch_s))]

        for i in range(len(batch_s)):
            batch_s[i] = torch.stack(batch_s[i])
            batch_a[i] = torch.stack(batch_a[i])
            assert batch_s[i].shape[0] - batch_a[i].shape[0] == 1
            batch_ri[i] = torch.tensor(batch_ri[i]).unsqueeze(-1).float().to(self.dev)
        
        self.test_set=[batch_s, batch_a, batch_steps, batch_ri]
        batch_log_r=[]
        for data_idx in range(len(batch_s)):
            curr_intermediate_rewards = batch_ri[data_idx].squeeze(-1)
            curr_log_return = torch.sum(torch.log(curr_intermediate_rewards))
            batch_log_r.append(curr_log_return.cpu().item())
        self.test_set.append(np.array(batch_log_r))

    @timer
    def calc_stats_in_test_set(self,corr_metrics: List[str]) -> Dict:
        """
        Calculate correlation between log_r and pred_logits with training_model
        Other statistics: avg_pred_logits_per_len
        """
        def _corr_func(corr_metric:str): 
            if corr_metric=="pearson":
                return pearsonr
            elif corr_metric=="spearman":
                return spearmanr
            else:
                raise ValueError("Unknown correlation metric. Available: pearson, spearman")
        
        states, actions, episode_lens, intermediate_rewards, log_r = self.test_set
        all_forward_entropy=[]
        all_forward_probs = []
        all_backward_entropy=[]
        all_backward_probs = []
        all_lengths=[]
        all_logits=[]
        for data_idx in range(len(states)):
            (curr_states, curr_episode_len, pred, logits, back_logits, curr_intermediate_rewards) = \
                self._extract_sample_from_batch(states, actions, episode_lens, intermediate_rewards, data_idx)

            all_logits.append(logits.sum().detach().cpu().item())

            with torch.no_grad():
                edge_mask = curr_states.float()
                fwd_full_lp = (pred[..., : self.action_dim] - self.inf * edge_mask).log_softmax(1)
                init_edge_mask = (curr_states == 0).float()
                back_full_lp = ((0 if self.uniform_pb else 1) * pred[..., self.action_dim:-1] - self.inf * init_edge_mask).log_softmax(1)

                fwd_lp = fwd_full_lp[:-1, :]
                fwd_p  = fwd_lp.exp()
                fwd_ent = -(fwd_p * fwd_lp).sum(dim=1).mean().item()
                all_forward_entropy.append(fwd_ent)

                back_lp = back_full_lp[1:, :]
                back_p  = back_lp.exp()
                back_ent = -(back_p * back_lp).sum(dim=1).mean().item()
                all_backward_entropy.append(back_ent)

                all_forward_probs.append(logits.exp().mean().item())
                all_backward_probs.append(back_logits.exp().mean().item())


        all_forward_entropy = np.array(all_forward_entropy)
        all_forward_probs = np.array(all_forward_probs)
        all_backward_entropy = np.array(all_backward_entropy)
        all_backward_probs = np.array(all_backward_probs)
        all_logits = np.array(all_logits)

        forward_policy_entropy = float(np.mean(all_forward_entropy))
        forward_avg_action_prob = float(np.mean(all_forward_probs))
        backward_avg_policy_entropy = float(np.mean(all_backward_entropy))
        backward_avg_action_prob = float(np.mean(all_backward_probs))

        results={}
        for corr_metric in corr_metrics:
            results[corr_metric], p =_corr_func(corr_metric)(all_logits,log_r)
        results['forward_policy_entropy'] = forward_policy_entropy
        results['forward_avg_action_prob'] = forward_avg_action_prob
        results['backward_avg_policy_entropy'] = backward_avg_policy_entropy
        results['backward_avg_action_prob'] = backward_avg_action_prob
        return results


    def sample_many(self, mbsize, evaluate=False):
        """
        Generate a batch of samples using sampling_model.
        The seed for sampling is the arg sampling_seed.
        """
        if evaluate:
            self._reset_cur_eval_Rs()

        batch_s, batch_a = [[] for i in range(mbsize)], [[] for i in range(mbsize)]
        batch_ri = [[] for i in range(mbsize)]
        env_idx_done_map = {i: False for i in range(mbsize)}
        not_done_envs = [i for i in range(mbsize)]
        env_idx_return_map = {}

        s = tf([i.reset() for i in self.train_envs])
        done = [False] * mbsize

        while not all(done):
            with torch.no_grad():
                # In eveluation, use the trained model to generate samples
                # In training, use the sampling model to generate samples
                if evaluate:
                    pred = self.training_model(s)
                else:
                    pred = self.sampling_model(s)

                edge_mask = s.float()
                logits = (pred[..., : self.action_dim] - self.inf * edge_mask).log_softmax(1)
                if evaluate:
                    sample_ins_probs = (logits / self.temp).softmax(1)
                else:
                    sample_ins_probs = (1 - self.exp_weight) * (logits / self.temp).softmax(1) + self.exp_weight * (1 - edge_mask) / (1 - edge_mask + 0.0000001).sum(1).unsqueeze(1)
                acts = sample_ins_probs.multinomial(1,generator=self.sampling_gen)
                acts = acts.squeeze(-1)

            step = [i.step(a) for i, a in zip([e for d, e in zip(done, self.train_envs) if not d], acts)]

            for dat_idx, (curr_s, curr_a, (_, curr_r, _)) in enumerate(zip(s, acts, step)):
                env_idx = not_done_envs[dat_idx]

                batch_s[env_idx].append(curr_s)
                batch_a[env_idx].append(curr_a.unsqueeze(-1))
                batch_ri[env_idx].append(curr_r)

            for dat_idx, (ns, r, d) in enumerate(step):
                env_idx = not_done_envs[dat_idx]
                env_idx_done_map[env_idx] = d

                if d:
                    env_idx_return_map[env_idx] = r
                    batch_s[env_idx].append(tf(ns))

            not_done_envs = []
            for env_idx in env_idx_done_map:
                if not env_idx_done_map[env_idx]:
                    not_done_envs.append(env_idx)

            c = count(0)
            m = {j: next(c) for j in range(mbsize) if not done[j]}
            done = [bool(d or step[m[i]][2]) for i, d in enumerate(done)]
            s = tf([i[0] for i in step if not i[2]])

        batch_steps = [len(batch_s[i]) for i in range(len(batch_s))]

        for i in range(len(batch_s)):
            batch_s[i] = torch.stack(batch_s[i])
            batch_a[i] = torch.stack(batch_a[i])
            assert batch_s[i].shape[0] - batch_a[i].shape[0] == 1
            batch_ri[i] = torch.tensor(batch_ri[i]).unsqueeze(-1).float().to(self.dev)

        if evaluate:
            for i in range(len(batch_ri)):
                curr_R = torch.prod(batch_ri[i]).item()
                curr_formatted_s = np.where(batch_s[i][-1].cpu().data.numpy()  == 1)[0].tolist()
                self.cur_eval_Rs.append(curr_R)
                if curr_formatted_s not in self.eval_visited_strs:
                    self.all_eval_unique_Rs.append(curr_R)
                    self.cur_eval_unique_Rs.append(curr_R)
                    self.eval_visited_strs.append(curr_formatted_s)
        
        return [batch_s, batch_a, batch_steps, batch_ri]
    
    def _extract_sample_from_batch(self, states, actions, episode_lens, intermediate_rewards,data_idx):
        """
        Extract samples from the batch and generate the prediction probabilities using training_model
        """
        curr_episode_len = episode_lens[data_idx]

        curr_states = states[data_idx][:curr_episode_len, :] 
        curr_actions = actions[data_idx][:curr_episode_len - 1, :] 
        curr_intermediate_rewards = intermediate_rewards[data_idx].squeeze(-1)

        pred = self.training_model(curr_states)

        edge_mask = curr_states.float()
        logits = (pred[..., :self.action_dim] - self.inf * edge_mask).log_softmax(1) 

        init_edge_mask = (curr_states == 0).float() 
        back_logits = ((0 if self.uniform_pb else 1) * pred[..., self.action_dim:-1] - self.inf * init_edge_mask).log_softmax(1) 

        logits = logits[:-1, :].gather(1, curr_actions).squeeze(1) 
        back_logits = back_logits[1:, :].gather(1, curr_actions).squeeze(1)

        return curr_states, curr_episode_len, pred,logits, back_logits, curr_intermediate_rewards
    
    def train_step(self,batch_size,method_name):
        """
        One training step for the agent.
        New training samples are generated by sampling_model, then fed into training_model
        """
        experiences = self.sample_many(batch_size) # collect samples for training
        if method_name in ('fl_db_gfn','fl_subtb_gfn'):
            loss = self.learn_from_fl(experiences) 
        elif method_name in ('db_gfn', 'subtb_gfn','tb_gfn'):
            loss = self.learn_from(experiences)
        else:
            raise ValueError(f"Unknown method name: {method_name}")
        loss.backward()
        return loss
    
    def eval_batch(self, batch_size):
        """
        Evaluate the agent on a batch of samples.
        Compute step-level entropies from FULL distributions (no change to sampler).
        """
        states, actions, episode_lens, intermediate_rewards = self.sample_many(batch_size, evaluate=True)

        fwd_ent_list, fwd_taken_p_list = [], []
        back_ent_list, back_taken_p_list = [], []

        for data_idx in range(len(states)):
            (curr_states, curr_episode_len, pred, logits, back_logits, curr_intermediate_rewards) = \
                self._extract_sample_from_batch(states, actions, episode_lens, intermediate_rewards, data_idx)

            T = curr_episode_len - 1

            with torch.no_grad():
                edge_mask = curr_states.float()
                fwd_full_lp = (pred[..., : self.action_dim] - self.inf * edge_mask).log_softmax(1)
                init_edge_mask = (curr_states == 0).float()
                back_full_lp = ((0 if self.uniform_pb else 1) * pred[..., self.action_dim:-1] - self.inf * init_edge_mask).log_softmax(1)

                fwd_lp = fwd_full_lp[:-1, :]
                fwd_p  = fwd_lp.exp()
                fwd_ent = -(fwd_p * fwd_lp).sum(dim=1).mean()
                fwd_ent_list.append(fwd_ent.item())

                back_lp = back_full_lp[1:, :]
                back_p  = back_lp.exp()
                back_ent = -(back_p * back_lp).sum(dim=1).mean()
                back_ent_list.append(back_ent.item())

                fwd_taken_p_list.append(logits.exp().mean().item())
                back_taken_p_list.append(back_logits.exp().mean().item())

        forward_policy_entropy = float(np.mean(fwd_ent_list))
        forward_avg_action_prob = float(np.mean(fwd_taken_p_list))
        backward_avg_policy_entropy = float(np.mean(back_ent_list))
        backward_avg_action_prob = float(np.mean(back_taken_p_list))
        return forward_policy_entropy, forward_avg_action_prob, backward_avg_policy_entropy, backward_avg_action_prob

    @abstractmethod
    def learn_from(self,batch):
        pass

    @abstractmethod
    def learn_from_fl(self,batch):
        pass

class TBFlowNetAgent(GFNAgent):
    def __init__(self,args,train_envs, test_envs):
        super().__init__(args, train_envs, test_envs) 
        self.Z = torch.zeros((1,)).to(args.dev)
        self.Z.requires_grad_()

    def learn_from(self, batch):
        states, actions, episode_lens, intermediate_rewards = batch

        alpha=self.alpha

        ll_diff = []
        for data_idx in range(len(states)):
            curr_states,curr_episode_len,pred,logits,back_logits,curr_intermediate_rewards = self._extract_sample_from_batch(states, actions, episode_lens, intermediate_rewards,data_idx)

            sum_logits = torch.sum(logits)
            sum_back_logits = torch.sum(back_logits)

            curr_return = torch.prod(curr_intermediate_rewards)

            curr_ll_diff = self.Z + sum_logits - curr_return.log() - sum_back_logits + (curr_episode_len-1)*torch.log(alpha/(1-alpha)) # alpha-GFN, change curr_episode_len to curr_episode_len-1
            ll_diff.append(curr_ll_diff ** 2)

        ll_diff = torch.cat(ll_diff)

        loss = ll_diff.sum() / len(states)

        return loss

class VarTBFlowNetAgent(GFNAgent):
    """
    Use the log-partition variance loss in https://arxiv.org/pdf/2302.05446
    """
    def learn_from(self, batch):
        states, actions, episode_lens, intermediate_rewards = batch

        alpha=self.alpha

        ll_diff = []
        for data_idx in range(len(states)):
            curr_states,curr_episode_len,pred,logits,back_logits,curr_intermediate_rewards = self._extract_sample_from_batch(states, actions, episode_lens, intermediate_rewards,data_idx)

            sum_logits = torch.sum(logits)
            sum_back_logits = torch.sum(back_logits)

            curr_return = torch.prod(curr_intermediate_rewards)

            curr_ll_diff = sum_logits - curr_return.log() - sum_back_logits + (curr_episode_len-1)*torch.log(alpha/(1-alpha)) # alpha-GFN
            ll_diff.append(curr_ll_diff)

        loss = torch.var(ll_diff) / len(states)

        return loss
        

class DBFlowNetAgent(GFNAgent):
    def learn_from(self, batch):
        """
        train the policy with the batch
        :param it: iteration number
        :param batch: a batch of data
        """
        states, actions, episode_lens, intermediate_rewards = batch

        alpha=self.alpha

        ll_diff = []
        for data_idx in range(len(states)):
            curr_states,curr_episode_len,pred,logits,back_logits,curr_intermediate_rewards = self._extract_sample_from_batch(states, actions, episode_lens, intermediate_rewards,data_idx)
            curr_return = torch.prod(curr_intermediate_rewards)
            
            log_flow = pred[..., -1] 
            log_flow = log_flow[:-1] 

            curr_ll_diff = torch.zeros(curr_states.shape[0] - 1).to(self.dev)
            curr_ll_diff += log_flow
            curr_ll_diff += logits
            curr_ll_diff[:-1] -= log_flow[1:] 
            curr_ll_diff -= back_logits
            curr_ll_diff[-1] -= curr_return.log()
            curr_ll_diff+= torch.ones_like(log_flow,device=log_flow.device, requires_grad=False) * torch.log(alpha/(1-alpha)) # alpha-GFN
            ll_diff.append(curr_ll_diff ** 2)

        ll_diff = torch.cat(ll_diff)

        loss = ll_diff.sum() / len(ll_diff)

        return loss

    def learn_from_fl(self, batch):
        states, actions, episode_lens, intermediate_rewards = batch

        alpha=self.alpha

        ll_diff = []
        for data_idx in range(len(states)):
            curr_states,curr_episode_len,pred,logits,back_logits,curr_intermediate_rewards = self._extract_sample_from_batch(states, actions, episode_lens, intermediate_rewards,data_idx)
            log_flow = pred[..., -1] 

            curr_ll_diff = torch.zeros(curr_states.shape[0] - 1).to(self.dev)
            curr_ll_diff += log_flow[:-1]
            curr_ll_diff += logits
            curr_ll_diff -= back_logits
            curr_ll_diff -= log_flow[1:]
            curr_ll_diff -= curr_intermediate_rewards.log()
            curr_ll_diff += torch.ones_like(log_flow[:-1],device=log_flow.device, requires_grad=False) * torch.log(alpha/(1-alpha)) # alpha-GFN
            ll_diff.append(curr_ll_diff ** 2)

        ll_diff = torch.cat(ll_diff)

        loss = ll_diff.sum() / len(ll_diff)

        return loss

class SubTBFlowNetAgent(GFNAgent):
    def learn_from(self, batch):
        """
        train the policy with the batch
        :param it: iteration number
        :param batch: a batch of data
        """
        states, actions, episode_lens, intermediate_rewards = batch

        alpha=self.alpha

        ll_diff = []
        for data_idx in range(len(states)):
            curr_states,curr_episode_len,pred,logits,back_logits,curr_intermediate_rewards = self._extract_sample_from_batch(states, actions, episode_lens, intermediate_rewards,data_idx)
            curr_return = torch.prod(curr_intermediate_rewards)

            log_flow = pred[..., -1] 
            log_flow = log_flow[:-1] 

            cur_loss=torch.zeros(1).to(self.dev)
            cur_lambda=torch.zeros(1).to(self.dev)
            T=curr_states.shape[0] - 1
            for i in range(T):
                for j in range(i+1,T+1):
                    if j==T:
                        acc = log_flow[i]-curr_return.log()
                    else:
                        acc = log_flow[i]-log_flow[j]
                    acc += torch.sum(logits[i:j]-back_logits[i:j])
                    cur_loss+=(acc+(j-i)*torch.log(alpha/(1-alpha))).pow(2) * self.subtb_lambda**(j-i)
                    cur_lambda+=self.subtb_lambda**(j-i)
            ll_diff.append(cur_loss / cur_lambda)

        ll_diff = torch.cat(ll_diff)

        loss = ll_diff.sum() / len(ll_diff)

        return loss

    def learn_from_fl(self, batch):
        states, actions, episode_lens, intermediate_rewards = batch

        alpha=self.alpha

        ll_diff = []
        for data_idx in range(len(states)):
            curr_states,curr_episode_len,pred,logits,back_logits,curr_intermediate_rewards = self._extract_sample_from_batch(states, actions, episode_lens, intermediate_rewards,data_idx)

            log_flow = pred[..., -1] 
            log_r_fl=curr_intermediate_rewards.log()
            T=curr_states.shape[0] - 1
            cur_loss=torch.zeros(1).to(self.dev)
            cur_lambda=torch.zeros(1).to(self.dev)
            for i in range(T):
                for j in range(i+1,T+1):
                    acc = log_flow[i]-log_flow[j] + torch.sum(logits[i:j] - back_logits[i:j] - log_r_fl[i:j])
                    cur_loss+=(acc+(j-i)*torch.log(alpha/(1-alpha))).pow(2) * self.subtb_lambda**(j-i)
                    cur_lambda+=self.subtb_lambda**(j-i)
            ll_diff.append(cur_loss / cur_lambda)

        ll_diff = torch.cat(ll_diff)

        loss = ll_diff.sum() / len(ll_diff)

        return loss
    

def main(args):
    method_name = args.method
    if (args.method == 'db_gfn' or args.method == "subtb_gfn") and args.fl:
        method_name = 'fl_' + args.method
    wdb_name= f'ss({args.sampling_seed})_ms({args.model_seed})_upb({int(args.rand_pb)})_sz({args.size})_m({method_name})_a({args.alpha})'
    
    if args.wdb:
        wandb.init(project=args.wdb_project, name=wdb_name)
        print("Successfully initialized wandb",flush=True)

    args.dev = torch.device(args.device)
    set_device(args.dev)
    print("Successfully allocated device: ", args.dev)

    # In https://arxiv.org/pdf/2302.01687 , exp_weight is fine-tuned for each method by grid search, which corresponds to the epsilon-greedy sampling trick
    # if method_name == 'fl_db_gfn' or method_name == 'fl_subtb_gfn':
    #     args.exp_weight = 0.5
    # elif method_name == 'db_gfn' or method_name == 'subtb_gfn':
    #     args.exp_weight = 1.0
    # elif method_name == 'tb_gfn':
    #     args.exp_weight = 0.0625

    if args.size == 'small':
        args.action_dim = 30
        args.set_size = 20
        intermediate_energies = [0.6961474461702144, 0.883494921538938, -0.2745059751263419, 0.883494921538938, 0.7442282139370466, 0.7442282139370466, -0.40046449111789073, 0.6261749162235306, -0.4381397850674522, 0.4110535720923896, -0.4381397850674522, 0.8206350287761408, 0.6961474461702144, 0.013149744911117978, 0.6961474461702144, 0.4110535720923896, 0.013149744911117978, -0.40046449111789073, -0.4381397850674522, 0.883494921538938, 0.7442282139370466, -0.2745059751263419, 0.6261749162235306, 0.6261749162235306, 0.8206350287761408, -0.40046449111789073, 0.8206350287761408, -0.2745059751263419, 0.4110535720923896, 0.013149744911117978]
        # if method_name in ('fl_db_gfn', 'fl_subtb_gfn'):
        #     args.exp_weight = 1.0
        # elif method_name == 'tb_gfn':
        #     args.exp_weight = 0.
    elif args.size == 'medium':
        args.action_dim = 80
        args.set_size = 60
        intermediate_energies = [-0.5230497555411129, 0.5731802451199923, 0.6881812517688572, -0.7866830411595669, -0.41860806745880197, 0.7396350970666805, 0.5731802451199923, 0.7396350970666805, -0.7866830411595669, 0.6009291618806774, 0.6881812517688572, -0.5230497555411129, -0.7837872949891345, -0.7837872949891345, -0.41860806745880197, -0.7866830411595669, 0.6009291618806774, -0.41860806745880197, -0.3855205729595568, -0.5230497555411129, -0.5230497555411129, -0.41860806745880197, 0.2961218224370139, 0.6881812517688572, 0.6009291618806774, -0.7837872949891345, 0.2961218224370139, -0.7866830411595669, 0.5731802451199923, 0.2961218224370139, -0.3855205729595568, -0.41860806745880197, -0.5230497555411129, -0.7837872949891345, 0.7396350970666805, -0.41860806745880197, 0.6009291618806774, -0.41860806745880197, 0.2961218224370139, -0.5230497555411129, 0.2961218224370139, 0.6881812517688572, -0.5230497555411129, -0.7837872949891345, 0.6881812517688572, 0.2961218224370139, 0.6881812517688572, 0.5731802451199923, 0.5731802451199923, 0.6881812517688572, -0.3855205729595568, -0.3855205729595568, -0.5230497555411129, 0.5731802451199923, 0.2961218224370139, -0.7837872949891345, -0.41860806745880197, 0.7396350970666805, 0.6881812517688572, -0.3855205729595568, 0.7396350970666805, 0.6009291618806774, -0.7837872949891345, -0.7866830411595669, 0.6009291618806774, 0.5731802451199923, 0.7396350970666805, -0.7866830411595669, 0.2961218224370139, -0.7866830411595669, 0.6009291618806774, 0.6009291618806774, 0.5731802451199923, 0.7396350970666805, -0.3855205729595568, 0.7396350970666805, -0.7866830411595669, -0.3855205729595568, -0.3855205729595568, -0.7837872949891345]
    elif args.size == 'large':
        args.action_dim = 100
        args.set_size = 80
        intermediate_energies = [-0.15382957507887518, -0.15382957507887518, -0.8596854107736154, -0.8596854107736154, 0.13832182722858843, -0.4589997720511263, 0.9244538380742333, -0.4589997720511263, -0.5938789622812419, 0.7326989860019331, 0.3925029176153736, 0.7231982581431591, 0.7326989860019331, -0.8596854107736154, 0.9244538380742333, -0.5938789622812419, -0.5938789622812419, 0.13832182722858843, -0.5596177369051512, 0.7326989860019331, -0.5596177369051512, -0.8596854107736154, 0.13832182722858843, 0.9244538380742333, 0.7231982581431591, -0.15382957507887518, -0.15382957507887518, -0.5938789622812419, 0.9244538380742333, 0.13832182722858843, 0.9244538380742333, 0.3925029176153736, -0.15382957507887518, 0.3925029176153736, -0.4589997720511263, -0.4589997720511263, -0.8596854107736154, 0.13832182722858843, -0.8596854107736154, -0.5596177369051512, 0.13832182722858843, -0.8596854107736154, 0.13832182722858843, 0.7326989860019331, 0.9244538380742333, 0.7231982581431591, 0.13832182722858843, 0.7231982581431591, -0.4589997720511263, -0.15382957507887518, 0.7326989860019331, -0.5938789622812419, -0.5938789622812419, -0.5596177369051512, 0.9244538380742333, 0.7231982581431591, 0.7231982581431591, -0.8596854107736154, 0.7326989860019331, -0.4589997720511263, 0.3925029176153736, -0.5938789622812419, 0.7326989860019331, 0.3925029176153736, 0.9244538380742333, -0.4589997720511263, -0.5596177369051512, -0.15382957507887518, 0.3925029176153736, -0.5938789622812419, -0.15382957507887518, -0.5938789622812419, 0.13832182722858843, 0.7326989860019331, -0.5596177369051512, -0.8596854107736154, -0.15382957507887518, -0.5596177369051512, 0.9244538380742333, 0.3925029176153736, -0.4589997720511263, -0.8596854107736154, 0.7326989860019331, 0.7231982581431591, -0.4589997720511263, -0.5596177369051512, -0.5596177369051512, 0.3925029176153736, 0.3925029176153736, 0.7231982581431591, 0.13832182722858843, -0.15382957507887518, 0.7231982581431591, 0.7231982581431591, -0.5596177369051512, -0.5938789622812419, 0.7326989860019331, -0.4589997720511263, 0.3925029176153736, 0.9244538380742333]
        # if method_name == 'tb_gfn':
        #     args.exp_weight = 0.

    train_envs = [SetEnv(args.action_dim, args.set_size, intermediate_energies) for i in range(args.mbsize)]
    test_envs = [SetEnv(args.action_dim, args.set_size, intermediate_energies) for i in range(args.test_set_size)]
    # Prepare the environment
    print("Successfully initialized the environment",flush=True)

    if args.method == 'tb_gfn':
        agent = TBFlowNetAgent(args, train_envs, test_envs)
    elif args.method == 'db_gfn':
        agent = DBFlowNetAgent(args, train_envs, test_envs)
    elif args.method == 'subtb_gfn':
        agent = SubTBFlowNetAgent(args, train_envs, test_envs)
    else:
        raise ValueError(f"Unknown method: {args.method}")

    if args.method == 'tb_gfn':
        opt = torch.optim.Adam([{'params': agent.parameters(), 'lr': args.tb_lr}, {'params':[agent.Z], 'lr': args.tb_z_lr} ])
    elif args.method == 'db_gfn' or args.method == 'subtb_gfn':
        opt = torch.optim.Adam([{'params': agent.parameters(), 'lr': args.tb_lr}])
    else:
        raise ValueError(f"Unknown method: {args.method}")

    # Generate the test set
    agent.gen_test_set_with_sampling_model(test_set_size=args.test_set_size)
    # exp_weight scheduling: from 1.0 to args.exp_weight (help strings are in English)
    if args.use_exp_weight_decay:
        exp_sched = ExpWeightScheduler(
            end=args.exp_weight,
            total_steps=args.n_train_steps,
            kind=args.exp_weight_sched,
            warm_frac=args.exp_weight_warm_frac,
        )

    if args.use_alpha_scheduler:
        alpha_sched = AlphaScheduler(
            total_steps=args.n_train_steps,
            alpha0=args.alpha,
            warm_frac=args.alpha_warm_frac,
            alpha_final=0.5
        )

    def _set_agent_alpha_value(agent, value, device):
        with torch.no_grad():
            if isinstance(agent.alpha, torch.nn.Parameter):
                agent.alpha.data = torch.tensor(float(value), device=device)
            else:
                agent.alpha = torch.tensor(float(value), device=device)

    
    print(f"Successfully initialized the optimizer, start training",flush=True)

    # ---- clipping interval accumulators ----
    clip_sum_trigger = 0.0
    clip_sum_scale = 0.0
    clip_count = 0
    # ---------------------------------------


    for i in range(args.n_train_steps):
        cur_time=time.time()
        # Update alpha by schedule (only when NOT trainable and scheduler enabled)
        if args.use_alpha_scheduler:
            _set_agent_alpha_value(agent, alpha_sched(i), args.dev)
        
        if args.use_exp_weight_decay:
            agent.exp_weight = exp_sched(i)

        forward_policy_entropy,forward_avg_action_prob,backward_avg_policy_entropy,backward_avg_action_prob = agent.eval_batch(args.mbsize) # evaluate the current trainiing model on a batch of samples
        opt.zero_grad()
        loss = agent.train_step(args.mbsize, method_name)

        if args.use_grad_clip and (args.grad_clip_norm is not None) and (args.grad_clip_norm > 0):
            params_to_clip = []
            for g in opt.param_groups:
                for p in g['params']:
                    if p.grad is not None:
                        params_to_clip.append(p)
            total_norm = torch.nn.utils.clip_grad_norm_(params_to_clip, max_norm=args.grad_clip_norm)
            total_norm = total_norm.cpu().item()
            tau = float(args.grad_clip_norm)
            scale = float(tau / (total_norm + 1e-12)) if total_norm > tau else 1.0
            trigger = 1.0 if total_norm > tau else 0.0

            clip_sum_scale += scale
            clip_sum_trigger += trigger
            clip_count += 1

        opt.step()
        
        agent.update_sampling_model_with_training_model()


        if i % args.n_eval_steps == 0 or i==args.n_train_steps - 1:
            # print evaluation results
            print(f"{wdb_name}, step={i}, loss={loss.detach().cpu().item()}, time={time.time()-cur_time:.2f}s",flush=True)

            all_unique_Rs = sorted(agent.all_eval_unique_Rs, reverse=True)
            all_unique_Rs = np.array(all_unique_Rs)
            top_100_Rs = all_unique_Rs[:100] if len(all_unique_Rs) >= 100 else all_unique_Rs
            mean_top_100_R = np.mean(top_100_Rs)
            top_1000_Rs = all_unique_Rs[:1000] if len(all_unique_Rs) >= 1000 else all_unique_Rs
            mean_top_1000_R = np.mean(top_1000_Rs)
            mean_R=sum(all_unique_Rs)/len(all_unique_Rs)
            modes=(all_unique_Rs > args.mode_threshold).sum()
            cur_alpha=agent.alpha.detach().cpu().item()
            cur_loss=loss.detach().cpu().item()
            training_dict={
                'step':i,
                'mean_top_100_R': mean_top_100_R,
                'mean_top_1000_R': mean_top_1000_R,
                'mean_R': mean_R,
                'modes': modes,
                'loss': cur_loss,
                'unique_samples': len(all_unique_Rs),
                'exp_weight': float(agent.exp_weight),
                'alpha': cur_alpha,
                'forward_policy_entropy_eval' : forward_policy_entropy,
                'forward_avg_action_prob_eval' : forward_avg_action_prob,
                'backward_policy_entropy_eval' : backward_avg_policy_entropy,
                'backward_avg_action_prob_eval' : backward_avg_action_prob,
            }
            cur_unique_Rs, _ = agent.get_eval_Rs()
            num_cur_unique_Rs = len(cur_unique_Rs)
            if num_cur_unique_Rs > 0:
                training_dict['avg_cur_unique_R']=sum(cur_unique_Rs)/num_cur_unique_Rs
                training_dict['num_cur_unique_R']=num_cur_unique_Rs
            else:
                training_dict['avg_cur_unique_R'], training_dict['num_cur_unique_R'] = np.nan, 0
            print(f"{wdb_name}, "+_dict_to_str(training_dict),flush=True)

            # ---- log interval averages for clipping ----
            if clip_count > 0:
                training_dict['clip_trigger_rate_interval'] = clip_sum_trigger / clip_count
                training_dict['clip_scale_avg_interval'] = clip_sum_scale / clip_count
            else:
                training_dict['clip_trigger_rate_interval'] = 0
                training_dict['clip_scale_avg_interval'] = 1.0

            # reset accumulators for the next interval
            clip_sum_trigger = 0.0
            clip_sum_scale = 0.0
            clip_count = 0
            # -------------------------------------------

            # load test set statistics
            corr_methods=['pearson','spearman']
            if i % args.n_test_steps == 0 or i == args.n_train_steps-1:
                test_stats=agent.calc_stats_in_test_set(corr_methods)
                for method,value in test_stats.items():
                    if method in corr_methods:
                        training_dict[f'{method}_corr_test']=test_stats[method] # add correlation results to training_dict
                    else:
                        training_dict[f'{method}_test']=test_stats[method]
                print(f"{wdb_name}, step={i}, test, "+_dict_to_str(test_stats))
            else:
                for method,value in test_stats.items():
                    if method in corr_methods:
                        training_dict[f'{method}_corr_test']=test_stats[method] # add correlation results to training_dict
                    else:
                        training_dict[f'{method}_test']=test_stats[method]
                
            if i == args.n_train_steps-1:
                # load settings to wandb
                setting_dict=vars(args).copy()
                setting_dict['alpha_init'] = setting_dict.get('alpha')
                setting_dict['exp_weight_final'] = setting_dict.get('exp_weight')
                for k in ('alpha','exp_weight' 'wdb', 'device', 'dev','wdb_project'):
                    setting_dict.pop(k, None)
                training_dict.update(setting_dict)
            if args.wdb:
                wandb.log(training_dict)


def process_bool_args(args):
    """
    Parse 0/1 args to int
    """
    args.wdb=bool(args.wdb)
    args.fl=bool(args.fl)
    args.use_alpha_scheduler = bool(args.use_alpha_scheduler)
    args.use_grad_clip = bool(args.use_grad_clip)

    return args

def set_model_seed(args):
    """
    Fix the seed for model initialization, training, etc.
    """
    torch.manual_seed(args.model_seed)
    np.random.seed(args.model_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.cuda.manual_seed_all(args.model_seed)
    random.seed(args.model_seed)


if __name__ == '__main__':
    print("start parsing args...", flush=True)
    args = parser.parse_args()
    for arg, value in vars(args).items():
        print(f"{arg}: {value}")
    args=process_bool_args(args)
    print("successfully parsed args!", flush=True)

    torch.set_num_threads(args.num_threads)
    set_model_seed(args)
    try:
        main(args)
    finally:
        if getattr(args, 'wdb', False):
            wandb.finish()

