import os 

import numpy as np
import pandas as pd
import torch
import argparse
import sys
import os
from foundation_models import (
    MolFormerRegressor,
    RobertaRegressor,
    T5Regressor,
    GPT2Regressor,
    Llama2Regressor,
)
from foundation_models import (
    get_molformer_tokenizer,
    get_roberta_tokenizer,
    get_t5_tokenizer,
    get_gpt2_tokenizer,
    get_llama2_tokenizer,
)
from bayesopt.acqf import ucb, ei, thompson_sampling
from problems.data_processor import (
    RedoxDataProcessor,
    SolvationDataProcessor,
    KinaseDockingDataProcessor,
    LaserEmitterDataProcessor,
    PhotovoltaicsPCEDataProcessor,
    PhotoswitchDataProcessor,
    D4DockingDataProcessor, 
    AmpCDockingDataProcessor
)
from problems.prompting import PromptBuilder
from utils import helpers
from utils.configs import LaplaceConfig, FSPLaplaceConfig, LLMFeatureType
from peft import LoraConfig, get_peft_model
import math
import time
from gpytorch.kernels import ScaleKernel, MaternKernel, RBFKernel, LinearKernel
from gauche.kernels.fingerprint_kernels import TanimotoKernel
from gpytorch.means import ZeroMean

from gpytorch.constraints import Interval

from bayesopt.surrogates.gp import MLLGP
from bayesopt.surrogates.laplace import FixedFeatLaplace
from bayesopt.surrogates.fsplaplace import FixedFeatFSPLaplace
from gauche.kernels.fingerprint_kernels import *

from sklearn.preprocessing import StandardScaler

import torch.utils.data as data_utils

import sys
import torch
from torch import nn, optim
import math
import pandas as pd
import tqdm
from transformers import get_scheduler

from utils.rogi_xd import rogi, Metric

ACTIVATIONS = {
    "relu": torch.nn.ReLU,
    "tanh": torch.nn.Tanh,
    "lrelu": torch.nn.LeakyReLU,
}     

parser = argparse.ArgumentParser()
parser.add_argument(
    "--problem",
    choices=["redox-mer", "solvation", "kinase", "laser", "pce", "photoswitch", "ampc", "d4"],
    default="redox-mer",
)
parser.add_argument("--method", choices=["gp", "laplace", "fsplaplace"])
parser.add_argument(
    "--foundation_model",
    default="molformer",
    choices=[
        "molformer",
        "roberta-large",
        "t5-base",
        "t5-base-chem",
        "gpt2-medium",
        "gpt2-large",
        "llama-2-7b",
    ],
)
parser.add_argument(
    "--kernel",
    choices=[
        "matern_0.5",
        "matern_1.5",
        "matern_2.5",
        "rbf",
        "linear"
    ],
    default="matern_2.5",
)
parser.add_argument(
    "--activation_fn", choices=["relu", "lrelu", "tanh"], default="tanh"
)
parser.add_argument(
    "--noise_var", default=0.0001
)
parser.add_argument(
    "--wd", default=1e-1
)
parser.add_argument(
    "--lr", default=1e-3
)
parser.add_argument(
    "--map_context_points", choices=["uniform", "bo_candidates"], default="bo_candidates"
)
parser.add_argument(
    "--cov_context_points", choices=["sobol", "bo_candidates"], default="bo_candidates"
)
parser.add_argument(
    "--prompt_type",
    choices=["single-number", "just-smiles", "completion"],
    default="just-smiles",
)
parser.add_argument("--lora_alpha_factor", default=4.)
parser.add_argument("--lora_dropout", default=0.1)
parser.add_argument("--lr_lora", default=3e-4)
parser.add_argument("--lora_r", default=4)
parser.add_argument("--acqf", choices=["ei", "ucb", "ts"], default="ts")
parser.add_argument("--n_init_data", type=int, default=10)
parser.add_argument("--exp_len", type=int, default=200)
parser.add_argument("--randseed", type=int, default=1)
parser.add_argument("--dtype", choices=["float32", "float64"], default="float64")
args = parser.parse_args()

# Parse hyper-parameters
np.random.seed(args.randseed)
torch.manual_seed(args.randseed)

args.lr = float(args.lr)
args.wd = float(args.wd)
args.noise_var = float(args.noise_var)
args.lora_alpha_factor = float(args.lora_alpha_factor)
args.lora_r = int(args.lora_r)
args.lora_dropout = float(args.lora_dropout)
args.lr_lora = float(args.lr_lora)

# if args.method == "laplace" and args.dtype == "float64":
#     print("Changing to float32", flush=True)
#     args.dtype = "float32"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float32 if args.dtype == "float32" else torch.float64
if DEVICE == "cuda":
    print("Running on GPU")
else:
    print("Running on CPU")

# Molformer expects only SMILES
if args.foundation_model == "molformer":
    args.prompt_type = "just-smiles"

# Load foundation model tokenizer
if args.foundation_model == "molformer":
    feature_dim = 768
    tokenizer = get_molformer_tokenizer()
elif "roberta" in args.foundation_model:
    tokenizer = get_roberta_tokenizer(args.foundation_model)
elif "t5" in args.foundation_model:
    feature_dim = 768
    if "chem" in args.foundation_model:
        foundation_model_real = "GT4SD/multitask-text-and-chemistry-t5-base-augm"
    else:
        foundation_model_real = args.foundation_model
    tokenizer = get_t5_tokenizer(foundation_model_real)
elif "gpt2" in args.foundation_model:
    feature_dim = 768
    tokenizer = get_gpt2_tokenizer(args.foundation_model)
elif "llama-2" in args.foundation_model:
    feature_dim = 768
    tokenizer = get_llama2_tokenizer(args.foundation_model)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load datasets 
if args.problem == "redox-mer":
    dataset = pd.read_csv("data/redox_mer_with_iupac.csv.gz")
    OBJ_COL = "Ered" 
    MAXIMIZATION = False
    prompt_builder = PromptBuilder(kind=args.prompt_type)
    data_processor = RedoxDataProcessor(prompt_builder, tokenizer)
elif args.problem == "solvation":
    dataset = pd.read_csv("data/redox_mer_with_iupac.csv.gz")
    OBJ_COL = "Gsol" 
    MAXIMIZATION = False
    prompt_builder = PromptBuilder(kind=args.prompt_type)
    data_processor = SolvationDataProcessor(prompt_builder, tokenizer)
elif args.problem == "kinase":
    dataset = pd.read_csv("data/enamine10k.csv.gz")
    OBJ_COL = "score"  
    MAXIMIZATION = False
    prompt_builder = PromptBuilder(kind=args.prompt_type)
    data_processor = KinaseDockingDataProcessor(prompt_builder, tokenizer)
elif args.problem == "laser":
    dataset = pd.read_csv("data/laser_multi10k.csv.gz")
    OBJ_COL = "Fluorescence Oscillator Strength"
    MAXIMIZATION = True
    prompt_builder = PromptBuilder(kind=args.prompt_type)
    data_processor = LaserEmitterDataProcessor(prompt_builder, tokenizer)
elif args.problem == "pce":
    dataset = pd.read_csv("data/photovoltaics_pce10k.csv.gz")
    OBJ_COL = "pce"
    MAXIMIZATION = True
    prompt_builder = PromptBuilder(kind=args.prompt_type)
    data_processor = PhotovoltaicsPCEDataProcessor(prompt_builder, tokenizer)
elif args.problem == "photoswitch":
    dataset = pd.read_csv("data/photoswitches.csv.gz")
    OBJ_COL = "Pi-Pi* Transition Wavelength"
    MAXIMIZATION = True
    prompt_builder = PromptBuilder(kind=args.prompt_type)
    data_processor = PhotoswitchDataProcessor(prompt_builder, tokenizer)
elif args.problem == "ampc":
    dataset = pd.read_csv("data/Zinc_AmpC_Docking_filtered.csv.gz")
    OBJ_COL = "dockscore"  
    MAXIMIZATION = False
    prompt_builder = PromptBuilder(kind=args.prompt_type)
    data_processor = AmpCDockingDataProcessor(prompt_builder, tokenizer)
elif args.problem == "d4":
    dataset = pd.read_csv("data/Zinc_D4_Docking_filtered.csv.gz")
    OBJ_COL = "dockscore"        
    MAXIMIZATION = False
    prompt_builder = PromptBuilder(kind=args.prompt_type)
    data_processor = D4DockingDataProcessor(prompt_builder, tokenizer)
else: 
    print("Invalid test function!")
    sys.exit(1)

# Turn into a maximization problem if necessary
if not MAXIMIZATION:
    dataset[OBJ_COL] = -dataset[OBJ_COL]
ground_truth_max = dataset[OBJ_COL].max()

print()
if args.method == "gp":
    print(
        f"Test Function: {args.problem}; Foundation Model: {args.foundation_model}; Prompt Type: {args.prompt_type}; Randseed: {args.randseed}; kernel: {args.kernel}; - noise_var: {args.noise_var} - lr_lora {args.lr_lora} - lora_alpha_factor: {args.lora_alpha_factor} - lora_dropout: {args.lora_dropout} - lora_r: {args.lora_r};", flush=True
    )
elif args.method == "laplace":
    print(
        f"Test Function: {args.problem}; Foundation Model: {args.foundation_model}; Prompt Type: {args.prompt_type}; Randseed: {args.randseed}; wd: {args.wd}; activation: {args.activation_fn}; lr: {args.lr} - noise_var: {args.noise_var} - lr_lora {args.lr_lora} - lora_alpha_factor: {args.lora_alpha_factor} - lora_dropout: {args.lora_dropout} - lora_r: {args.lora_r};", flush=True
    )
elif args.method == "fsplaplace":
    print(
        f"Test Function: {args.problem}; Foundation Model: {args.foundation_model}; Prompt Type: {args.prompt_type}; Randseed: {args.randseed}; kernel: {args.kernel}; activation: {args.activation_fn}; map_context_points: {args.map_context_points}; cov_context_points: {args.cov_context_points}; lr: {args.lr} - noise_var: {args.noise_var} - lr_lora {args.lr_lora} - lora_alpha_factor: {args.lora_alpha_factor} - lora_dropout: {args.lora_dropout} - lora_r: {args.lora_r};", flush=True
    )
print(
    "---------------------------------------------------------------------------------------------------------------"
)
print()

dataset_train = []
while len(dataset_train) < args.n_init_data:
    idx = np.random.randint(len(dataset))
    # Make sure that the optimum is not included
    if dataset.loc[idx][OBJ_COL] >= ground_truth_max:
        continue
    dataset_train.append(helpers.pop_df(dataset, idx))

# Get foundation models
def get_model():
    if args.foundation_model == "molformer":
        model = MolFormerRegressor(tokenizer, dtype=args.dtype)
        target_modules = ["query", "value"]
    elif "roberta" in args.foundation_model:
        model = RobertaRegressor(
            kind=args.foundation_model,
            tokenizer=tokenizer,
            reduction=LLMFeatureType.AVERAGE
        )
        target_modules = ["query", "value"]
    elif "gpt2" in args.foundation_model:
        model = GPT2Regressor(
            kind=args.foundation_model,
            tokenizer=tokenizer,
            reduction=LLMFeatureType.AVERAGE,
        )
        target_modules = ["c_attn"]
    elif "llama-2" in args.foundation_model:
        model = Llama2Regressor(
            kind=args.foundation_model,
            tokenizer=tokenizer,
            reduction=LLMFeatureType.AVERAGE,
        )
        target_modules = ["q_proj", "v_proj"]
    elif "t5" in args.foundation_model:
        if "chem" in args.foundation_model:
            model = T5Regressor(
                kind="GT4SD/multitask-text-and-chemistry-t5-base-augm",
                tokenizer=tokenizer,
                reduction=LLMFeatureType.AVERAGE,
            )
        else:
            model = T5Regressor(
                kind=args.foundation_model,
                tokenizer=tokenizer,
                reduction=LLMFeatureType.AVERAGE,
            )
        target_modules = ["q", "v"]
    else:
        raise NotImplementedError

    config = LoraConfig(
        r=args.lora_r,
        lora_alpha=int(args.lora_alpha_factor * args.lora_r),
        target_modules=target_modules,
        lora_dropout=args.lora_dropout,
        bias="none",
        modules_to_save=["head"],
    )
    lora_model = get_peft_model(model, config)
    for p in lora_model.base_model.head.original_module.parameters():
        p.requires_grad = False

    return lora_model


APPEND_EOS = args.foundation_model != "molformer" and (
    "t5" not in args.foundation_model
)

def build_surrogate_model(train_x, train_y, bo_candidates):

    if args.method == "gp":
        if "matern" in args.kernel:
            nu = float(args.kernel.split("_")[-1])
            kernel_fn = MaternKernel(nu=nu, ard_num_dims=feature_dim, lengthscale_constraint=Interval(1e-3, 1e3))
        elif args.kernel == "rbf":
            kernel_fn = RBFKernel(ard_num_dims=feature_dim, lengthscale_constraint=Interval(1e-3, 1e3))
        elif args.kernel == "linear":
            kernel_fn = LinearKernel()
        else: 
            raise ValueError(f"Invalid kernel: {args.kernel}")
        kernel = ScaleKernel(kernel_fn) if args.kernel != "linear" else kernel_fn
        model = MLLGP(
            train_X=train_x, train_Y=train_y, kernel=kernel, noise_var=args.noise_var, device=DEVICE, dtype=DTYPE, normalize_x=False
        )
    elif args.method == "laplace":    
        cfg = LaplaceConfig(wd=args.wd, lr=args.lr, activation=args.activation_fn, noise_var=args.noise_var)
        activation = ACTIVATIONS[args.activation_fn]
        def get_net():
            return torch.nn.Sequential(
                torch.nn.Linear(feature_dim, 50),
                activation(),
                torch.nn.Linear(50, 50),
                activation(),
                torch.nn.Linear(50, 1),
            ).to(DTYPE).to(DEVICE)
        model = FixedFeatLaplace(
            train_x, train_y, get_net, hess_factorization="kron", laplace_config=cfg, device=DEVICE, dtype=DTYPE, normalize_x=False
        )
    elif args.method == "fsplaplace":
        cfg = FSPLaplaceConfig(
            cov_context_points=args.cov_context_points, 
            map_context_points=args.map_context_points, 
            lr=args.lr,
            activation=args.activation_fn,
            noise_var=args.noise_var
        )
        activation = ACTIVATIONS[args.activation_fn]
        def get_net():
            return torch.nn.Sequential(
                torch.nn.Linear(feature_dim, 50),
                activation(),
                torch.nn.Linear(50, 50),
                activation(),
                torch.nn.Linear(50, 1),
            ).to(DTYPE).to(DEVICE)
        
        mean_fn = ZeroMean(batch_shape=torch.Size([1]))
        if "matern" in args.kernel:
            nu = float(args.kernel.split("_")[-1])
            kernel_fn = MaternKernel(nu=nu, ard_num_dims=feature_dim, lengthscale_constraint=Interval(1e-3, 1e3))
        elif args.kernel == "rbf":
            kernel_fn = RBFKernel(ard_num_dims=feature_dim, lengthscale_constraint=Interval(1e-3, 1e3))
        elif args.kernel == "linear":
            kernel_fn = LinearKernel(batch_shape=torch.Size([1]))
        else: 
            raise ValueError(f"Invalid kernel: {args.kernel}")
        kernel_fn = ScaleKernel(kernel_fn, batch_shape=torch.Size([1])).to(DTYPE).to(DEVICE) if args.kernel not in ["linear"] else kernel_fn
            
        model = FixedFeatFSPLaplace(
            train_x, train_y, bo_candidates, mean_fn.to(DTYPE).to(DEVICE), 
            kernel_fn, initialize_nn=get_net, laplace_config=cfg, device=DEVICE, dtype=DTYPE, normalize_x=False
        )

    return model.to(DEVICE).to(DTYPE)


def finetune_foundation_model(model, train_loader, n_epochs=100, grad_clip=0.0): # n_epochs=100
    model.train()
    loss_func = nn.MSELoss()

    lora_params = [
        p for n, p in model.named_parameters() if p.requires_grad and "lora" in n
    ]
    head_params = [
        p
        for n, p in model.named_parameters()
        if p.requires_grad and "lora" not in n
    ]
    optimizer_lora = optim.AdamW(lora_params, lr=3e-4, weight_decay=5e-4)
    optimizer_head = optim.AdamW(head_params, lr=1e-3, weight_decay=5e-4)

    num_training_steps = n_epochs * len(train_loader)
    scheduler_lora = get_scheduler(
        name="linear",
        optimizer=optimizer_lora,
        # num_warmup_steps=0.06*num_training_steps,  # Following the warmup ratio in LoRA paper
        num_warmup_steps=0,
        num_training_steps=num_training_steps,
    )
    scheduler_head = get_scheduler(
        name="cosine",
        optimizer=optimizer_head,
        # num_warmup_steps=0.06*num_training_steps,  # Following the warmup ratio in LoRA paper
        num_warmup_steps=0,
        num_training_steps=num_training_steps,
    )
    scaler = torch.amp.GradScaler(DEVICE, enabled=False)

    # Joint lora-head training
    for i in tqdm.trange(
        n_epochs, position=1, leave=False, desc="[Training]", colour="blue", file=sys.stdout, disable=True
    ):
        batch_loss = 0
        for batch in train_loader:
            model.train()
            labels = batch["labels"].to(DEVICE, DTYPE, non_blocking=True)

            outputs = model(batch)
            loss = loss_func(outputs, labels)

            scaler.scale(loss).backward()

            if grad_clip != 0.0:
                scaler.unscale_(optimizer_lora)
                torch.nn.utils.clip_grad_norm_(lora_params, grad_clip)

            scaler.step(optimizer_lora)
            scaler.step(optimizer_head)
            scaler.update()
            scheduler_lora.step()
            scheduler_head.step()
            optimizer_lora.zero_grad(set_to_none=True)
            optimizer_head.zero_grad(set_to_none=True)
            batch_loss += loss.item()

            if i % 10 == 0 or i == n_epochs - 1:
                print(f"{i}/{n_epochs} - LORA loss: {batch_loss / len(train_loader)}", flush=True)

    model.eval()

    return model

best_y = pd.DataFrame(dataset_train)[OBJ_COL].max()
pbar = tqdm.trange(args.exp_len, position=0, colour="green", file=sys.stdout, disable=True)
print(
    f"[Best f(x) = {helpers.y_transform(best_y, MAXIMIZATION):.3f}]", flush=True
)

pbar = tqdm.trange(args.exp_len, file=sys.stdout, disable=True)
trace_best_acqval = []
trace_best_y = [helpers.y_transform(best_y, MAXIMIZATION)] * (args.exp_len + 1)
trace_timing = [0.0] * (args.exp_len + 1)

timing_ft = []
timing_train = []
timing_preds = []
rogi_scores = []

for i in pbar:
    if DEVICE == "cuda":
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        torch.cuda.synchronize()
        start.record()
    else:
        start = time.time() 

    # Generate the dataloader of observed molecules
    # Note: we standardize the labels during finetuning
    _dataset_train = pd.DataFrame(dataset_train, columns=dataset.columns)
    ft_dataloader, _, _ = data_processor.get_dataloader(
        _dataset_train, batch_size=128, shuffle=False, append_eos=APPEND_EOS, standardize_y=True
    )

    # Fine-tune the language model
    feat_model = get_model().to(DEVICE)
    start_ft = time.time()
    feat_model = finetune_foundation_model(feat_model, ft_dataloader)
    timing_ft.append(time.time() - start_ft)
    
    # Make new dataset of unobserved data
    mol_dataloader = data_processor.get_dataloader(
        dataset, batch_size=256, shuffle=False, append_eos=APPEND_EOS
    )
    features, targets = [], []
    for j, data in enumerate(tqdm.tqdm(mol_dataloader, file=sys.stdout, dynamic_ncols=True, mininterval=0, disable=True)):
        with torch.no_grad():
            feat = feat_model.forward_features(data)
        features += list(feat.cpu())
        targets += list(data["labels"])

    # Generate the dataloader of observed molecules
    _dataset_train = pd.DataFrame(dataset_train, columns=dataset.columns)
    train_dataloader = data_processor.get_dataloader(
        _dataset_train, batch_size=256, shuffle=False, append_eos=APPEND_EOS
    )
    train_features, train_targets = [], []
    for data in tqdm.tqdm(train_dataloader, file=sys.stdout, dynamic_ncols=True, mininterval=0, disable=True):
        with torch.no_grad():
            feat = feat_model.forward_features(data)
        train_features += list(feat.cpu())
        train_targets += list(data["labels"])

    # Clean up 
    del feat_model, _dataset_train, ft_dataloader, mol_dataloader, train_dataloader
    
    # Compute ROGI
    all_features = np.concatenate([features, train_features], axis=0)
    all_targets = np.concatenate([targets, train_targets], axis=0)

    rogi_score = rogi(all_features, all_targets, normalize=True, metric=Metric.EUCLIDEAN, nboots=10)
    print(f"ROGI score: {rogi_score.rogi} +/- {rogi_score.uncertainty}", flush=True)
    rogi_scores.append({"mean": rogi_score.rogi, "std": rogi_score.uncertainty})

    # Normalize features 
    x_preprocessor = StandardScaler()
    all_features = list(
        torch.tensor(
            x_preprocessor.fit_transform(np.array(all_features).reshape(-1, all_features[0].shape[-1]))
        ).float()
    )
    features, train_features = all_features[:len(features)], all_features[len(features):]

    # Create surrogate model
    start_train = time.time()
    train_x, train_y = torch.stack(train_features).to(DTYPE).to(DEVICE), torch.stack(train_targets).to(DTYPE).to(DEVICE)
    X_candidates = torch.stack(all_features).to(DTYPE).to(DEVICE)
    model = build_surrogate_model(train_x, train_y, X_candidates).to(DTYPE).to(DEVICE)
    timing_train.append(time.time() - start_train)

    ## BO
    dataloader = data_utils.DataLoader(
        data_utils.TensorDataset(torch.stack(features), torch.stack(targets)),
        batch_size=256,
        shuffle=False,
    )

    preds, uncerts, labels = [], [], []
    acq_vals = []
    start_pred = time.time()

    for x, y in dataloader:
        x, y = x.to(DEVICE).to(DTYPE), y.to(DEVICE).to(DTYPE)
        posterior = model.posterior(x)
        f_mean, f_var = posterior.mean, posterior.variance

        if args.acqf == "ei":
            acq_vals.append(ei(f_mean, f_var, best_y))
        elif args.acqf == "ucb":
            acq_vals.append(ucb(f_mean, f_var))
        else:
            acq_vals.append(thompson_sampling(f_mean, f_var))

        preds.append(f_mean)
        uncerts.append(f_var.sqrt())
        labels.append(y)

    del dataloader

    timing_preds.append(time.time() - start_pred)

    acq_vals = torch.cat(acq_vals, dim=0).cpu().squeeze()
    preds, uncerts, labels = (
        torch.cat(preds, dim=0).cpu(),
        torch.cat(uncerts, dim=0).cpu(),
        torch.cat(labels, dim=0).cpu(),
    )
    test_loss = torch.nn.MSELoss()(preds, labels).item()

    # Pick a molecule (a row in the current dataset) that maximizes the acquisition
    idx_best = torch.argmax(acq_vals).item()
    data_best = helpers.pop_df(dataset, idx_best)
    new_y = data_best[OBJ_COL]
    dataset_train.append(data_best)

    trace_best_acqval.append(torch.max(acq_vals).item())

    # Update the current best y
    if new_y.item() > best_y:
        best_y = new_y.item()

    print(
        f"{i}/{args.exp_len} - [Best f(x) = {helpers.y_transform(best_y, MAXIMIZATION):.3f}, "
        + f"curr f(x) = {helpers.y_transform(new_y.item(), MAXIMIZATION):.3f}, test MSE = {test_loss:.3f}]", 
        flush=True
    )

    # Housekeeping
    if DEVICE == "cuda":
        end.record()
        torch.cuda.synchronize()
        timing = start.elapsed_time(end) / 1000
    else:
        timing = time.time() - start

    trace_best_y[i + 1] = helpers.y_transform(best_y, MAXIMIZATION)
    trace_timing[i + 1] = timing

    # Early stopping if we already got the max
    if best_y >= ground_truth_max:
        break

# Save results
path = f"results/feature_finetuning/{args.problem}/{args.method}_{args.foundation_model}_finetuning/"
if not os.path.exists(path):
    os.makedirs(path)

suffix = f"{args.lora_alpha_factor}_{args.lora_r}_{args.lora_dropout}_{args.lr_lora}"
if args.method == "gp":
    suffix += f"_{args.kernel}_{args.noise_var}"
elif args.method == "fsplaplace":
    suffix += f"_{args.kernel}_{args.activation_fn}_{args.map_context_points}_{args.cov_context_points}_{args.lr}_{args.noise_var}"
elif args.method == "laplace":
    suffix += f"_{args.wd}_{args.activation_fn}_{args.lr}_{args.noise_var}"

np.save(
    f"{path}/timing_train_{args.n_init_data}_{args.acqf}_{args.randseed}_{suffix}.npy",
    timing_train,
)
np.save(
    f"{path}/timing_preds_{args.n_init_data}_{args.acqf}_{args.randseed}_{suffix}.npy",
    timing_preds,
)
np.save(
    f"{path}/timing_ft_{args.n_init_data}_{args.acqf}_{args.randseed}_{suffix}.npy",
    timing_ft,
)
np.save(
    f"{path}/rogi_scores_{args.n_init_data}_{args.acqf}_{args.randseed}_{suffix}.npy",
    rogi_scores,
)

np.save(
    f"{path}/trace_best_acqval_{args.n_init_data}_{args.acqf}_{args.randseed}_{suffix}.npy",
    trace_best_acqval,
)

if args.foundation_model == "molformer":
    np.save(
        f"{path}/trace_best_y_{args.n_init_data}_{args.acqf}_{args.randseed}_{suffix}.npy",
        trace_best_y,
    )
    np.save(
        f"{path}/trace_timing_{args.n_init_data}_{args.acqf}_{args.randseed}_{suffix}.npy",
        trace_timing,
    )
else:
    np.save(
        f"{path}/{args.prompt_type}_trace_best_y_{args.n_init_data}_{args.acqf}_{args.randseed}_{suffix}.npy",
        trace_best_y,
    )
    np.save(
        f"{path}/{args.prompt_type}_trace_timing_{args.n_init_data}_{args.acqf}_{args.randseed}_{suffix}.npy",
        trace_timing,
    )
