import argparse
import json
import pathlib
import time
import traceback

import numpy as np
import torch
import tqdm
from torch import nn, optim

from CHANGELOG import MODEL_VERSION, CHECK_VERSION
from tigprompt.model.feature_getter import NumericalFeature
from tigprompt.utils import BackgroundThreadGenerator

from init_utils import init_data, init_model
from train_utils import EarlyStopMonitor, get_logger, hash_args, seed_all

from tigprompt.model.temporal_prompt_generator import TransformerTProG, VanillaTProG, ProjectionTProG

from tigprompt.eval_utils import eval_edge_prediction, warmup
import pickle
import shutil



def run(*, prefix, gpu, seed, lr, n_epochs,
        patience, force, ckpt_path, root, bs,
        use_valid, dropout,
        # above are new parameters
        prompter_type,
        teaching_prompt, prompt_teaching_coef,
        fine_tune_mode, prompt_start, prompt_end, prompt_dim,
        # above are prompting parameters
        data, dim, feature_as_buffer, num_workers,
        hit_type, restarter_type, hist_len,
        n_neighbors, n_layers, n_heads,
        strategy, msg_src, upd_src,
        mem_update_type, msg_tsfm_type,
        restart_prob, warmup_steps, subset,
        embedding_type, dyrep, no_memory,
        **kwargs
        ):
    args = {k: v for k, v in locals().items()
            if not k in {'gpu', 'force', 'kwargs'}}
    HASH = hash_args(**args, MODEL_VERSION=MODEL_VERSION)
    prefix += '.' + HASH
    if not prefix:
        raise ValueError('Prefix should be given explicitly.')
    if gpu == "-1":
        device = torch.device('cpu')
    else:
        device = torch.device(f'cuda:{gpu}')

    restart_mode = restart_prob > 0

    # Sanity check
    if (not restart_mode) and (warmup_steps > 0):
        raise ValueError('Warmup is not needed without restart.')

    RESULT_SAVE_PATH = f"results/{prefix}.json"
    PICKLE_SAVE_PATH = "results/{}.pkl".format(prefix)
    PROMPTER_SAVE_PATH = f'./saved_prompters/{prefix}.pth'
    MODEL_SAVE_PATH = f'./saved_models/{prefix}.pth'

    pathlib.Path("./saved_prompters/").mkdir(parents=True, exist_ok=True)
    ckpts_dir = pathlib.Path(f"./saved_checkpoints/{prefix}")
    ckpts_dir.mkdir(parents=True, exist_ok=True)
    get_checkpoint_path = lambda epoch: ckpts_dir / f'{epoch}.pth'

    prompter_ckpts_dir = pathlib.Path(f"./saved_prompter_ckpts/{prefix}")
    prompter_ckpts_dir.mkdir(parents=True, exist_ok=True)
    get_prompter_path = lambda epoch: prompter_ckpts_dir / f'{epoch}.pth'

    logger = get_logger(prefix)
    logger.info(f'[START {prefix}]')
    logger.info(f'Model version: {MODEL_VERSION}')
    logger.info(", ".join([f"{k}={v}" for k, v in args.items()]))

    if pathlib.Path(RESULT_SAVE_PATH).exists() and not force:
        logger.info('Duplicate task! Abort!')
        return False

    try:
        # Init
        seed_all(seed)
        # ============= Load Data ===========
        if restarter_type != 'none':
            if restarter_type != 'projection':
                basic_data, graphs, dls = init_data(
                    data, root, seed,
                    num_workers=num_workers, bs=bs, warmup_steps=warmup_steps,
                    subset=subset, strategy=strategy,
                    n_layers=n_layers, n_neighbors=n_neighbors,
                    restarter_type=restarter_type, hist_len=hist_len, prompt_start=prompt_start, prompt_end=prompt_end
                )
            elif restarter_type == 'projection':
                basic_data, graphs, dls = init_data(
                    data, root, seed,
                    num_workers=num_workers, bs=bs, warmup_steps=warmup_steps,
                    subset=subset, strategy=strategy,
                    n_layers=n_layers, n_neighbors=n_neighbors,
                    restarter_type='static', hist_len=hist_len, prompt_start=prompt_start, prompt_end=prompt_end
                )
        elif restarter_type == 'none':
            if prompter_type == 'transformer':
                basic_data, graphs, dls = init_data(
                    data, root, seed,
                    num_workers=num_workers, bs=bs, warmup_steps=warmup_steps,
                    subset=subset, strategy=strategy,
                    n_layers=n_layers, n_neighbors=n_neighbors,
                    restarter_type='seq', hist_len=hist_len, prompt_start=prompt_start, prompt_end=prompt_end
                )
            elif prompter_type == 'projection' or 'vanilla':
                basic_data, graphs, dls = init_data(
                    data, root, seed,
                    num_workers=num_workers, bs=bs, warmup_steps=warmup_steps,
                    subset=subset, strategy=strategy,
                    n_layers=n_layers, n_neighbors=n_neighbors,
                    restarter_type='static', hist_len=hist_len, prompt_start=prompt_start, prompt_end=prompt_end
                )
        (
            nfeats, efeats, full_data, train_data, val_data, test_data,
            inductive_val_data, inductive_test_data, prompt_data, train_and_prompt_data
        ) = basic_data
        train_graph, full_graph, train_and_prompt_graph = graphs
        (
            train_dl, offline_dl, val_dl, ind_val_dl,
            test_dl, ind_test_dl, val_warmup_dl, test_warmup_dl, prompt_dl
        ) = dls

        # ============= Init Model ===========
        encoder = init_model(
            nfeats, efeats, train_and_prompt_graph, full_graph, full_data, device,
            feature_as_buffer=feature_as_buffer, dim=dim,
            n_layers=n_layers, n_heads=n_heads, n_neighbors=n_neighbors,
            hit_type=hit_type, dropout=dropout,
            restarter_type=restarter_type, hist_len=hist_len,
            msg_src=msg_src, upd_src=upd_src,
            msg_tsfm_type=msg_tsfm_type, mem_update_type=mem_update_type,
            embedding_type=embedding_type, dyrep=dyrep, no_memory=no_memory,
            and_prompt=False
        )

        if nfeats is not None:
            nfeats = torch.from_numpy(nfeats).float()
            dim = nfeats.shape[1] if dim is None else dim
        if efeats is not None:
            efeats = torch.from_numpy(efeats).float()
            dim = efeats.shape[1] if dim is None else dim

        raw_feat_getter = NumericalFeature(
            nfeats, efeats, dim=dim, register_buffer=feature_as_buffer, device=device
        )
        raw_feat_getter.n_nodes = full_graph.num_node
        raw_feat_getter.n_edges = len(full_data)
        ts_delta_mean, ts_delta_std, *_ = full_data.get_stats()

        if prompter_type == "transformer":
            prompter = TransformerTProG(
                raw_feat_getter=raw_feat_getter,
                graph=full_graph,
                hist_len=hist_len,
                n_head=n_heads, dropout=dropout,
                dyrep=dyrep, prompt_dim=prompt_dim
            ).to(device)
        elif prompter_type == "vanilla":
            prompter = VanillaTProG(
                raw_feat_getter=raw_feat_getter,
                graph=full_graph,
                dyrep=dyrep, prompt_dim=prompt_dim
                ).to(device)
        elif prompter_type == "projection":
            prompter = ProjectionTProG(
                raw_feat_getter=raw_feat_getter,
                graph=full_graph,
                dyrep=dyrep, prompt_dim=prompt_dim
            ).to(device)
            
        # load model ckpt
        encoder.load_state_dict(torch.load(ckpt_path, map_location=device))
        encoder.eval()

        fine_tune_loss_fn = nn.BCEWithLogitsLoss()
        prompt_loss_fn = nn.MSELoss()
        if fine_tune_mode:
            optimizer = optim.Adam(encoder.parameters(), lr=lr)
        optimizer_prompt = optim.Adam(prompter.parameters(), lr=lr)

        val_aps = []
        ind_val_aps = []
        val_aucs = []
        ind_val_aucs = []
        epoch_times = []
        total_epoch_times = []
        train_losses = []

        if use_valid:
            early_stopper = EarlyStopMonitor(max_round=patience)
        for epoch in range(n_epochs):
            start_epoch_t0 = time.time()
            logger.info('Start {} epoch'.format(epoch))

            m_loss = []
            m_contrast_loss = []
            m_prompt_loss = []
            it = BackgroundThreadGenerator(prompt_dl)
            it = tqdm.tqdm(it, total=len(prompt_dl), ncols=50)

            buffer_list = ['left_memory.vals', 'left_memory.update_ts', 'left_memory.active_mask',
                           'right_memory.vals', 'right_memory.update_ts', 'right_memory.active_mask',
                           'msg_memory.vals', 'msg_memory.update_ts', 'msg_memory.active_mask',
                           'upd_memory.vals', 'upd_memory.update_ts', 'upd_memory.active_mask']

            encoder.reset()

            ckpt_buffers = torch.load(ckpt_path, map_location=device)
            encoder_dict = encoder.state_dict()
            for buffer in buffer_list:
                encoder_dict[buffer] = ckpt_buffers[buffer]
            encoder.load_state_dict(encoder_dict)

            if fine_tune_mode:
                encoder.train()
            prompter.train()
            for i_batch, (src_ids, dst_ids, neg_dst_ids, ts, eids, labels, comp_graph) in enumerate(it):
                src_ids = src_ids.long().to(device)
                dst_ids = dst_ids.long().to(device)
                neg_dst_ids = neg_dst_ids.long().to(device)
                ts = ts.float().to(device)
                eids = eids.long().to(device)
                comp_graph.to(device)
                if not fine_tune_mode:
                    with torch.no_grad():
                        (h_left_with_negs, h_left, h_prev_left, h_prev_right,
                         bs, batch_node_ids, involved_node_reprs) = encoder.contrast_learning_with_prompt(src_ids, dst_ids,
                                                                                                          neg_dst_ids, ts,
                                                                                                          eids, comp_graph)
                elif fine_tune_mode:
                    (h_left_with_negs, h_left, h_prev_left, h_prev_right,
                     bs, batch_node_ids, involved_node_reprs) = encoder.contrast_learning_with_prompt(src_ids, dst_ids,
                                                                                                      neg_dst_ids, ts,
                                                                                                      eids, comp_graph)

                optimizer_prompt.zero_grad()
                if fine_tune_mode:
                    optimizer.zero_grad()

                index = comp_graph.restart_data.index
                unique_nids = torch.cat([src_ids, dst_ids])[index]
                unique_ts = ts.repeat(2)[index]

                index_neg = comp_graph.restart_data_neg.index
                unique_nids_neg = torch.cat([src_ids, neg_dst_ids])[index_neg]
                unique_ts_neg = ts.repeat(2)[index_neg]

                index = index.to(device)
                unique_nids = unique_nids.to(device)
                unique_ts = unique_ts.to(device)
                
                index_neg = index_neg.to(device)
                unique_nids_neg = unique_nids_neg.to(device)
                unique_ts_neg = unique_ts_neg.to(device)

                if encoder.dyrep:
                    local_center_nids = comp_graph.local_index[batch_node_ids]
                    h_prev_right_with_negs = involved_node_reprs[local_center_nids]
                    x, y, neg_y = h_prev_right_with_negs.reshape(3, bs, encoder.nfeat_dim)
                else:
                    x, y, neg_y = h_left_with_negs.reshape(3, bs, encoder.nfeat_dim)

                surrogate_h_prev_left, surrogate_h_prev_right, prompted_x = prompter(unique_nids, unique_ts, x, src_ids)
                _1, _2, prompted_y = prompter(unique_nids, unique_ts, y, dst_ids)
                _1, _2, prompted_neg_y = prompter(unique_nids_neg, unique_ts_neg, neg_y, neg_dst_ids, is_neg=True)

                src_hit, dst_hit, neg_src_hit, neg_dst_hit = comp_graph.hit_data
                if encoder.hit_type == 'vec':
                    x_pos_pair = torch.cat([prompted_x, src_hit], 1)  # [bs, dim + n_neigh]
                    y_pos_pair = torch.cat([prompted_y, dst_hit], 1)  # [bs, dim + n_neigh]
                    x_neg_pair = torch.cat([prompted_x, neg_src_hit], 1)
                    y_neg_pair = torch.cat([prompted_neg_y, neg_dst_hit], 1)
                elif encoder.hit_type == 'bin':
                    x_pos_pair = prompted_x + encoder.hit_embedding(src_hit.max(1).values.long())
                    y_pos_pair = prompted_y + encoder.hit_embedding(dst_hit.max(1).values.long())
                    x_neg_pair = prompted_x + encoder.hit_embedding(neg_src_hit.max(1).values.long())
                    y_neg_pair = prompted_neg_y + encoder.hit_embedding(neg_dst_hit.max(1).values.long())
                elif encoder.hit_type == 'count':
                    x_pos_pair = prompted_x + encoder.hit_embedding(src_hit.sum(1).long())
                    y_pos_pair = prompted_y + encoder.hit_embedding(dst_hit.sum(1).long())
                    x_neg_pair = prompted_x + encoder.hit_embedding(neg_src_hit.sum(1).long())
                    y_neg_pair = prompted_neg_y + encoder.hit_embedding(neg_dst_hit.sum(1).long())
                else:
                    x_pos_pair = x_neg_pair = prompted_x
                    y_pos_pair = prompted_y
                    y_neg_pair = prompted_neg_y

                pos_scores = encoder.score_fn(x_pos_pair, y_pos_pair).squeeze(1)
                neg_scores = encoder.score_fn(x_neg_pair, y_neg_pair).squeeze(1)
                # compute loss
                label_ones = torch.ones_like(pos_scores)
                label_zeros = torch.zeros_like(neg_scores)
                labels = torch.cat([label_ones, label_zeros], 0)
                contrast_loss = fine_tune_loss_fn(
                    torch.cat([pos_scores, neg_scores], 0), labels)

                # if no_memory, cannot teaching prompt
                if not no_memory:
                    if teaching_prompt:
                        targets = torch.cat([h_prev_left[index], h_prev_right[index]], 0)  # [2n, d]
                        preds = torch.cat([surrogate_h_prev_left, surrogate_h_prev_right], 0)  # [2n, d]
                        valid_rows = torch.where(~(targets == 0).all(1))[0]

                        if len(valid_rows):
                            prompt_loss = prompt_loss_fn(preds[valid_rows], targets[valid_rows].detach())
                        else:
                            prompt_loss = torch.tensor(0, device=contrast_loss.device)
                    else:
                        prompt_loss = torch.tensor(0, device=contrast_loss.device)
                else:
                    prompt_loss = torch.tensor(0, device=contrast_loss.device)

                loss = contrast_loss + prompt_teaching_coef * prompt_loss
                loss.backward()
                optimizer_prompt.step()
                if fine_tune_mode:
                    optimizer.step()
                m_contrast_loss.append(contrast_loss.item())
                m_prompt_loss.append(prompt_loss.item())
                m_loss.append(loss.item())

            epoch_time = time.time() - start_epoch_t0

            # validation
            encoder.flush_msg()
            encoder.graph = full_graph

            uptodate_nodes = set()

            if restart_mode:
                encoder.msg_store.clear()
                if warmup_steps:
                    uptodate_nodes = warmup(encoder, val_warmup_dl, device)
                else:
                    uptodate_nodes = set()
            elif subset < 1.0:
                _ = eval_edge_prediction(encoder, offline_dl, device, prompter=prompter, restart_mode=False)

            memory_state_train_end = encoder.save_memory_state()  # save states at t_train_end
            inference_time_start = time.time()
            val_ap, val_auc = eval_edge_prediction(
                encoder, val_dl, device, restart_mode, prompter=prompter, uptodate_nodes=uptodate_nodes.copy()
            )  # memory modified
            inference_time = time.time() - inference_time_start
            memory_state_valid_end = encoder.save_memory_state()  # save states at t_valid_end
            encoder.load_memory_state(memory_state_train_end)  # load states at t_train_end
            inference_time_start_ind = time.time()
            ind_val_ap, ind_val_auc = eval_edge_prediction(
                encoder, ind_val_dl, device, restart_mode, prompter=prompter, uptodate_nodes=uptodate_nodes.copy()
            )
            inference_time_ind = time.time() - inference_time_start_ind
            encoder.load_memory_state(memory_state_valid_end)

            total_epoch_time = time.time() - start_epoch_t0
            total_epoch_times.append(total_epoch_time)

            # save
            encoder.flush_msg()
            # save prompter
            torch.save(prompter.state_dict(), get_prompter_path(epoch))
            if fine_tune_mode:
                torch.save(encoder.state_dict(), get_checkpoint_path(epoch))

            logger.info('Epoch {:4d} training took  {:.2f}s'.format(epoch, epoch_time))
            logger.info('Epoch {:4d} total    took  {:.2f}s'.format(epoch, total_epoch_time))
            logger.info('Epoch {:4d} trans. inference took  {:.2f}s'.format(epoch, inference_time))
            logger.info('Epoch {:4d} ind. inference took  {:.2f}s'.format(epoch, inference_time_ind))
            logger.info(f'Epoch mean    total loss: {np.mean(m_loss):.4f}')
            logger.info(f'Epoch mean contrast loss: {np.mean(m_contrast_loss):.4f}')
            logger.info(f'Epoch mean   prompt loss: {np.mean(m_prompt_loss):.4f}')
            logger.info(f'Val     ap: {val_ap:.4f}, Val     auc: {val_auc:.4f}')
            logger.info(f'Val ind ap: {ind_val_ap:.4f}, Val ind auc: {ind_val_auc:.4f}')

            val_aps.append(val_ap)
            ind_val_aps.append(ind_val_ap)
            val_aucs.append(val_auc)
            ind_val_aucs.append(ind_val_auc)
            train_losses.append(np.mean(m_loss))

            if use_valid:
                if early_stopper.early_stop_check(val_ap):
                    logger.info('No improvement over {} epochs, stop training'.format(
                        early_stopper.max_round))
                    break
                else:
                    torch.save(prompter.state_dict(), get_prompter_path(epoch))
                    if fine_tune_mode:
                        torch.save(encoder.state_dict(), get_checkpoint_path(epoch))

        if use_valid:
            # Testing
            logger.info(f'Loading the best model at epoch {early_stopper.best_epoch}')
            best_epoch = early_stopper.best_epoch
            best_val_ap = val_aps[best_epoch]
            best_val_auc = val_aucs[best_epoch]
            best_ind_val_ap = ind_val_aps[best_epoch]
            best_ind_val_auc = ind_val_aucs[best_epoch]
            logger.info(f'[ Val] Best     ap: {best_val_ap:.4f} Best     auc: {best_val_auc:.4f}')
            logger.info(f'[ Val] Best ind ap: {best_ind_val_ap:.4f} Best ind auc: {best_ind_val_auc:.4f}')

            if fine_tune_mode:
                best_model_path = get_checkpoint_path(early_stopper.best_epoch)
                model_state = torch.load(best_model_path)
                encoder.load_state_dict(model_state)
            else:
                encoder.load_state_dict(torch.load(ckpt_path, map_location=device))
            torch.save(encoder.state_dict(), MODEL_SAVE_PATH)  # save to the model save folder

            # load best prompt
            best_prompter_path = get_prompter_path(early_stopper.best_epoch)
            prompter_state = torch.load(best_prompter_path)
            prompter.load_state_dict(prompter_state)
            torch.save(prompter.state_dict(), PROMPTER_SAVE_PATH)

            encoder.eval()
            prompter.eval()
            encoder.graph = full_graph
            if restart_mode:
                encoder.msg_store.clear()
                if warmup_steps:
                    uptodate_nodes = warmup(encoder, test_warmup_dl, device)
                else:
                    uptodate_nodes = set()

            memory_state_val_end = encoder.save_memory_state()  # save states at t_valid_end
            test_ap, test_auc = eval_edge_prediction(
                encoder, test_dl, device, restart_mode, prompter=prompter,
                uptodate_nodes=uptodate_nodes.copy()
            )  # memory modified
            encoder.load_memory_state(memory_state_val_end)  # load states at t_valid_end
            ind_test_ap, ind_test_auc = eval_edge_prediction(
                encoder, ind_test_dl, device, restart_mode, prompter=prompter,
                uptodate_nodes=uptodate_nodes.copy()
            )
            logger.info(f'[Test]     ap: {test_ap:.4f}     auc: {test_auc:.4f}')
            logger.info(f'[Test] ind ap: {ind_test_ap:.4f} ind auc: {ind_test_auc:.4f}')

        else:
            logger.info('No validation set. Use the last epoch result.')
            test_ap = val_aps[-1]
            ind_test_ap = ind_val_aps[-1]
            test_auc = val_aucs[-1]
            ind_test_auc = ind_val_aucs[-1]

        # Save results for this run
        pickle.dump({
            "val_aps": val_aps,
            "val_aucs": val_aucs,
            "ind_val_aps": ind_val_aps,
            "ind_val_aucs": ind_val_aucs,
            "test_ap": test_ap,
            "ind_test_ap": ind_test_ap,
            "test_auc": test_auc,
            "ind_test_auc": ind_test_auc,
            "epoch_times": epoch_times,
            "train_losses": train_losses,
            "total_epoch_times": total_epoch_times
        }, open(PICKLE_SAVE_PATH, "wb"))

        results = args.copy()
        results.update(
            HASH=HASH,
            VERSION=MODEL_VERSION,
            val_ap=best_val_ap, ind_val_ap=best_ind_val_ap,
            val_auc=best_val_auc, ind_val_auc=best_ind_val_auc,
            test_ap=test_ap, test_auc=test_auc,
            ind_test_ap=ind_test_ap, ind_test_auc=ind_test_auc
        )
        json.dump(results, open(RESULT_SAVE_PATH, 'w'))

        # remove all ckpts
        shutil.rmtree(ckpts_dir)
        shutil.rmtree(prompter_ckpts_dir)

    except Exception as e:
        logger.error(traceback.format_exc())
        logger.error(e)
        raise


def get_args():
    parser = argparse.ArgumentParser()
    # Exp Setting
    parser.add_argument('--code', type=str, default='', help='Name of the saved result and model')
    parser.add_argument('--json', type=str, default='', help='Path to model result (json file)')
    parser.add_argument('--ckpt', type=str, default='', help='Path to model check point')
    parser.add_argument('--root', type=str, default='.', help='Dataset root')
    parser.add_argument('--seed', type=int, default=0, help='Random seed')
    parser.add_argument('--gpu', type=str, default='0', help='Cuda index')

    parser.add_argument('--n_epochs', type=int, default=50, help='Number of epochs')
    parser.add_argument('--bs', type=int, default=100, help='Batch size')
    parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
    parser.add_argument('--dropout', type=float, default=0.1, help='Dropout probability')
    parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate')
    parser.add_argument('--use_valid', type=bool, default=True, help='Use validation set')

    parser.add_argument('--force', action='store_true', help='Overwirte the existing task')

    # prompt and fine tune
    parser.add_argument('--prompter_type', type=str, default="transformer", choices=["transformer", "vanilla", "projection"], help='TProG type')
    parser.add_argument('--teaching_prompt', action='store_true', help='teaching prompt when fine tuning')
    parser.add_argument('--prompt_teaching_coef', type=float, default=1.0, help='prompt teaching loss coef')
    parser.add_argument('--fine_tune_mode', action='store_true', help='Using fine_tune_mode to train the encoder when training prompter')

    parser.add_argument('--prompt_end', type=float, default=0.7, help='quantile of prompting end time, this value should >= prompt_start')

    parser.add_argument('--prompt_dim', type=int, default=None, help='dim of prompts')

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = get_args()

    if args.code:
        with open(f'./results/{args.code}.json') as fh:
            saved_results = json.load(fh)
        ckpt_path = f'./saved_models/{args.code}.pth'
    else:
        with open(args.json) as fh:
            saved_results = json.load(fh)
        ckpt_path = args.ckpt_path

    if not CHECK_VERSION(saved_results['VERSION'], MODEL_VERSION):
        raise ValueError('version not match: {} != {}'.format(
            saved_results['VERSION'], MODEL_VERSION))

    prefix = saved_results['HASH'] if saved_results['prefix'] == '' else saved_results['prefix']
    prefix += '-fine_tune'
    kwargs = {k: v for k, v in saved_results.items() if k not in
              {'prefix', 'seed', 'lr', 'n_epochs', 'bs', 'patience', 'root', 'dropout', 'prompt_start', 'prompt_end'}
              }

    run(
        prefix=prefix, gpu=args.gpu, seed=args.seed,
        lr=args.lr, dropout=args.dropout, bs=args.bs, n_epochs=args.n_epochs,
        patience=args.patience, force=args.force,
        use_valid=args.use_valid, root=args.root, ckpt_path=ckpt_path,
        prompter_type=args.prompter_type,
        teaching_prompt=args.teaching_prompt, prompt_teaching_coef=args.prompt_teaching_coef,
        fine_tune_mode=args.fine_tune_mode,
        prompt_start=saved_results['prompt_start'], prompt_end=args.prompt_end,
        prompt_dim=args.prompt_dim,
        **kwargs
    )
