import numpy as np
import torch
import gym
import argparse
import os
import random
import math
import time
import copy
from pathlib import Path
import yaml
import h5py
from tqdm import tqdm

import algo.utils as utils
from envs.env_utils import call_terminal_func
from envs.common import call_env
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler

import ott
import d4rl
import scipy as sp

import ot
import jax.numpy as jnp
import numpy as np
import jax

# from MulticoreTSNE import MulticoreTSNE as TSNE
import matplotlib.pyplot as plt
import seaborn as sns
import subprocess

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

import umap

from sklearn.neighbors import NearestNeighbors
from sklearn.metrics.pairwise import pairwise_kernels
from scipy import linalg
from sklearn.metrics.pairwise import cosine_similarity
from scipy.spatial import distance


class ResBlock(nn.Module):
    def __init__(self, dim, drop_prob=0.1):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.linear = nn.Linear(dim, dim)
        self.activation = nn.SiLU()
        self.drop_path = DropPath(drop_prob)

    def forward(self, x):
        return x + self.drop_path(self.activation(self.linear(self.norm(x))))

class DropPath(nn.Module):
    def __init__(self, drop_prob=0.1):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if not self.training or self.drop_prob == 0.:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        binary_mask = torch.floor(random_tensor)
        return x / keep_prob * binary_mask

class InverseRewardModelResNet(nn.Module):
    def __init__(self, state_dim, hidden_dim=256, n_layers=3):
        super().__init__()
        input_dim = state_dim * 2
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        self.res_blocks = nn.Sequential(*[ResBlock(hidden_dim) for _ in range(n_layers)])
        self.output_layer = nn.Linear(hidden_dim, 1)

    def forward(self, s):
        # s: (batch, state_dim * 2) 형태
        x = self.input_proj(s)
        x = self.res_blocks(x)
        return self.output_layer(x).squeeze(-1)

class InverseActionModelResNet(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256, n_layers=3):
        super().__init__()
        input_dim = state_dim * 2
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        self.res_blocks = nn.Sequential(*[ResBlock(hidden_dim) for _ in range(n_layers)])
        self.output_layer = nn.Linear(hidden_dim, action_dim)

    def forward(self, s):
        x = self.input_proj(s)
        x = self.res_blocks(x)
        return self.output_layer(x)

def train_inverse_action_model(model, x, y, device, epochs=1000, batch_size=128, lr=1e-3):
    dataset = TensorDataset(torch.from_numpy(x).float(), torch.from_numpy(y).float())
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for batch_x, batch_y in loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            pred = model(batch_x)
            loss = F.mse_loss(pred, batch_y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * batch_x.size(0)
        avg_loss = total_loss / len(dataset)
        if epoch % 100 == 0 or epoch == 1:
            print(f"[InverseActionModel] Epoch {epoch} | Loss: {avg_loss:.6f}")

def train_inverse_reward_model(model, x, y, device, epochs=1000, batch_size=128, lr=1e-3):
    dataset = TensorDataset(torch.from_numpy(x).float(), torch.from_numpy(y).float())
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for batch_x, batch_y in loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            pred = model(batch_x)
            loss = F.mse_loss(pred, batch_y.squeeze(-1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * batch_x.size(0)
        avg_loss = total_loss / len(dataset)
        if epoch % 100 == 0 or epoch == 1:
            print(f"[InverseRewardmodel] Epoch {epoch} | Loss: {avg_loss:.6f}")

class RewardModel(nn.Module):
    def __init__(self, state_dim, hidden_dim=256, n_layers=3, use_bn=False):
        super().__init__()
        input_dim = state_dim * 2
        layers = []
        for _ in range(n_layers):
            layers.append(nn.Linear(input_dim, hidden_dim))
            if use_bn:
                layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.SiLU())
            input_dim = hidden_dim
        layers.append(nn.Linear(hidden_dim, 1))
        self.net = nn.Sequential(*layers)
    def forward(self, s):
        return self.net(s).squeeze(-1)


class NextStateScoreNet(nn.Module):
    def __init__(self, state_dim, hidden_dim=256, n_layers=4, emb_dim=128, bn=False):
        super().__init__()
        self.state_dim = state_dim
        self.emb_dim = emb_dim

        self.cond_mlp = nn.Sequential(
            nn.Linear(state_dim, emb_dim),
            nn.SiLU(),
            nn.Linear(emb_dim, emb_dim),
            nn.SiLU()
        )
        if bn:
            self.time_mlp = nn.Sequential(
                nn.Linear(1, emb_dim),
                nn.BatchNorm1d(emb_dim),
                nn.SiLU(),
                nn.Linear(emb_dim, emb_dim),
                nn.BatchNorm1d(emb_dim),
                nn.SiLU()
            )
        else:
            self.time_mlp = nn.Sequential(
                nn.Linear(1, emb_dim),
                nn.SiLU(),
                nn.Linear(emb_dim, emb_dim),
                nn.SiLU()
            )

        self.blocks = nn.ModuleList()
        input_dim = state_dim + emb_dim + emb_dim  
        last_dim = input_dim
        for _ in range(n_layers):
            layers = [nn.Linear(last_dim, hidden_dim)]
            if bn:
                layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.SiLU())
            self.blocks.append(nn.Sequential(*layers))
            last_dim = hidden_dim

        self.final = nn.Linear(hidden_dim, state_dim)

    def forward(self, noisy_s_prime, t, cond_s):
        if t.dim() == 1:
            t = t.unsqueeze(-1)                  
        t_emb = self.time_mlp(t)                  
        cond_emb = self.cond_mlp(cond_s)           
        h = torch.cat([noisy_s_prime, t_emb, cond_emb], dim=-1)  
        for block in self.blocks:
            h_pre = h
            h = block(h)
            if h.shape == h_pre.shape:
                h = h + h_pre  
        out = self.final(h)             
        return out

class LabelCondScoreNet(nn.Module):
    def __init__(self, input_dim, label_dim=1, hidden_dim=256, n_layers=4, emb_dim=128, bn=False):
        super().__init__()
        self.input_dim = input_dim
        self.label_dim = label_dim
        self.emb_dim = emb_dim
        self.label_mlp = nn.Sequential(
            nn.Linear(label_dim, emb_dim),
            nn.SiLU(),
            nn.Linear(emb_dim, emb_dim),
            nn.SiLU()
        )
        self.time_mlp = nn.Sequential(
            nn.Linear(1, emb_dim),
            nn.SiLU(),
            nn.Linear(emb_dim, emb_dim),
            nn.SiLU()
        )
        self.blocks = nn.ModuleList()
        last_dim = input_dim + emb_dim + emb_dim  
        for _ in range(n_layers):
            block = []
            block.append(nn.Linear(last_dim, hidden_dim))
            if bn:
                block.append(nn.BatchNorm1d(hidden_dim))
            block.append(nn.SiLU())
            self.blocks.append(nn.Sequential(*block))
            last_dim = hidden_dim
        self.final = nn.Linear(hidden_dim, input_dim)

    def forward(self, x, t, label):
        if t.dim() == 1:
            t = t.unsqueeze(-1)
        t_emb = self.time_mlp(t)
        label_emb = self.label_mlp(label)
        h = torch.cat([x, t_emb, label_emb], dim=-1)

        for block in self.blocks:
            h_pre = h
            h = block(h)
            if h.shape == h_pre.shape:  
                h = h + h_pre
        out = self.final(h)
        return out


def sigma_t(t, alpha_min=0.1, alpha_max=20.0):
    B = alpha_min * t + 0.5 * (alpha_max - alpha_min) * t ** 2
    return torch.sqrt(1.0 - torch.exp(-B))

def lambda_t(sigma):
    return sigma ** 2


def train_label_cond_score_model( 
    model, src_data, tgt_data, src_label, tgt_label,
    optimizer, device, epochs=100, batch_size=128, alpha_min=0.01, alpha_max=50.0
):
    N_tgt = tgt_data.shape[0]
    N_src = src_data.shape[0]
    
    x_tensor = torch.from_numpy(tgt_data).float()
    y_tensor = torch.from_numpy(tgt_label).float()
    dataset = torch.utils.data.TensorDataset(x_tensor, y_tensor)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size//2, shuffle=True)  

    x_src_tensor = torch.from_numpy(src_data).float()
    y_src_tensor = torch.from_numpy(src_label).float()
    src_dataset = torch.utils.data.TensorDataset(x_src_tensor, y_src_tensor)
    src_dataloader = torch.utils.data.DataLoader(src_dataset, batch_size=batch_size//2, shuffle=True) 
    def next_batch(loader, it):
        try:
            batch = next(it)
        except StopIteration:
            it = iter(loader)   
            batch = next(it)
        return batch, it
    src_data_iter = iter(src_dataloader)


    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for batch_x, batch_label in dataloader:
            batch_x = batch_x.to(device)
            batch_label = batch_label.to(device)
            
            (batch_x_src, batch_label_src), src_data_iter = next_batch(src_dataloader, src_data_iter)
            batch_x_src, batch_label_src = batch_x_src.to(device), batch_label_src.to(device)
            batch_x = torch.cat([batch_x, batch_x_src], dim=0)
            batch_label = torch.cat([batch_label, batch_label_src], dim=0)
            batch_x = batch_x[torch.randperm(batch_x.size(0))]
            batch_label = batch_label[torch.randperm(batch_label.size(0))]

            batch_size_ = batch_x.size(0)
            t = torch.rand(batch_size_, device=device).unsqueeze(-1)
            sigma = sigma_t(t, alpha_min, alpha_max)
            z = torch.randn_like(batch_x)
            xt = batch_x + sigma * z
            target = -z / torch.clamp(sigma, min=1e-9)
            score_pred = model(xt, t, batch_label)
            weight = lambda_t(sigma.squeeze(-1))
            loss = ((score_pred - target) ** 2).sum(dim=1) * weight
            loss = loss.mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * batch_size_

        avg_loss = total_loss / len(dataloader.dataset)
        if epoch % 100 == 0:
            print(f"[LabelCondScoreNet DL] Epoch {epoch+1} | Loss: {avg_loss:.6f}")





def train_conditional_next_state_score_model(
    model, tgt_s, tgt_s_prime,
    optimizer, device, epochs=2000, batch_size=128,
    alpha_min=0.1, alpha_max=20.0,
):
    x_tensor = torch.from_numpy(tgt_s_prime).float()
    cond_tensor = torch.from_numpy(tgt_s).float()
    dataset = TensorDataset(x_tensor, cond_tensor)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for batch_noisy_s_prime, batch_cond_s in dataloader:
            batch_noisy_s_prime = batch_noisy_s_prime.to(device)
            batch_cond_s = batch_cond_s.to(device)
            batch_size_ = batch_noisy_s_prime.size(0)

            t = torch.rand(batch_size_, device=device)
            sigma = sigma_t(t, alpha_min, alpha_max)
            z = torch.randn_like(batch_noisy_s_prime)
            xt = batch_noisy_s_prime + sigma[:, None] * z
            target = -z / torch.clamp(sigma[:, None], min=1e-9) 
            score_pred = model(xt, t, batch_cond_s)
            weight = lambda_t(sigma)
            loss = ((score_pred - target) ** 2).sum(dim=1) * weight
            loss = loss.mean()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * batch_size_

        avg_loss = total_loss / len(dataloader.dataset)
        if epoch % 100 == 0 or epoch == epochs - 1:
            print(f"Epoch {epoch+1} | Loss: {avg_loss:.6f}")



def make_time_grid(num_steps, t_start=1.0, t_end=1e-5, gamma=3.0, device='cuda'):
    u = torch.linspace(0., 1., num_steps, device=device)
    u_gamma = u ** gamma                                     
    t_seq = t_start * (1. - u_gamma) + t_end * u_gamma 
    return t_seq

def pc_sampler_cond(
    batch_size, dim, score_model, ymax, num_steps=1000,  
    alpha_min=0.1, alpha_max=20,
    corrector=True, snr=0.16, corrector_steps=1,
    device='cuda', t_end=1e-5, t_gamma=3, 
):

    t_seq = make_time_grid(num_steps+1, t_start=1.0, t_end=t_end, gamma=t_gamma, device=device)
    sigma_seq = sigma_t(t_seq, alpha_min, alpha_max)

    x_t = torch.randn(batch_size, dim, device=device) * sigma_seq[0]

    label = torch.rand(batch_size, 1, device=device) * float(ymax) 
    label = label.float()

    with torch.no_grad():
        for i in range(num_steps):
            is_last_step = (i == num_steps - 1)
            sigma_t_cur = sigma_seq[i]
            sigma_t_minus_1 = sigma_seq[i+1]
            sigma_delta_sq = sigma_t_cur ** 2 - sigma_t_minus_1 ** 2
            sigma_delta = torch.sqrt(torch.clamp(sigma_delta_sq, min=1e-9))

            t_cur = t_seq[i].repeat(batch_size, 1)  
            t_minus_1 = t_seq[i+1].repeat(batch_size, 1)  
            grad_t = score_model(x_t, t_cur, label)
            x_t_minus_1 = x_t + sigma_delta_sq * grad_t
            noise = torch.randn_like(x_t)
            x_t_minus_1 = x_t_minus_1 + sigma_delta * noise

            if corrector and not is_last_step:
                for _ in range(corrector_steps):
                    grad_t_minus_1 = score_model(x_t_minus_1, t_minus_1, label)
                    grad_norm = torch.norm(grad_t_minus_1.reshape(batch_size, -1), dim=-1).mean()
                    noise_norm = (dim ** 0.5)
                    step_size = (snr * noise_norm / (grad_norm + 1e-9)) ** 2 * 2
                    noise = torch.randn_like(x_t_minus_1)
                    x_t_minus_1 = x_t_minus_1 + step_size * grad_t_minus_1 + torch.sqrt(2 * step_size) * noise
            x_t = x_t_minus_1

    return x_t.cpu().numpy()



def conditional_sampler(
    cond_s,              
    score_model,         
    num_steps=1000,
    alpha_min=0.1,
    alpha_max=20.0,
    corrector=True,
    snr=0.16,
    corrector_steps=1,
    device='cuda',
    t_end=1e-5,
    t_gamma=3,
):
    batch_size, state_dim = cond_s.shape

    u = torch.linspace(0., 1., num_steps, device=device) ** t_gamma
    t_seq = 1.0 * (1 - u) + t_end * u

    sigma_seq = sigma_t(t_seq)  
    x_t = torch.randn(batch_size, state_dim, device=device) * sigma_seq[0]
    cond_s = cond_s.to(device)

    for i in range(num_steps - 1):
        t_cur = t_seq[i].repeat(batch_size, 1)
        t_next = t_seq[i+1].repeat(batch_size, 1)
        sigma_cur = sigma_seq[i]
        sigma_next = sigma_seq[i+1]

        sigma_delta_sq = sigma_cur ** 2 - sigma_next ** 2
        sigma_delta = torch.sqrt(torch.clamp(sigma_delta_sq, min=1e-9))

        grad = score_model(x_t, t_cur, cond_s)
        x_t_minus_1 = x_t + sigma_delta_sq * grad
        noise = torch.randn_like(x_t)
        x_t_minus_1 = x_t_minus_1 + sigma_delta * noise

        if corrector and i < num_steps - 2:
            for _ in range(corrector_steps):
                grad_corrector = score_model(x_t_minus_1, t_next, cond_s)
                grad_norm = torch.norm(grad_corrector.reshape(batch_size, -1), dim=-1).mean()
                noise_norm = np.sqrt(state_dim)
                step_size = 2.0 * (snr * noise_norm / (grad_norm + 1e-9)) ** 2
                noise_corrector = torch.randn_like(x_t_minus_1)
                x_t_minus_1 = x_t_minus_1 + step_size * grad_corrector + torch.sqrt(2 * step_size) * noise_corrector
        x_t = x_t_minus_1

    gen_s_prime = x_t.cpu().numpy()
    cond_s_np = cond_s.cpu().numpy()

    transitions = np.concatenate([cond_s_np, gen_s_prime], axis=1)  
    return transitions


def conditional_sampler_batch(
        cond_s_all, score_model, batch_size=512, **sampler_kwargs
):
    score_model.eval()
    device = sampler_kwargs.get('device', 'cuda')
    all_samples = []
    n = cond_s_all.shape[0]

    with torch.no_grad():
        for start in range(0, n, batch_size):
            end = min(start + batch_size, n)
            cond_s_batch = cond_s_all[start:end].to(device)
            samples_batch = conditional_sampler(cond_s_batch, score_model, **sampler_kwargs)
            all_samples.append(samples_batch)
            torch.cuda.empty_cache()
    return np.vstack(all_samples)



def filter_outliers(X, thresh, max_violation_frac=0.1, method="z_score"):

    if method == "iqr":
        q1 = np.nanpercentile(X, 25, axis=0)
        q3 = np.nanpercentile(X, 75, axis=0)
        iqr = q3 - q1
        lower = q1 - thresh * iqr
        upper = q3 + thresh * iqr
        violations = (X < lower) | (X > upper)
        row_violation_frac = violations.mean(axis=1)
        mask = row_violation_frac <= max_violation_frac
        return X[mask]

    elif method == "robust_z_score":
        med = np.nanmedian(X, axis=0)
        mad = np.nanmedian(np.abs(X - med), axis=0)
        mad = np.where(mad == 0, 1e-9, mad)              
        rzs = 0.6745 * (X - med) / mad                    
        violations = (np.abs(rzs) > thresh)
        row_violation_frac = violations.mean(axis=1)
        mask = row_violation_frac <= max_violation_frac   
        return X[mask]

    elif method == "z_score":
        mu = X.mean(axis=0)
        std = X.std(axis=0) + 1e-12
        z = np.abs((X - mu) / std)
        violations = (np.abs(z) > thresh)
        row_violation_frac = violations.mean(axis=1)
        mask = row_violation_frac <= max_violation_frac  
        return X[mask]



if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--dir", default="./costlogs")
    parser.add_argument("--policy", default="OTDF", help='policy to use, support OTDF')
    parser.add_argument("--env", default="halfcheetah")
    parser.add_argument("--seed", default=0, type=int)   
    parser.add_argument("--metric", default='cosine', type=str)     
    parser.add_argument('--srctype', default='medium', type=str)
    parser.add_argument("--tartype", default='medium', type=str)
    parser.add_argument("--steps", default=1e6, type=int)
    parser.add_argument("--source-sample-size", type=int, default=5000, help="Number of source samples for t-SNE (default: 5000)")
    parser.add_argument("--score_epoch", type=int, default=800, help="Num of score network training epoch (default: 800")
    parser.add_argument("--idn", type=str, default=None, help="Identity of experiments")
    parser.add_argument("--tr_score", type=int, default=1, help="choose training score function")
    
    parser.add_argument("--num_gen", type=int, default=200000, help="num of gen samples")
    parser.add_argument("--deno_steps", type=int, default=500, help="choose denoising iteration step")
    parser.add_argument("--corrector_steps", type=int, default=1)
    parser.add_argument("--alpha_min", type=float, default=0.1)
    parser.add_argument("--alpha_max", type=float, default=20.0)
    parser.add_argument("--snr", type=float, default=0.05)
    parser.add_argument("--t_end", type=float, default=1e-6)
    parser.add_argument("--t_gamma", type=int, default=4)
    
    parser.add_argument("--model_save_dir", type=str, default="stateEP10000,tranEP10000,seed2")
    parser.add_argument("--state_score_epoch", type=int, default=10000)
    parser.add_argument("--tran_score_epoch", type=int, default=10000)

    parser.add_argument("--ymax", type=float, default=0.2, help="choosing conditional state lambda(default=0)")
    parser.add_argument("--use_z_thresh", type=int, default=0)
    parser.add_argument("--z_thresh", type=float, default=3.0)
    parser.add_argument("--max_frac", type=float, default=0.0)

    parser.add_argument("--wandb_idn", type=str, default=None)
    parser.add_argument("--reg_weight", type=float, default=0.001)
   



    args = parser.parse_args()

    with open(f"{str(Path(__file__).parent.absolute())}/config/{args.policy.lower()}/{args.env.replace('-', '_')}.yaml", 'r', encoding='utf-8') as f:
        config = yaml.safe_load(f)

    print("------------------------------------------------------------")
    print("Policy: {}, Env: {}, Seed: {}".format(args.policy, args.env, args.seed))
    print("------------------------------------------------------------")
    
    outdir = args.dir + '/' + args.env + '-srcdatatype-' + args.srctype + '-tardatatype-' + args.tartype

    if not os.path.exists(args.dir):
        os.makedirs(args.dir)
    
    if '_' in args.env:
        args.env = args.env.replace('_', '-')
    
    src_env_name = args.env.split('-')[0] + '-' + args.srctype + '-v2'
    src_env = gym.make(src_env_name)
    src_env.seed(args.seed)
    tar_env = call_env(config['tar_env_config'])
    tar_env.seed(args.seed)
    src_eval_env = copy.deepcopy(src_env)
    src_eval_env.seed(args.seed + 100)
    tar_eval_env = copy.deepcopy(tar_env)
    tar_eval_env.seed(args.seed + 100)

    src_env.action_space.seed(args.seed)
    tar_env.action_space.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)

    state_dim = src_env.observation_space.shape[0]
    action_dim = src_env.action_space.shape[0] 
    max_action = float(src_env.action_space.high[0])
    min_action = -max_action
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    config['metric'] = args.metric

    config.update({
        'state_dim': state_dim,
        'action_dim': action_dim,
        'max_action': max_action,
    })

    src_replay_buffer = utils.OTReplayBuffer(state_dim, action_dim, device)
    tar_replay_buffer = utils.ReplayBuffer(state_dim, action_dim, device)

    src_dataset = d4rl.qlearning_dataset(src_env)
    tar_dataset = utils.call_tar_dataset(args.env, args.tartype)

    source_terminal_info = src_dataset['terminals']
    src_replay_buffer.convert_D4RL(src_dataset)
    tar_replay_buffer.convert_D4RL(tar_dataset)

    src_plot_data = np.hstack([
        src_replay_buffer.state,
        src_replay_buffer.action,
        src_replay_buffer.next_state
    ])

    tar_plot_data = np.hstack([
        tar_replay_buffer.state,
        tar_replay_buffer.action,
        tar_replay_buffer.next_state
    ])

    training_score = args.tr_score

    exp_dir = os.path.join("experiments", args.idn)
    os.makedirs(exp_dir, exist_ok=True)

    src_data        = np.concatenate([src_replay_buffer.state, src_replay_buffer.next_state], axis=1)
    src_action_data = np.concatenate([src_replay_buffer.state, src_replay_buffer.action, src_replay_buffer.next_state], axis=1)
    tgt_data        = np.concatenate([tar_replay_buffer.state, tar_replay_buffer.next_state], axis=1)
    tgt_action_data = np.concatenate([tar_replay_buffer.state, tar_replay_buffer.action, tar_replay_buffer.next_state], axis=1)

    tgt_action = tar_replay_buffer.action
    tgt_reward = tar_replay_buffer.reward

    tgt_labels = np.zeros((len(tgt_data), 1))  # label=0
    src_labels = np.ones((len(src_data),  1))  # label=1

    all_data  = np.concatenate([tgt_data[:,:state_dim], src_data[:,:state_dim]], axis=0)
    all_labels  = np.concatenate([tgt_labels, src_labels], axis=0)

    src_reward_data = src_replay_buffer.reward
    tgt_reward_data = tar_replay_buffer.reward
    input_dim = state_dim + state_dim
    input_action_dim = state_dim + action_dim+state_dim



    exp_dir = os.path.join("experiments", args.idn)
    os.makedirs(exp_dir, exist_ok=True)
    model_save_dir = os.path.join("experiments", args.env + '_' + args.srctype + '_' + args.tartype +'_' + args.model_save_dir+'_' + str(args.seed))
    os.makedirs(model_save_dir, exist_ok=True)

    model_paths = {  
        "score": os.path.join(model_save_dir, "score_model.pth"),
        "state": os.path.join(model_save_dir ,"state_model.pth"),
        "cond": os.path.join(model_save_dir ,"cond_model.pth"),
        "inv_action": os.path.join(model_save_dir, "inverse_action_model.pth"),
        "reward": os.path.join(model_save_dir, "reward_model.pth")
    }

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    mean = tgt_data.mean(axis=0)
    std = tgt_data.std(axis=0)
    epsilon = 1e-9

    train_data_norm = (tgt_data - mean) / (std + epsilon)

    src_data_norm = (src_data - mean) / (std + epsilon)

    x = torch.from_numpy(train_data_norm).float()
    dataset = torch.utils.data.TensorDataset(x)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)

    tran_cond_model        = NextStateScoreNet(state_dim=state_dim).to(device) 
    state_model            = LabelCondScoreNet(input_dim=state_dim).to(device)
    inv_action_model_train = InverseActionModelResNet(state_dim, action_dim).to(device)
    reward_model_train     = InverseRewardModelResNet(state_dim).to(device)

    optimizer_cond = optim.Adam(tran_cond_model.parameters(), lr=1e-4)
    optimizer_state = optim.Adam(state_model.parameters(), lr=1e-4)
    torch.cuda.empty_cache()

    if training_score:
        train_conditional_next_state_score_model(tran_cond_model, train_data_norm[:, :state_dim],
                                                 train_data_norm[:, state_dim:], optimizer_cond, device, epochs = args.tran_score_epoch)
        torch.save(tran_cond_model.state_dict(), model_paths["cond"])
        print(f"ScoreNet weights saved to {model_paths['cond']}")

        train_label_cond_score_model(
            model=state_model,
            src_data=src_data_norm[:,:state_dim],
            tgt_data=train_data_norm[:,:state_dim],
            src_label=src_labels,
            tgt_label=tgt_labels,
            optimizer=optimizer_state,
            device=device,
            epochs=args.state_score_epoch,
            batch_size=128,
            alpha_min=args.alpha_min,
            alpha_max=args.alpha_max,
        )
        torch.save(state_model.state_dict(), model_paths["state"])
        print(f"ScoreNet weights saved to {model_paths['state']}")

        train_inverse_action_model(inv_action_model_train, tgt_data, tgt_action, device)
        torch.save(inv_action_model_train.state_dict(), model_paths["inv_action"])
        print(f"InverseActionModel weights saved to {model_paths['inv_action']}")

        train_inverse_reward_model(reward_model_train, tgt_data, tgt_reward, device)
        torch.save(reward_model_train.state_dict(), model_paths["reward"])
        print(f"InverseRewardModel weights saved to {model_paths['reward']}")

    else:
        state_model.load_state_dict(torch.load(model_paths["state"], map_location=device))
        state_model.eval()
        print(f"Loaded ScoreNet weights from {model_paths['state']}")

        tran_cond_model.load_state_dict(torch.load(model_paths["cond"], map_location=device))
        tran_cond_model.eval()
        print(f"Loaded ScoreNet weights from {model_paths['cond']}")

        inv_action_model_train.load_state_dict(torch.load(model_paths["inv_action"], map_location=device))
        inv_action_model_train.eval()
        print(f"Loaded InverseActionModel weights from {model_paths['inv_action']}")

        reward_model_train.load_state_dict(torch.load(model_paths["reward"], map_location=device))
        reward_model_train.eval()
        print(f"Loaded InverseRewardModel weights from {model_paths['reward']}")

    torch.cuda.empty_cache()
    print("training (or loading) complete!")


    save_dir = os.path.join(os.getcwd(), 'experiments_result', args.idn)
    os.makedirs(save_dir, exist_ok=True)

    current_dir = os.getcwd()
    tar_env_underscore = args.env.replace('-', '_')
    tar_h5_relpath = os.path.join('datasets', f"{tar_env_underscore}_{args.tartype}.hdf5")
    base_dir = os.path.join(os.getcwd(), "datasets_modified")
    os.makedirs(base_dir, exist_ok=True)

    target_env_name = f"{tar_env_underscore}_{args.tartype}"

    target_dir = os.path.join(base_dir, target_env_name)
    os.makedirs(target_dir, exist_ok=True)
    low_sa_samples = torch.from_numpy(src_data[:args.num_gen]).float().to(device)  

    gen_mix_s_norm = pc_sampler_cond(
        batch_size=45000, dim=state_dim,
        score_model=state_model, ymax=0.0, 
        num_steps=args.deno_steps, corrector=True, snr=args.snr,
        alpha_min=args.alpha_min, alpha_max=args.alpha_max,
        corrector_steps=args.corrector_steps, device=device, t_end=args.t_end, t_gamma=args.t_gamma,
    )

    gen_target_state = np.concatenate([train_data_norm[:,:state_dim], gen_mix_s_norm], axis=0)

    N = gen_target_state.shape[0]

    repeat_times = int(np.ceil(args.num_gen / N))
    arr_tiled = np.tile(gen_target_state, (repeat_times, 1)) 
    arr_out = arr_tiled[:args.num_gen]

    gen_target_state_torch = torch.from_numpy(gen_target_state).float()
    gen_target_norm = conditional_sampler_batch(
        gen_target_state_torch,
        score_model=tran_cond_model,
        num_steps=args.deno_steps,
        alpha_min=0.1,
        alpha_max=20.0,
        corrector=True,
        snr=0.16,
        corrector_steps=1,
        device=device,
        t_end=args.t_end,
        batch_size=512,  
        t_gamma=args.t_gamma,
    )
    gen_target_samples = gen_target_norm * (std + epsilon) + mean

    n_gen_temp = 100000  #
    gen_mix_norm_list, n_gen_sum = [], 0

    n_gen_sample_total = 0
    n_gen_filtered_sample_total = 0
    while n_gen_sum < args.num_gen:
        gen_mix_s_norm = pc_sampler_cond(
            batch_size=n_gen_temp, dim=state_dim,
            score_model=state_model, ymax=args.ymax,  
            num_steps=args.deno_steps, corrector=True, snr=args.snr, alpha_min=args.alpha_min, alpha_max=args.alpha_max,
            corrector_steps=args.corrector_steps, device=device, t_end=args.t_end, t_gamma=args.t_gamma, 
        )
        gen_mix_ssp_norm = np.concatenate([gen_mix_s_norm, np.zeros_like(gen_mix_s_norm)], axis=1)
        mask = np.concatenate([np.ones((gen_mix_ssp_norm.shape[0], state_dim)), np.zeros((gen_mix_ssp_norm.shape[0], state_dim))], axis=1)

        gen_mix_ssp_norm = torch.from_numpy(gen_mix_ssp_norm).float()
        gen_mix_norm_temp = conditional_sampler_batch(
            gen_mix_ssp_norm[:, :state_dim],
            score_model=tran_cond_model,
            num_steps=args.deno_steps,
            alpha_min=0.1,
            alpha_max=20.0,
            corrector=True,
            snr=0.16,
            corrector_steps=1,
            device=device,
            t_end=args.t_end,
            batch_size=512,  
            t_gamma=args.t_gamma,
        )
        print("gen_mix_norm_temp before criterion :", gen_mix_norm_temp.shape)
        n_gen_sample_total = n_gen_sample_total + len(gen_mix_norm_temp)

        if args.use_z_thresh: # method : ["iqr", "robust_z_score", "z_score"]
            gen_mix_norm_temp = filter_outliers(gen_mix_norm_temp, thresh=args.z_thresh, max_violation_frac=args.max_frac, method="z_score")  
            print("gen_mix_norm_temp after (z-score criterion) :", gen_mix_norm_temp.shape)  
            n_gen_filtered_sample_total = n_gen_filtered_sample_total + len(gen_mix_norm_temp)

        gen_mix_norm_list.append(gen_mix_norm_temp)
        n_gen_sum = n_gen_sum + len(gen_mix_norm_temp)
        print("n_gen_sum", n_gen_sum)
    
    filtered_rate = int((n_gen_filtered_sample_total / n_gen_sample_total) * 100)

    gen_mix_norm = np.concatenate(gen_mix_norm_list)[:args.num_gen, :]
    gen_mix_samples = gen_mix_norm * (std + epsilon) + mean

    low_sa_samples = np.concatenate([gen_target_samples, gen_mix_samples], axis=0)

    save_buf = True
    if save_buf:

        with torch.no_grad():
            inverse_action_hat_train = inv_action_model_train(torch.tensor(low_sa_samples).to(device))
            inverse_action_hat_train = torch.clamp(inverse_action_hat_train, -1.0, 1.0)
            reward_hat_train = reward_model_train(torch.tensor(low_sa_samples).to(device))

        state_bundle      = low_sa_samples[:, :state_dim]
        action_bundle     = inverse_action_hat_train.cpu().numpy()
        next_state_bundle = low_sa_samples[:, state_dim:]
        reward_bundle     = reward_hat_train.cpu().numpy()

        low_sa_samples_len = len(low_sa_samples)
        source_terminal_info_arr = np.array(source_terminal_info)

        if len(source_terminal_info_arr) < low_sa_samples_len:
            repeat_times = int(np.ceil(low_sa_samples_len / len(source_terminal_info_arr)))
            padded_arr = np.tile(source_terminal_info_arr, repeat_times)[:low_sa_samples_len]
        else:
            padded_arr = source_terminal_info_arr[:low_sa_samples_len]

        terminal_bundle = padded_arr
        print("state_bundle", state_bundle.shape)
        print("action_bundle", action_bundle.shape)
        print("next_state_bundle", next_state_bundle.shape)
        print("reward_bundle", reward_bundle.shape)

        tar_env_underscore = args.env.replace('-', '_')
        base_dir = os.path.join(os.getcwd(), "datasets_modified")
        os.makedirs(base_dir, exist_ok=True)
        target_env_name = f"{tar_env_underscore}_{args.tartype}"
        target_dir = os.path.join(base_dir, target_env_name)
        os.makedirs(target_dir, exist_ok=True)
        sav_name = f"modified_{args.env}_{args.srctype}_{args.tartype}_{args.idn}_{str(args.ymax)}_{str(args.seed)}.npz"
        save_path = os.path.join(target_dir, sav_name)
        np.savez(
            save_path,
            states=state_bundle,
            actions=action_bundle,
            next_states=next_state_bundle,
            rewards=reward_bundle,
            terminals=terminal_bundle,
            timeouts=terminal_bundle,
        )
        print("Save buffer data complete! save_path:", save_path)
    

    cuda_env = os.environ.get("CUDA_VISIBLE_DEVICES", "")
    cmd = [
        "python", "run_offlinerl.py",
        "--env", args.env,
        "--policy", "OTDF",  
        "--srctype", args.srctype,
        "--tartype", args.tartype,
        "--seed", str(args.seed),
        "--idn", sav_name,
        "--wandb_idn", args.idn +"ymax_" + str(args.ymax) + "_reg"+str(args.reg_weight) + "_filteredrate"+str(filtered_rate),
        "--reg_weight", str(args.reg_weight),
        "--select_ratio", "2",
    ]

    print("[INFO] Running second command:", " ".join(cmd))
    os.execvpe("python", cmd, {"CUDA_VISIBLE_DEVICES": cuda_env, **dict(os.environ)})

###############################################################################################################

    print('all process done!')






