import os
import sys
from typing import List
import json
import fire
import torch
import transformers
from datasets import load_dataset
from typing import List, Optional, Union
import pdb 
from tqdm import tqdm
import copy
import re
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn.functional as F
import copy
import numpy as np
import random

from torch.nn.utils.rnn import pad_sequence
sys.path.append(os.path.join(os.getcwd(), "peft/src/"))
from peft import (  # noqa: E402
    LoraConfig,
    BottleneckConfig,
    PrefixTuningConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_int8_training,
    set_peft_model_state_dict,
)

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoModel,
    GenerationConfig,
    DataCollatorForSeq2Seq,
    get_linear_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
    LlamaTokenizer,
)

from llama import LlamaForCausalLM  # noqa: F402
from llama_explore.modeling_llama_explore import New_LlamaForCausalLM  # noqa: F402

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

try:
    if torch.backends.mps.is_available():
        device = "mps"
except:  # noqa: E722
    pass

import wandb
from typing import List, Any, Tuple

def train_runner(
    model,
    model_explore_list,
    tokenizer,
    tokenizer_explore,
    train_data,
    val_data,
    micro_batch_size,
    gradient_accumulation_steps,
    num_epochs,
    learning_rate,
    learning_rate_a,
    warmup_steps,
    fp16,
    eval_step,
    save_step,
    output_dir,
    explore_output_dir,
    device,
    val_set_size,
    use_wandb,
    wandb_project,
    wandb_run_name,
    ddp,
    group_by_length,
    resume_from_checkpoint,
    explore_flag,
    seed,
    seed_worker,
    topk_logits,
    clip_value,
    capacity,
    alpha,
    beta,
    model_type,
    gpus
):
    """
    A customized training loop that replaces the original transformers.Trainer usage.
    Includes a tqdm progress bar for training.
    """
    run = wandb.init(
        project=wandb_project,
        name=wandb_run_name,
        config={
            "train_data": train_data,
            "micro_batch_size": micro_batch_size,
            "gradient_accumulation_steps": gradient_accumulation_steps,
            "num_epochs": num_epochs,
            "learning_rate": learning_rate,
            "learning_rate_a": learning_rate_a,
            "output_dir": output_dir,
            "explore_output_dir": explore_output_dir,
            "device": device,
            "wandb_run_name": wandb_run_name,
            "explore_flag": explore_flag,
            "topk_logits": topk_logits,
            "seed": seed,
            "capacity": capacity,
            "alpha": alpha,
            "beta": beta,
            "model_type": model_type
        },
        mode="offline"
    )
    # ---------------------------
    # 1. Setup DataLoaders
    # ---------------------------
    collator = DataCollatorForSeq2Seq(
        tokenizer,
        pad_to_multiple_of=8,
        return_tensors="pt",
        padding=True
    )
    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=micro_batch_size,
        shuffle=True,
        collate_fn=collator,
        worker_init_fn=lambda _: np.random.seed(seed + torch.distributed.get_rank() if ddp else seed),
        generator=torch.Generator().manual_seed(seed)
    )
    if val_data is not None:
        val_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=micro_batch_size,
            shuffle=False,
            collate_fn=collator,
            worker_init_fn=lambda _: np.random.seed(seed + torch.distributed.get_rank() if ddp else seed),
            generator=torch.Generator().manual_seed(seed)
        )
    else:
        val_loader = None
    # print("start!")
    # ---------------------------
    # 2. Setup Optimizer / Scheduler
    # ---------------------------
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    total_steps = len(train_loader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps,
    )
    optimizer_explore_list, scheduler_explore_list, scaler_explore_list = [], [], []
    if explore_flag:
        for model_explore in model_explore_list:
            optimizer_explore = torch.optim.AdamW(model_explore.parameters(), lr=learning_rate_a)
            total_steps_explore = len(train_loader) * num_epochs

            scheduler_explore = get_linear_schedule_with_warmup(
                optimizer_explore,
                num_warmup_steps=warmup_steps,
                num_training_steps=total_steps_explore,
            )
            optimizer_explore_list.append(optimizer_explore)
            scheduler_explore_list.append(scheduler_explore)

    # Optionally set up gradient scaling for FP16
    scaler = None
    scaler_explore = None
    if fp16:
        scaler = torch.cuda.amp.GradScaler()
        if explore_flag:
            for model_explore in model_explore_list:
                scaler_explore_list.append(torch.cuda.amp.GradScaler())

    global_step = 0
    model.train()
    step = 0
    for epoch in range(num_epochs):
        print(f"F1 Train Epoch {epoch+1}/{num_epochs}")
        epoch_loss = 0.0

        with tqdm(total=len(train_loader), desc=f"Training Epoch {epoch+1}", unit="batch") as pbar:
            for step, batch in enumerate(train_loader):
                for k in batch:
                    batch[k] = batch[k].to(device)

                with torch.amp.autocast("cuda", dtype=torch.bfloat16):
                    exploit_outputs = model(**batch)  
                    loss = exploit_outputs.loss
                    loss = loss / gradient_accumulation_steps
                    run.log({
                        "loss": loss.item(),
                        "step": step
                    })
                step += 1
                if scaler:
                    scaler.scale(loss).backward()
                else:
                    loss.backward()

                epoch_loss += loss.item() * gradient_accumulation_steps
                
                # Gradient accumulation
                if (step + 1) % gradient_accumulation_steps == 0:
                    if scaler:
                        scaler.step(optimizer)
                        scaler.update()
                        
                    else:
                        optimizer.step()

                    optimizer.zero_grad()
                    scheduler.step()
                    global_step += 1

                    # Logging every 10 steps or so (just an example)
                    if global_step % 10 == 0:
                        current_loss = loss.item() * gradient_accumulation_steps
                        print(f"  [train_runner] step={global_step}, loss={current_loss:.8f}")

                    
                    # Save checkpoint at specified steps
                    if save_step > 0 and (global_step % save_step == 0):
                        ckpt_path = os.path.join(output_dir, f"checkpoint-{global_step}")
                        os.makedirs(ckpt_path, exist_ok=True)
                        print(f"  [train_runner] Saving checkpoint to {ckpt_path}")
                        model.save_pretrained(ckpt_path)

                # Update our training tqdm
                pbar.update(1)
                pbar.set_postfix(
                    {
                        "loss": f"{loss.item() * gradient_accumulation_steps:.8f}",
                        "lr": f"{scheduler.get_last_lr()[0]:.2e}"
                    }
                )
    if explore_flag:
        step=0
        global_step=0
        model.eval()
        
        for model_explore_idx in range(len(model_explore_list)):
            model_explore_list[model_explore_idx].train()
        for model_explore_idx in range(len(model_explore_list)):
            for epoch in range(num_epochs):
                print(f"F2 TRAIN Epoch {epoch+1}/{num_epochs}")
                epoch_explore_loss = 0.0

                with tqdm(total=len(train_loader), desc=f"Training Epoch {epoch+1}", unit="batch") as pbar:
                    for step, batch in enumerate(train_loader):
                        explore_loss_list = []
                        for k in batch:
                            batch[k] = batch[k].to(device)

                        with torch.amp.autocast("cuda", dtype=torch.bfloat16):
                            with torch.no_grad():
                                exploit_outputs = model(**batch)
                                for i in range(model_explore_idx):
                                    gpu_id = gpus[i + 1]
                                    exploit_input_ids = batch['input_ids'].detach()
                                    exploit_attention_mask = batch['attention_mask'].detach()
                                    exploit_labels = batch['labels'].detach()
                                    
                                    exploit_logits = exploit_outputs.logits.detach()  # [batch_size, seq_len_output, vocab_size]
                                    exploit_pred_ids = exploit_logits.argmax(dim=-1)  # [batch_size, seq_len_output]

                                    pred_valid_lengths = (exploit_pred_ids != tokenizer.pad_token_id).sum(dim=1).tolist()

                                    exploit_pred_ids_padded = pad_sequence(
                                        [ids[:length] for ids, length in zip(exploit_pred_ids, pred_valid_lengths)],  
                                        batch_first=True,
                                        padding_value=tokenizer.pad_token_id
                                    )  # [batch_size, max_pred_len]

                                    exploit_hidden_states = torch.stack(exploit_outputs.hidden_states, dim=0).detach()

                                    exploit_outputs = model_explore_list[i](
                                            input_ids=exploit_input_ids.to(f"cuda:{gpu_id}"),  # 新增：拼接后的输入 token
                                            attention_mask=exploit_attention_mask.to(f"cuda:{gpu_id}"),  # 新增：更新后的掩码
                                            labels=exploit_labels.to(f"cuda:{gpu_id}"),
                                            exploit_hidden_states=exploit_hidden_states.to(gpu_id),
                                            exploit_labels=exploit_labels.to(f"cuda:{gpu_id}"),
                                            topk_logits=topk_logits,
                                            clip_value=clip_value,
                                            f1_pred_ids=exploit_pred_ids_padded.to(f"cuda:{gpu_id}"),
                                            beta=beta
                                        )
                            gpu_id = gpus[model_explore_idx + 1]
                                    
                            exploit_input_ids = batch['input_ids'].detach()
                            exploit_attention_mask = batch['attention_mask'].detach()
                            exploit_labels = batch['labels'].detach()
                            
                            exploit_logits = exploit_outputs.logits.detach()  # [batch_size, seq_len_output, vocab_size]
                            exploit_pred_ids = exploit_logits.argmax(dim=-1)  # [batch_size, seq_len_output]

                            pred_valid_lengths = (exploit_pred_ids != tokenizer.pad_token_id).sum(dim=1).tolist()
                            # max_pred_len = max(pred_valid_lengths) 

                            exploit_pred_ids_padded = pad_sequence(
                                [ids[:length] for ids, length in zip(exploit_pred_ids, pred_valid_lengths)],
                                batch_first=True,
                                padding_value=tokenizer.pad_token_id
                            )  # [batch_size, max_pred_len]

                            exploit_hidden_states = torch.stack(exploit_outputs.hidden_states, dim=0).detach()
                            del exploit_outputs

                            
                            explore_outputs = model_explore_list[model_explore_idx](
                                # attention_mask=exploit_attention_mask,
                                input_ids=exploit_input_ids.to(f"cuda:{gpu_id}"), 
                                attention_mask=exploit_attention_mask.to(f"cuda:{gpu_id}"), 
                                labels=exploit_labels.to(f"cuda:{gpu_id}"),
                                exploit_hidden_states=exploit_hidden_states.to(gpu_id),
                                exploit_labels=exploit_labels.to(f"cuda:{gpu_id}"),
                                topk_logits=topk_logits,
                                clip_value=clip_value,
                                f1_pred_ids=exploit_pred_ids_padded.to(f"cuda:{gpu_id}"),
                                beta=beta
                            )
                            exploit_hidden_states = torch.stack(explore_outputs.hidden_states, dim=0).detach()
                            exploit_logits = explore_outputs.logits.detach()  # [batch_size, seq_len_output, vocab_size]
                            exploit_pred_ids = exploit_logits.argmax(dim=-1)  
                            pred_valid_lengths = (exploit_pred_ids != tokenizer.pad_token_id).sum(dim=1).tolist()
                            max_pred_len = max(pred_valid_lengths) 

                            exploit_pred_ids_padded = pad_sequence(
                                [ids[:length] for ids, length in zip(exploit_pred_ids, pred_valid_lengths)], 
                                batch_first=True,
                                padding_value=tokenizer.pad_token_id
                            )  # [batch_size, max_pred_len]

                            explore_loss = explore_outputs.loss
                            explore_loss = explore_loss / gradient_accumulation_steps
                                
                            run.log({
                                f"explore_loss_{model_explore_idx}": explore_loss.item()
                            })
                            explore_loss_list.append(copy.deepcopy(explore_loss.item()))

                            if scaler_explore_list[model_explore_idx]:
                                scaler_explore_list[model_explore_idx].scale(explore_loss).backward()
                            else:
                                explore_loss.backward()

                            epoch_explore_loss += explore_loss.item() * gradient_accumulation_steps 

                            # Gradient accumulation
                            if (step + 1) % gradient_accumulation_steps == 0:
                                if scaler_explore:
                                    print(model_explore_idx)
                                    scaler_explore_list[model_explore_idx].step(optimizer_explore_list[model_explore_idx])
                                    scaler_explore_list[model_explore_idx].update()
                                else:
                                    optimizer_explore_list[model_explore_idx].step()

                                optimizer_explore_list[model_explore_idx].zero_grad()
                                scheduler_explore_list[model_explore_idx].step()
                                global_step += 1

                                # Logging every 10 steps or so (just an example)
                                if global_step % 10 == 0:
                                    current_explore_loss = explore_loss.item() * gradient_accumulation_steps
                                    print(f"[train_runner] step={global_step}, explore_loss={current_explore_loss:.8f}")
                                torch.cuda.empty_cache()

                            pbar.update(1)
                            postfix_dict = {}
                            
                            postfix_dict[f"explore_loss_{model_explore_idx}"] = f"{explore_loss_list[0] * gradient_accumulation_steps:.8f}"
                            pbar.set_postfix(postfix_dict)

    print("[train_runner] Training complete.")

def train(
        base_model: str = "", 
        explore_base_model: str = "",
        data_path: str = "yahma/alpaca-cleaned",
        output_dir: str = "./lora-alpaca",
        explore_output_dir: str = "./lora-explore-alpaca",
        adapter_name: str = "lora",
        load_8bit : bool = False,
        # training hyperparams
        batch_size: int = 128,
        micro_batch_size: int = 4,
        num_epochs: int = 3,
        learning_rate: float = 2e-4,
        learning_rate_a: float = 3e-4,
        cutoff_len: int = 256,
        val_set_size: int = 2000,
        use_gradient_checkpointing: bool = False,
        eval_step: int = 200,
        save_step: int = 200,
        # lora hyperparams
        lora_r: int = 8,
        lora_alpha: int = 16,
        lora_dropout: float = 0.05,
        lora_dropout_a: float = 0.08,
        lora_target_modules: List[str] = None,
        # bottleneck adapter hyperparams
        bottleneck_size: int = 256,
        non_linearity: str = "tanh",
        adapter_dropout: float = 0.0,
        use_parallel_adapter: bool = False,
        use_adapterp: bool = False,
        target_modules: List[str] = None,
        scaling: Union[float, str] = 1.0,
        # prefix tuning hyperparams
        num_virtual_tokens: int = 30,
        # llm hyperparams
        train_on_inputs: bool = True,  # if False, masks out inputs in loss
        group_by_length: bool = False,  # faster, but produces an odd training loss curve
        # wandb params
        wandb_project: str = "BoostLLM",
        wandb_run_name: str = "",
        wandb_watch: str = "",  # options: false | gradients | all
        wandb_log_model: str = "",  # options: false | true
        resume_from_checkpoint: str = None,  # either training checkpoint or final adapter
        dataset_name: str = "",
        test_data_path: str="",
        direct_test: bool = False,
        explore_flag: bool = False,
        explore_logits_factor: float = 100.0,
        seed: int = 42,
        topk_logits: int=10,
        clip_value: float=0.1,
        fp16: bool = True,
        capacity: int=100,
        alpha: float=0.1,
        beta: float=0.1,
        model_type:str="BoostLLM",
        gpus:str="0"
):
    print(
        f"Finetuning model with params:\n"
        f"base_model: {base_model}\n"
        f"explore_base_model: {explore_base_model}\n"
        f"data_path: {data_path}\n"
        f"output_dir: {output_dir}\n"
        f"explore_output_dir: {explore_output_dir}\n"
        f"batch_size: {batch_size}\n"
        f"micro_batch_size: {micro_batch_size}\n"
        f"num_epochs: {num_epochs}\n"
        f"learning_rate: {learning_rate}\n"
        f"learning_rate_a: {learning_rate_a}\n"
        f"cutoff_len: {cutoff_len}\n"
        f"val_set_size: {val_set_size}\n"
        f"use_gradient_checkpointing: {use_gradient_checkpointing}\n"
        f"lora_r: {lora_r}\n"
        f"lora_alpha: {lora_alpha}\n"
        f"lora_dropout: {lora_dropout}\n"
        f"lora_dropout_a: {lora_dropout_a}\n"
        f"lora_target_modules: {lora_target_modules}\n"
        f"bottleneck_size: {bottleneck_size}\n"
        f"non_linearity: {non_linearity}\n"
        f"adapter_dropout: {adapter_dropout}\n"
        f"use_parallel_adapter: {use_parallel_adapter}\n"
        f"use_adapterp: {use_adapterp}\n"
        f"train_on_inputs: {train_on_inputs}\n"
        f"scaling: {scaling}\n"
        f"adapter_name: {adapter_name}\n"
        f"target_modules: {target_modules}\n"
        f"group_by_length: {group_by_length}\n"
        f"wandb_project: {wandb_project}\n"
        f"wandb_run_name: {wandb_run_name}\n"
        f"wandb_watch: {wandb_watch}\n"
        f"wandb_log_model: {wandb_log_model}\n"
        f"resume_from_checkpoint: {resume_from_checkpoint}\n"
        f"dataset_name: {dataset_name}\n"
        f"test_data_path: {test_data_path}\n"
        f"direct_test: {direct_test}\n"
        f"explore_flag: {explore_flag}\n"
        f"explore_logits_factor: {explore_logits_factor}\n"
        f"seed: {seed}\n",
        f"topk_logits: {topk_logits}\n",
        f"clip_value: {clip_value}\n",
        f"fp16: {fp16}\n",
        f"capacity:{capacity}\n",
        f"alpha:{alpha}\n",
        f"beta:{beta}\n",
        f"model_type:{model_type}\n"
        f"gpus:{gpus}\n"
    )
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed) 
    random.seed(seed)     
    

    assert (
        base_model
    ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
    
    if explore_flag:
        assert (
            explore_base_model
        ), "Please specify a --explore_model , e.g. --explore_model='decapoda-research/llama-7b-hf'"
    explore_base_model_list = explore_base_model.split(",")
    gradient_accumulation_steps = batch_size // micro_batch_size

    device_map = "auto"
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    ddp = world_size != 1
    # if ddp:
    #     device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
    #     gradient_accumulation_steps = gradient_accumulation_steps // world_size    
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    if ddp:
        def seed_worker(worker_id):
            worker_seed = seed + torch.distributed.get_rank()
            np.random.seed(worker_seed)
            random.seed(worker_seed)
            torch.manual_seed(worker_seed)
    else:
        seed_worker = None

    if ddp:
        device = torch.device(f"cuda:{local_rank}")
        gradient_accumulation_steps = gradient_accumulation_steps // world_size
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        device_map = "auto"  # 
        
    # print('--------',ddp, device_map)
    # exit(0)

    # Check if parameter passed or if set within environ
    use_wandb = len(wandb_project) > 0 or (
            "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
    )
    # Only overwrite environ if wandb param passed
    if len(wandb_project) > 0:
        os.environ["WANDB_PROJECT"] = wandb_project
    if len(wandb_watch) > 0:
        os.environ["WANDB_WATCH"] = wandb_watch
    if len(wandb_log_model) > 0:
        os.environ["WANDB_LOG_MODEL"] = wandb_log_model

    if load_8bit:
        assert load_8bit ==False
        model = AutoModelForCausalLM.from_pretrained(
            base_model,
            load_in_8bit=load_8bit,
            torch_dtype=torch.float16,
            device_map=device_map,
            trust_remote_code=True,
        )
    else:
        model = LlamaForCausalLM.from_pretrained(
            base_model,
            load_in_8bit=False,
            torch_dtype=torch.float16,
            device_map={"": 0},
            trust_remote_code=True,
            mode="Training"
            # attn_implementation="flash_attention_2"
        )
        model_hidden_size = model.config.hidden_size
        model_explore_list = []
        # gpus = [int(i) for i in gpus.split(",")]
        if explore_flag:
            for i, explore_base_model in enumerate(explore_base_model_list):
                device_id = int(gpus[i + 1]) if isinstance(gpus, (list, tuple)) else int(gpus.split(",")[i+1])
                model_device = torch.device(f"cuda:{device_id}" if torch.cuda.is_available() else "cpu")
                # torch.cuda.set_device(device_id)
                model_explore = New_LlamaForCausalLM.from_pretrained(
                    explore_base_model,
                    load_in_8bit=False,
                    torch_dtype=torch.float16,
                    device_map={"": model_device},
                    trust_remote_code=True,
                    exploit_hidden_size=model_hidden_size,
                    explore_logits_factor=explore_logits_factor,
                    alpha=alpha,
                    gpu_id=device_id
                    # attn_implementation="flash_attention_2"
                )

                model_explore.to(model_device)
                model_explore_list.append(model_explore)

    if model.config.model_type == "llama":
        tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True, use_fast=False)
        if explore_flag:
            tokenizer_explore = AutoTokenizer.from_pretrained(explore_base_model, trust_remote_code=True, use_fast=False)
    else:
        tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True, use_fast=False)

    tokenizer.pad_token_id = 0
    tokenizer.padding_side = "right"
    
    if explore_flag:
        tokenizer_explore.pad_token_id = 0
        tokenizer_explore.padding_side = "right"

    def tokenize(prompt, add_eos_token=True):
        # there's probably a way to do this with the tokenizer settings
        # but again, gotta move fast
        result = tokenizer(
            prompt,
            truncation=True,
            max_length=cutoff_len,
            padding=True,
            return_tensors=None,
        )
        if (
                result["input_ids"][-1] != tokenizer.eos_token_id
                and len(result["input_ids"]) < cutoff_len
                and add_eos_token
        ):
            result["input_ids"].append(tokenizer.eos_token_id)
            if "chatglm" not in base_model:
                result["attention_mask"].append(1)

        result["labels"] = result["input_ids"].copy()

        if "chatglm" in base_model:
            return {"input_ids": result["input_ids"], "labels": result["labels"]}
        else:
            return result

    def generate_and_tokenize_prompt(data_point):
        full_prompt = generate_prompt(data_point)
        tokenized_full_prompt = tokenize(full_prompt)
        
        if not train_on_inputs:
            assert train_on_inputs == True 
            user_prompt = generate_prompt({**data_point, "output": ""})
            tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
            user_prompt_len = len(tokenized_user_prompt["input_ids"])

            tokenized_full_prompt["labels"] = [
                                                  -100
                                              ] * user_prompt_len + tokenized_full_prompt["labels"][
                                                                    user_prompt_len:
                                                                    ]  # could be sped up, probably
        return tokenized_full_prompt

    # Prepare model for int8 training if needed
    model = prepare_model_for_int8_training(model, use_gradient_checkpointing=use_gradient_checkpointing)
    if explore_flag:
        for i in range(len(model_explore_list)):
            model_explore_list[i] = prepare_model_for_int8_training(model_explore_list[i], use_gradient_checkpointing=use_gradient_checkpointing)
        
    # ---------------------
    # PEFT Config
    # ---------------------
    if adapter_name == "lora":
        config = LoraConfig(
            r=lora_r,
            lora_alpha=lora_alpha,
            target_modules=target_modules,
            lora_dropout=lora_dropout,
            bias="none",
            task_type="CAUSAL_LM",
        )
        print(f"target_modules for base model: {target_modules}")
        
        if explore_flag:
            target_explore_modules = copy.deepcopy(target_modules)
            target_explore_modules.append("fc_layer")
            if model.config.hidden_size != model_explore.config.hidden_size:
                target_explore_modules.append("match_hidden_states")
            
            print(f"target_modules for explore model: {target_explore_modules}")
            
            config_explore = LoraConfig(
                r=lora_r,
                lora_alpha=lora_alpha,
                target_modules=target_explore_modules,
                lora_dropout=lora_dropout_a,
                bias="none",
                task_type="CAUSAL_LM",
            )

    elif adapter_name == "bottleneck":
        config = BottleneckConfig(
            bottleneck_size=bottleneck_size,
            non_linearity=non_linearity,
            adapter_dropout=adapter_dropout,
            use_parallel_adapter=use_parallel_adapter,
            use_adapterp=use_adapterp,
            target_modules=target_modules,
            scaling=scaling,
            bias="none",
            task_type="CAUSAL_LM",
        )
    elif adapter_name == "prefix-tuning":
        config = PrefixTuningConfig(
            num_virtual_tokens=num_virtual_tokens,
            task_type="CAUSAL_LM",
        )

    model = get_peft_model(model, config)
    if explore_flag:
        for i in range(len(model_explore_list)):
            model_explore_list[i] = get_peft_model(model_explore_list[i], config_explore)
    
    if adapter_name == "prefix-tuning":
        model.to('cuda')

    # ---------------------
    # Load dataset
    # ---------------------
    if data_path.endswith(".json"): 
        data = load_dataset("json", data_files=data_path)
    else:
        data = load_dataset(data_path)


    model.print_trainable_parameters()
    if explore_flag:
        for i in range(len(model_explore_list)):
            model_explore_list[i].print_trainable_parameters()

    if val_set_size > 0:
        train_val = data["train"].train_test_split(
            test_size=val_set_size, shuffle=True, seed=seed
        )
        train_data = (
            train_val["train"]
            .shuffle()
            .map(generate_and_tokenize_prompt, remove_columns=train_val["train"].column_names)
        )
        val_data = (
            train_val["test"]
            .shuffle()
            .map(generate_and_tokenize_prompt, remove_columns=train_val["test"].column_names)
        )
    else:
        train_data = data["train"].shuffle().map(
            generate_and_tokenize_prompt, remove_columns=data["train"].column_names
        )
        val_data = None

    

    # Mark model to not use cache during training
    model.config.use_cache = False
    model.config.output_hidden_states = True
    
    if explore_flag:
        for i in range(len(model_explore_list)):
            model_explore_list[i].config.use_cache = False
            model_explore_list[i].config.output_hidden_states = True
        
    # Keep a special state_dict that only saves the adapter's parameters
    old_state_dict = model.state_dict
    model.state_dict = (
        lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
    ).__get__(model, type(model))
    
    if explore_flag:
        for i in range(len(model_explore_list)):
            old_state_dict_explore = model_explore_list[i].state_dict
            model_explore_list[i].state_dict = (
                lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict_explore())
            ).__get__(model_explore_list[i], type(model_explore_list[i]))


    # Equivalent warmup_steps as we had in the Trainer setup
    warmup_steps = 100

    train_runner(
        model=model,
        model_explore_list=model_explore_list if explore_flag else None,
        tokenizer=tokenizer,
        tokenizer_explore=tokenizer_explore if explore_flag else None,
        train_data=train_data,
        val_data=val_data,
        micro_batch_size=micro_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        num_epochs=num_epochs,
        learning_rate=learning_rate,
        learning_rate_a=learning_rate_a,
        warmup_steps=warmup_steps,
        fp16=fp16,  # from the original TrainingArguments(fp16=True)
        eval_step=eval_step if val_set_size > 0 else 0,
        save_step=save_step,
        output_dir=output_dir,
        explore_output_dir=explore_output_dir,
        device=device,
        val_set_size=val_set_size,
        use_wandb=use_wandb,
        wandb_project=wandb_project,
        wandb_run_name=wandb_run_name,
        ddp=ddp,
        group_by_length=group_by_length,
        resume_from_checkpoint=resume_from_checkpoint,
        explore_flag=explore_flag,
        seed=seed,
        seed_worker=seed_worker,
        topk_logits=topk_logits,
        clip_value=clip_value,
        capacity=capacity,
        alpha=alpha,
        beta=beta,
        model_type=model_type,
        gpus=gpus
    )

    # ---------------------
    # After training, save final model
    # ---------------------
    model.save_pretrained(output_dir)
    if explore_flag:
        explore_output_dir_list = explore_output_dir.split(',')
        for i in range(len(model_explore_list)):
            model_explore_list[i].save_pretrained(explore_output_dir_list[i])
    print("\n If there's a warning about missing keys above, please disregard :)\n")


    
def create_dir(dir_path):
    if not os.path.exists(dir_path):
        os.mkdir(dir_path)
    return

def generate_prompt(data_point):
    if data_point["input"]:
        return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. 

                ### Instruction:
                {data_point["instruction"]}
                
                ### Input:
                {data_point["input"]}
                
                ### Response:
                {data_point["output"]}""" # noqa: E501
    else:
        return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.  

                ### Instruction:
                {data_point["instruction"]}
                
                ### Response:
                {data_point["output"]}""" # noqa: E501

def load_data(file_path) -> list:
    """
    read data from dataset file
    Args:
        args:

    Returns:

    """
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"can not find dataset file : {file_path}")
    json_data = json.load(open(file_path, 'r'))
    return json_data

def create_batch(dataset, batch_size):
    batches = []
    num_batch = len(dataset)//batch_size if len(dataset) % batch_size == 0 else len(dataset)//batch_size + 1
    for i in range(num_batch):
        batch = dataset[i*batch_size: min((i+1)*batch_size, len(dataset))]
        batches.append(batch)
    return batches


def extract_answer(dataset, sentence: str) -> float:
    sentence = sentence.lower()
    if dataset == 'boolq':
        sentence_ = sentence.strip()
        pred_answers = re.findall(r'true|false', sentence_)
        if not pred_answers:
            return ""
        return pred_answers[0]
    elif dataset == 'piqa':
        sentence_ = sentence.strip()
        pred_answers = re.findall(r'solution1|solution2', sentence_)
        if not pred_answers:
            return ""
        return pred_answers[0]
    elif dataset in ['social_i_qa', 'ARC-Challenge', 'ARC-Easy', 'openbookqa']:
        sentence_ = sentence.strip()
        pred_answers = re.findall(r'answer1|answer2|answer3|answer4|answer5', sentence_)
        if not pred_answers:
            return ""
        return pred_answers[0]
    elif dataset == 'hellaswag':
        sentence_ = sentence.strip()
        pred_answers = re.findall(r'ending1|ending2|ending3|ending4', sentence_)
        if not pred_answers:
            return ""
        return pred_answers[0]
    elif dataset == 'winogrande':
        sentence_ = sentence.strip()
        pred_answers = re.findall(r'option1|option2', sentence_)
        if not pred_answers:
            return ""
        return pred_answers[0]

def generate_prompt_for_test(instruction, input=None):
    if input:
        return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

                ### Instruction:
                {instruction}

                ### Input:
                {input}

                ### Response:
                """  # noqa: E501
    else:
        return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. 

                ### Instruction:
                {instruction}

                ### Response:
                """  # noqa: E501

if __name__ == "__main__":
    fire.Fire(train)
