import inspect
import os
import warnings
from typing import Dict

import torch
import argparse
from transformers import (
    Trainer,
    AutoConfig,
    AutoTokenizer, 
    AutoModelForCausalLM, 
    AutoModelForSequenceClassification,
    DataCollatorForSeq2Seq,
    DataCollatorWithPadding,
    set_seed,
    TrainingArguments
)
from transformers.trainer_utils import EvalPrediction, is_torch_xla_available
if is_torch_xla_available():
    import torch_xla.core.xla_model as xm

import wandb
import evaluate
import datetime
import json
import numpy as np

from peft import get_peft_model, GIFTConfig, LoraConfig, VeraConfig, PeftModelForCausalLM, prepare_model_for_kbit_training

from task_config import task_config
from dataset import LoReftGLUEDataset, SupervisedDataset
from compute_metrics import compute_metrics

device = "cuda" if torch.cuda.is_available() else "cpu"
classification_tasks = {"glue"}
dtype_mapping = {
    "float32": torch.float32,
    "float16": torch.float16,
    "bfloat16": torch.bfloat16,
    "float8": "float8",
}


def prepare_model_for_peft(model, use_gradient_checkpointing=True, gradient_checkpointing_kwargs=None):

    if gradient_checkpointing_kwargs is None:
        gradient_checkpointing_kwargs = {}

    for name, param in model.named_parameters():
        # freeze base model's layers
        param.requires_grad = False

    if use_gradient_checkpointing:
        # When having `use_reentrant=False` + gradient_checkpointing, there is no need for this hack
        if "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]:
            # For backward compatibility
            if hasattr(model, "enable_input_require_grads"):
                model.enable_input_require_grads()
            else:

                def make_inputs_require_grad(module, input, output):
                    output.requires_grad_(True)

                model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

        # To support older transformers versions, check if the model supports gradient_checkpointing_kwargs
        _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
            inspect.signature(model.gradient_checkpointing_enable).parameters
        )

        if not _supports_gc_kwargs and len(gradient_checkpointing_kwargs) > 0:
            warnings.warn(
                "gradient_checkpointing_kwargs is not supported in this version of transformers. The passed kwargs will be ignored."
                " if you want to use that feature, please upgrade to the latest version of transformers.",
                FutureWarning,
            )

        gc_enable_kwargs = (
            {} if not _supports_gc_kwargs else {"gradient_checkpointing_kwargs": gradient_checkpointing_kwargs}
        )
        # enable gradient checkpointing for memory efficiency
        model.gradient_checkpointing_enable(**gc_enable_kwargs)

    return model


class LLMTrainer(Trainer):

    def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval):
        if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
            if is_torch_xla_available():
                xm.mark_step()

            logs: Dict[str, float] = {}

            # all_gather + mean() to get average loss over all processes
            tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

            # reset tr_loss to zero
            tr_loss -= tr_loss

            logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
            if grad_norm is not None:
                logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
            logs["learning_rate"] = self._get_learning_rate()

            # Track GPU memory for the current device
            logs["gpu_memory"] = torch.cuda.max_memory_allocated(device=self.args.device) / (1024.0 * 1024.0 * 1024.0)

            self._total_loss_scalar += tr_loss_scalar
            self._globalstep_last_logged = self.state.global_step
            self.store_flos()

            self.log(logs)

        metrics = None
        if self.control.should_evaluate:
            metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
            self._report_to_hp_search(trial, self.state.global_step, metrics)

            # Run delayed LR scheduler now that metrics are populated
            if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                metric_to_check = self.args.metric_for_best_model
                if not metric_to_check.startswith("eval_"):
                    metric_to_check = f"eval_{metric_to_check}"
                self.lr_scheduler.step(metrics[metric_to_check])

        if self.control.should_save:
            self._save_checkpoint(model, trial, metrics=metrics)
            self.control = self.callback_handler.on_save(self.args, self.state, self.control)


def finetune(
    model: str,
    epochs: int,
    seed: int,
    max_n_train_example: int,
    max_n_eval_example: int,
    is_wandb: bool,
    wandb_name: str,
    gradient_accumulation_steps: int,
    batch_size: int,
    output_dir: str,
    task: str,
    lr: float,
    schedule: str,
    data_dir: str,
    train_dataset: str,
    eval_dataset: str,
    save_model: bool,
    eval_batch_size: int,
    warmup_ratio: float,
    warmup_steps: int,
    weight_decay: float,
    test_split: str,
    train_on_inputs: bool,
    max_length: int,
    allow_cls_grad: bool,
    metric_for_best_model: str,
    dtype: str,
    logging_steps: int,
    wandb_dir: str,
    wandb_proj: str,
    greedy_decoding: bool,
    temperature: float,
    top_p: float,
    top_k: float,
    args,
    **kwargs,
):
    """
    Generic Representation Finetuning.
    """
    # Print all the arguments
    print(
        f"model: {model}, epochs: {epochs}, seed: {seed}, max_n_train_example: {max_n_train_example}, "
        f"max_n_eval_example: {max_n_eval_example}, is_wandb: {is_wandb}, wandb_name: {wandb_name}, "
        f"gradient_accumulation_steps: {gradient_accumulation_steps}, batch_size: {batch_size}, "
        f"output_dir: {output_dir}, task: {task}, lr: {lr}, schedule: {schedule}, data_dir: {data_dir}, "
        f"train_dataset: {train_dataset}, eval_dataset: {eval_dataset}, save_model: {save_model}, "
        f"eval_batch_size: {eval_batch_size}, warmup_ratio: {warmup_ratio}, weight_decay: {weight_decay}, "
        f"test_split: {test_split}, train_on_inputs: {train_on_inputs}, max_length: {max_length}, "
        f"allow_cls_grad: {allow_cls_grad}, metric_for_best_model: {metric_for_best_model}, dtype: {dtype}, "
        f"logging_steps: {logging_steps}, wandb_dir: {wandb_dir}, wandb_proj: {wandb_proj}, "
        f"greedy_decoding: {greedy_decoding}, temperature: {temperature}, top_p: {top_p}, top_k: {top_k}"
    )

    assert task in {
        "commonsense", "math", "alpaca", "instruct", "ultrafeedback", "glue", "gsm8k",
        "ultrafeedback_pair", "boolq"
    }

    dtype = dtype_mapping[dtype]
    
    # store/log run details
    print(
        f"task: {task}, model: {model}, lr: {lr}, weight_decay: {weight_decay}, rank: {args.rank}, "
        f"epoch: {epochs}, train_on_inputs: {train_on_inputs}, "
        f"max_length: {max_length}, allow_cls_grad: {allow_cls_grad}"
    )

    # everything is guarded by a single seed
    set_seed(seed)

    model_name = model
    model_str = model.split("/")[-1]
    train_dataset_str = train_dataset
    now = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")
    if train_dataset is not None:
        run_name = f"{model_str}.{task}.{train_dataset_str}.{test_split}.{now}.{args.rank}.{lr}.{weight_decay}"
    else:
        run_name = f"{model_str}.{args.tuner}.{args.description}.{task}.{now}.{args.rank}.{lr}.{weight_decay}"

    if warmup_steps > 0:
        assert warmup_ratio == 0., "Cannot specify both warmup_steps and warmup_ratio."
    if warmup_ratio > 0.:
        assert warmup_steps == 0, "Cannot specify both warmup_steps and warmup_ratio."

    # load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        model_max_length=max_length,
        padding_side="right",
        use_fast=False,
    )
    if tokenizer.unk_token == None and tokenizer.pad_token == None:
        # raw llama3
        print("adding a special padding token...")
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        need_resize = True
    else:
        tokenizer.pad_token = tokenizer.unk_token
        need_resize = False

    # load dataset splits
    assert task in task_config, f"Unrecognized task: {task}"
    train_datasets = task_config[task]["train_datasets"] if train_dataset is None else [train_dataset]
    if task == "glue":
        eval_datasets = [train_dataset]
    else:
        eval_datasets = task_config[task]["eval_datasets"] if eval_dataset is None else [eval_dataset]
        
    Dataset = LoReftGLUEDataset if task == "glue" else SupervisedDataset 
    train_dataset = Dataset(
        task, train_datasets[0] if task == "glue" or task == "ultrafeedback_pair" \
            else (os.path.join(data_dir, train_datasets[0]) if data_dir is not None else train_datasets[0]), 
        tokenizer, data_split="train", seed=seed, max_n_example=max_n_train_example,
        **{"test_split": test_split}
    )
    trigger_tokens = train_dataset.trigger_tokens
    num_labels = train_dataset.num_labels

    all_eval_datasets = {}
    for eval_dataset in eval_datasets:
        test_splits = test_split.split(";")
        all_eval_datasets[eval_dataset] = {}
        for split in test_splits:
            raw_eval = Dataset(
                task, eval_dataset if task == "glue" else os.path.join(data_dir, eval_dataset), 
                tokenizer, data_split=split, seed=seed, max_n_example=max_n_eval_example,
            )
            all_eval_datasets[eval_dataset][split] = [raw_eval, raw_eval.raw_dataset]
    eval_datasets = all_eval_datasets

    if task == "glue":
        # we repartition the eval_datatsets into [1] 50% validation + [2] 50% test
        # we select the best model on [1] during training
        # we test the selected model on [2] to ensure fairness
        to_split_eval_datasets = eval_datasets[train_dataset_str][test_split][0]
        if len(to_split_eval_datasets) > 5000:
            in_train_n_eval_sample = 1000
        else:
            in_train_n_eval_sample = len(to_split_eval_datasets) // 2

        new_splits = torch.utils.data.random_split(
            to_split_eval_datasets, [len(to_split_eval_datasets)-in_train_n_eval_sample, in_train_n_eval_sample]
        )
        
        in_test_eval_datasets, in_train_eval_datasets = new_splits[0], new_splits[1]
        eval_datasets[train_dataset_str][test_split][0] = in_test_eval_datasets
        print("GLUE validation split (in training): ", len(in_train_eval_datasets))
        print("GLUE validation split (testing): ", len(eval_datasets[train_dataset_str][test_split][0]))

        is_regression = train_dataset_str == "stsb"
        metric = evaluate.load("glue", train_dataset_str)
        # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
        # predictions and label_ids field) and has to return a dictionary string to float.
        def in_training_compute_metrics(p: EvalPrediction):
            preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
            preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
            result = metric.compute(predictions=preds, references=p.label_ids)
            if len(result) > 1:
                result["combined_score"] = np.mean(list(result.values())).item()
            return result

    # load model based on task type.
    if task in classification_tasks:
        config = AutoConfig.from_pretrained(
            model, num_labels=num_labels,
            finetuning_task=train_dataset_str,
            load_in_8bit=True if dtype == "float8" else False,
            device_map=device
        )
        # full precision loading since usually for small models
        model = AutoModelForSequenceClassification.from_pretrained(
            model,
            config=config, # just providing the label
            torch_dtype=dtype if dtype != "float8" else None,
            load_in_8bit=True if dtype == "float8" else False,
            device_map=device
        )
        task_type = "SEQ_CLS"
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model,
            torch_dtype=dtype if dtype != "float8" else None,  # save memory
            load_in_8bit=True if dtype == "float8" else False,
            device_map=device,
        )
        config = model.config
        task_type = "CAUSAL_LM"
    if need_resize:
        model.resize_token_embeddings(len(tokenizer))
        
    # select collator based on the type
    if task in classification_tasks:
        data_collator = DataCollatorWithPadding(
            tokenizer=tokenizer,
            padding="longest"
        )
    else:
        data_collator = DataCollatorForSeq2Seq(
            tokenizer=tokenizer,
            model=model,
            label_pad_token_id=-100,
            padding="longest",
        )

    model = prepare_model_for_peft(model, use_gradient_checkpointing=args.use_gradient_checkpointing, gradient_checkpointing_kwargs={"use_reentrant": False})

    peft_config_cls = {"gift": GIFTConfig, "lora": LoraConfig, "vera": VeraConfig, "dora": LoraConfig}[args.tuner]
    peft_config_args = {"r": args.rank, "target_modules": args.target_modules, "task_type": task_type}
    if "gift" in args.tuner:

        # Hardcode based on description for now.
        if args.description == "config_alpha":
            args.target_modules = ["q_proj", "v_proj"]
            args.tied_modules = [
                ["model.layers.\d+.self_attn.q_proj"],
                ["model.layers.\d+.self_attn.v_proj"],
            ]
            args.transform_dim = {
                "q_proj": "input",
                "v_proj": "input"
            }
        elif args.description == "config_beta":
            args.target_modules = ["q_proj", "k_proj", "v_proj", "up_proj", "down_proj"]
            args.tied_modules = [
                ["model.layers.\d+.self_attn.q_proj"], 
                ["model.layers.\d+.self_attn.k_proj"], 
                ["model.layers.\d+.self_attn.v_proj"], 
                ["model.layers.\d+.mlp.up_proj"],
                ["model.layers.\d+.mlp.down_proj"]
            ]
            args.transform_dim = {
                "q_proj": "input",
                "k_proj": "input",
                "v_proj": "input",
                "up_proj": "input",
                "down_proj": "input"
            }
        elif args.description == "config_gamma":
            args.target_modules = ["o_proj", "down_proj"]
            args.tied_modules = [
                ["model.layers.\d+.self_attn.o_proj"], 
                ["model.layers.\d+.mlp.down_proj"]
            ]
            args.transform_dim = {
                "o_proj": "output",
                "down_proj": "output",
            }
        elif args.description == "config_delta":
            args.target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"]
            args.tied_modules = []
            for l in range(32): # Hardcoded for llama-7/8b for now.
                args.tied_modules += [
                    [f"model.layers.{l}.self_attn.q_proj", f"model.layers.{l}.self_attn.k_proj", f"model.layers.{l}.self_attn.v_proj"],
                    [f"model.layers.{l}.self_attn.o_proj"], 
                    [f"model.layers.{l}.mlp.up_proj", f"model.layers.{l}.mlp.gate_proj"],
                    [f"model.layers.{l}.mlp.down_proj"]
                ]
            args.transform_dim = {
                "q_proj": "input",
                "k_proj": "input",
                "v_proj": "input",
                "o_proj": "output",
                "up_proj": "input",
                "gate_proj": "input",
                "down_proj": "output"
            }
        else:
            raise ValueError(f"Unrecognized description: {args.description}")
        peft_config_args["target_modules"] = args.target_modules
        peft_config_args["tied_modules"] = args.tied_modules
        peft_config_args["transform_dim"] = args.transform_dim
        peft_config_args["gift_alpha"] = args.scaling_factor
    if args.tuner == "vera":
        peft_config_args["projection_prng_key"] = seed
    if "lora" in args.tuner or "dora" in args.tuner:
        peft_config_args["lora_alpha"] = args.scaling_factor
    if "dora" in args.tuner:
        peft_config_args["use_dora"] = True

    peft_config = peft_config_cls(**peft_config_args)

    wrapped_model: PeftModelForCausalLM = get_peft_model(model, peft_config, adapter_name=args.tuner)
    num_trainable = sum([p.numel() for p in wrapped_model.parameters() if p.requires_grad])
    num_non_trainable = sum([p.numel() for p in wrapped_model.parameters() if not p.requires_grad])
    percent_trainable = num_trainable / num_non_trainable * 100

    print(wrapped_model)

    # num_trainable, percent_trainable = wrapped_model.num_trainable_parameters()
    print(f"Num. trainable parameters: {num_trainable/1e6:.4f}M. Percent of trainable parameters: {percent_trainable:.4f}%")

    # for GLUE tasks, we enable gradients on the classifier head.
    # the parameter will be counted as well.
    if task == "glue" and allow_cls_grad:
        for param in wrapped_model.model.classifier.parameters():
            # wrapped_model with HF trainer will automatically pick up these params to optimize
            param.requires_grad = True

    # train enables dropout but no grads.
    # this line might not be necessary since HF trainer enables this by default.
    # wrapped_model.backbone.train()
    n_params = num_trainable # wrapped_model.count_parameters(include_model=False)

    # start wandb logging
    if is_wandb:
        run = wandb.init(
            project=f"{wandb_proj}", 
            entity=wandb_name,
            name=run_name,
            dir=wandb_dir,
        )
        run.summary.update(vars(args))
        wandb.log(
            {"train/n_params": n_params, "train/percent_trainable": percent_trainable})

    # # training args
    if dtype==torch.float16 and args.autocast:
        print("Using FP16 with autocast.")
    if dtype==torch.bfloat16 and args.autocast:
        print("Using BF16 with autocast.")
    training_args = TrainingArguments(
        output_dir=f"{output_dir}/{run_name}",
        run_name=run_name,
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=eval_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        evaluation_strategy="epoch" if task == "glue" else "no",
        save_strategy="epoch" if task == "glue" else "no",
        metric_for_best_model=metric_for_best_model if task == "glue" else None,
        load_best_model_at_end=True if task == "glue" else False,
        logging_strategy="steps",
        save_total_limit=1, # for GLUE, it will save 2 at max.
        logging_steps=logging_steps,
        lr_scheduler_type=schedule,
        learning_rate=lr,
        warmup_ratio=warmup_ratio,
        warmup_steps=warmup_steps,
        optim="adamw_torch",
        weight_decay=weight_decay,
        report_to="wandb" if is_wandb else "none",
        use_cpu=False if device == "cuda" else True,
        seed=seed,
        remove_unused_columns=True,
        fp16=dtype==torch.float16 and args.autocast,
        bf16=dtype==torch.bfloat16 and args.autocast,
        gradient_checkpointing=args.use_gradient_checkpointing,
        max_grad_norm=args.max_grad_norm,
    )

    # make trainer
    # trainer_class = ReftTrainerForSequenceClassification \
    #     if task in classification_tasks else Trainer
    trainer_class = LLMTrainer # Check if we need a different class for cls tasks
    trainer = trainer_class(
        model=wrapped_model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=in_train_eval_datasets if task == "glue" else None,
        data_collator=data_collator,
        compute_metrics=in_training_compute_metrics if task == "glue" else None,
    )
    trainer.train()

    # dump config
    args_dict = vars(args)
    args_dict["n_params"] = n_params
    json_file_name = f"{output_dir}/{run_name}/args.json"
    with open(json_file_name, 'w') as json_file:
        json.dump(args_dict, json_file, indent=4)

    # save model
    if save_model:
        wrapped_model.save_pretrained(f"{output_dir}/{run_name}")

    # ensure everything is in eval mode
    wrapped_model.eval()

    print({"n_params": n_params})
    # do eval
    eval_results = {}
    for dataset_name in eval_datasets:
        # split evalset into chunks
        for split, (eval_dataset, data_items) in eval_datasets[dataset_name].items():
            
            generations, stats = compute_metrics(
                task, dataset_name, wrapped_model, tokenizer, eval_dataset, data_items,
                trigger_tokens, run_name, eval_batch_size, 
                data_collator if task in classification_tasks else None,
                split, greedy_decoding, temperature, top_p, top_k
            )

            # log
            eval_results.update(stats)
            if is_wandb:
                wandb.log(stats)
            generations = stats if generations is None else generations
            result_json_file_name = f"{output_dir}/{run_name}/{dataset_name}_{split}_outputs.json"
            with open(result_json_file_name, 'w') as json_file:
                json.dump(generations, json_file, indent=4)

    # log final eval stats
    result_json_file_name = f"{output_dir}/{run_name}/eval_results.json"
    eval_results["n_params"] = n_params
    with open(result_json_file_name, 'w') as json_file:
        json.dump(eval_results, json_file, indent=4)

    print(f"Training results can be found in {output_dir}/{run_name}/checkpoint")

def main():
    parser = argparse.ArgumentParser(description="A simple script that takes different arguments.")
    
    # Data
    parser.add_argument('-task', '--task', type=str, default=None)
    parser.add_argument('-data_dir', '--data_dir', type=str, default="./datasets")
    parser.add_argument('-train_dataset', '--train_dataset', type=str, default=None)
    parser.add_argument('-eval_dataset', '--eval_dataset', type=str, default=None)
    parser.add_argument('-model', '--model', type=str, help='yahma/llama-7b-hf', default='yahma/llama-7b-hf')
    
    # Experiment
    parser.add_argument('-seed', '--seed', type=int, help='42', default=42)
    parser.add_argument('-is_wandb', '--is_wandb', action='store_true')
    parser.add_argument('-wandb_name', '--wandb_name', type=str, default="reft")
    parser.add_argument('-eval_batch_size', '--eval_batch_size', type=int, default=4)
    parser.add_argument('-output_dir', '--output_dir', type=str, default="./official_results")
    parser.add_argument('-act_fn', '--act_fn', type=str, default=None)
    parser.add_argument('-add_bias', '--add_bias', action='store_true')
    parser.add_argument('-test_split', '--test_split', type=str, default="validation")
    parser.add_argument('-train_on_inputs', '--train_on_inputs', action='store_true')
    parser.add_argument('-max_length', '--max_length', type=int, help=512, default=512)
    parser.add_argument('-nt', '--use_normalized_template', action='store_true')
    parser.add_argument('-dtype', '--dtype', type=str, default="bfloat16" if device == "cuda" else "float32")
    parser.add_argument('-logging_steps', '--logging_steps', type=int, help=1, default=1)
    parser.add_argument('-wandb_dir', '--wandb_dir', type=str, default='wandb')
    parser.add_argument('-wandb_proj', '--wandb_proj', type=str, default='GIFT')

    # decoding params
    parser.add_argument('-e', '--epochs', type=int, help='1', default=1)
    parser.add_argument('-t', '--temperature', type=float, default=None)
    parser.add_argument('-top_p', '--top_p', type=float, default=None)
    parser.add_argument('-top_k', '--top_k', type=float, default=None)
    parser.add_argument('-gd', '--greedy_decoding', action='store_true')

    # Training params
    parser.add_argument('--use_gradient_checkpointing', action='store_true', default=False)
    parser.add_argument('--autocast', action='store_true', default=False)
    parser.add_argument('-lr', '--lr', type=float, default=5e-3)
    parser.add_argument('-schedule', '--schedule', type=str, default='linear')
    parser.add_argument('-wu', '--warmup_ratio', type=float, default=0.)
    parser.add_argument('-ws', '--warmup_steps', type=int, default=0)
    parser.add_argument('-wd', '--weight_decay', type=float, default=0.00)
    parser.add_argument('-dropout', '--dropout', type=float, default=0.00)
    parser.add_argument('--max_grad_norm', type=float, default=1.0)
    parser.add_argument('-allow_cls_grad', '--allow_cls_grad', action='store_true')
    parser.add_argument('-save_model', '--save_model', action='store_true')
    parser.add_argument('-max_n_train_example', '--max_n_train_example', type=int, default=None)
    parser.add_argument('-max_n_eval_example', '--max_n_eval_example', type=int, default=None)
    parser.add_argument('-gradient_accumulation_steps', '--gradient_accumulation_steps', type=int, default=4)
    parser.add_argument('-batch_size', '--batch_size', type=int, default=4)
    parser.add_argument('-metric_for_best_model', '--metric_for_best_model', type=str, default="accuracy")

    # GIFT params
    group = parser.add_argument_group("GIFT parameters")
    group.add_argument("--tuner", type=str, default="gift", help="PEFT Tuner.")
    group.add_argument("--rank", type=int, default=32, help="Rank r.")
    group.add_argument(
        "--target_modules", type=str, nargs="+",
        default=["q_proj", "v_proj"],
        help="Modules to apply finetuning on.",
    )
    group.add_argument(
        "--tied_modules", type=str, nargs="+",
        default=None,
        help="Modules to tie together.",
    )
    group.add_argument(
        "--transform_dim", type=str,
        default="input",
        help="GIFT dimension to transform.",
    )
    group.add_argument(
        "--scaling_factor", type=int, default=8,
        help="Scaling ratio for residuals.",
    )
    group.add_argument(
        "--description", type=str, default="",
        help="Description of exp.",
    )

    args = parser.parse_args()

    finetune(**vars(args), args=args)


if __name__ == "__main__":
    main()