import os

import wandb 



from datasets import load_dataset, Dataset
import random
import torch 
import numpy as np
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F

import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from torch.distributed.checkpoint.state_dict import (
    get_state_dict,
    set_state_dict,
)

from accelerate import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig

# from accelerate.utils import save_model_on_main_process

from typing import List
import datetime


from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from trl import SFTTrainer
from peft import LoraConfig, get_peft_model
from torch.utils.data.sampler import BatchSampler, RandomSampler
import math
import random
from typing import Iterator, List, Optional, Dict
import torch
import torch.distributed as dist
from torch.utils.data import Sampler, DataLoader



from huggingface_hub import login



from inference.sample.prompt import system_msg_gpqa, user_msg_gpqa, system_msg_csqa, user_msg_csqa, csqa_001, csqa_002
from inference.sample.seq_prompt import seq_msg, seq_msg_hg
import copy



import argparse
import string

import torch.nn as nn
import types, os


### BATCH SAMPLER FOR FSDP ###

class DistributedGroupedMacroBatchSampler(Sampler[List[int]]):
    """
    Yield local batch indices for each rank, while ensuring:
      - At each training step, all ranks draw samples from the SAME group_id.
      - Supports gradient accumulation by emitting 'accum_steps' consecutive global-batches per group chunk.
    """

    def __init__(
        self,
        dataset,
        group_key: str,
        local_batch_size: int,
        accum_steps: int = 1,          # gradient_accumulation_steps
        shuffle_groups: bool = False,
        shuffle_within_group: bool = False,
        seed: int = 0,
        drop_last: bool = True,
        rank: Optional[int] = None,
        world_size: Optional[int] = None,
        updates_per_group: Optional[int] = None,  # use all samples by default
    ):
        if local_batch_size <= 0:
            raise ValueError("local_batch_size must be > 0")
        if accum_steps <= 0:
            raise ValueError("accum_steps must be > 0")

        self.dataset = dataset
        self.group_key = group_key
        self.local_batch_size = local_batch_size
        self.accum_steps = accum_steps
        self.shuffle_groups = shuffle_groups
        self.shuffle_within_group = shuffle_within_group
        self.seed = seed
        self.drop_last = drop_last
        self.epoch = 0
        self.updates_per_group = updates_per_group  # use all samples by default

        if dist.is_available() and dist.is_initialized():
            self.rank = dist.get_rank() if rank is None else rank
            self.world_size = dist.get_world_size() if world_size is None else world_size
        else:
            self.rank = 0 if rank is None else rank
            self.world_size = 1 if world_size is None else world_size

        # group_id -> [indices]
        group_ids = self.dataset[self.group_key]
        self.group_to_indices: Dict[str, List[int]] = {}
        for idx, gid in enumerate(group_ids):
            self.group_to_indices.setdefault(gid, []).append(idx)

        self.global_batch = self.local_batch_size * self.world_size
        self.macro_size = self.global_batch * self.accum_steps  # samples per optimizer step per group chunk

        # sanity (optional)
        if self.drop_last:
            bad = [gid for gid, idxs in self.group_to_indices.items() if len(idxs) < self.global_batch]
            if bad:
                raise ValueError(
                    f"Some groups have fewer samples than global_batch={self.global_batch}. "
                    f"Example: {bad[0]} size={len(self.group_to_indices[bad[0]])}"
                )

    def set_epoch(self, epoch: int):
        self.epoch = int(epoch)

    def __iter__(self):
        rng = random.Random(self.seed + self.epoch)
        group_keys = list(self.group_to_indices.keys())
        if self.shuffle_groups:
            rng.shuffle(group_keys)

        for gid in group_keys:
            idxs = self.group_to_indices[gid].copy()
            if self.shuffle_within_group:
                rng.shuffle(idxs)


            if self.updates_per_group is not None:
                needed = self.updates_per_group * self.macro_size  # macro_size=global_batch*accum_steps
                if len(idxs) < needed:
                    rep = (needed + len(idxs) - 1) // len(idxs)
                    idxs = (idxs * rep)[:needed]   # with replacement(반복)
                else:
                    idxs = idxs[:needed]

            for start in range(0, len(idxs), self.macro_size):
                macro = idxs[start:start+self.macro_size]
                for step in range(0, self.macro_size, self.global_batch):
                    gb = macro[step:step+self.global_batch]
                    r0 = self.rank * self.local_batch_size
                    r1 = (self.rank + 1) * self.local_batch_size
                    local = gb[r0:r1]
                    yield local

    def __len__(self):
        num_groups = len(self.group_to_indices)
        if self.updates_per_group is None:
            raise NotImplementedError
        return num_groups * self.updates_per_group * self.accum_steps


### LOSS FUNCTION ###

class MPTrainer(SFTTrainer):
    def __init__(
        self,
        *args,
        use_pg: bool = True,          # use policy-gradient correction
        beta_pg: float = 0.01,  #0,01        # weight for PG term (tune: 1e-3 ~ 1e-2)
        pg_warmup_steps: int = 0,   # delay PG until baseline stabilizes
        baseline_tau: float = 0.95,   # EMA for baseline
        adv_clip: float = 5.0,        # clamp |advantage|
        temperature: float = 0.9,     # softmax temperature
        # pos_weight: bool = True,
        ema_decay : float = 0.99,
        m_train: int = 20, 
        delta: int = 5,
        backward: bool = False,  # stopgrad on k1 
        group_key = "group_id",
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.use_pg = use_pg
        self.beta_pg = beta_pg
        self.pg_warmup_steps = pg_warmup_steps
        self.baseline_tau = baseline_tau
        self.adv_clip = adv_clip
        self.temperature = temperature
        # self.lambda_lm = lambda_lm
        self._pg_baseline = None      # lazily seeded on first use
        self.ema_decay = ema_decay
        self.target_model = None   # ω
        self.m_train = m_train
        self.delta = delta
        self.backward = backward
        self.group_key = group_key

    def get_train_dataloader(self):
        ds = self.train_dataset
        self.args.gradient_accumulation_steps = 4

        b = self.args.per_device_train_batch_size
        G = self.args.gradient_accumulation_steps
        

        W = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1
        macro_size = b * W * G

        updates_per_group = math.ceil(500 / macro_size)

        batch_sampler = DistributedGroupedMacroBatchSampler(
            dataset=ds,
            group_key="group_id",
            local_batch_size=b,
            accum_steps=G,           
            updates_per_group=updates_per_group,
            shuffle_groups=False,
            shuffle_within_group=True,
            seed=self.args.seed,
            drop_last=False,            
        )

        dataloader = DataLoader(
            ds,
            batch_sampler=batch_sampler,
            collate_fn=self.data_collator,
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.dataloader_pin_memory,
            persistent_workers=self.args.dataloader_num_workers > 0,
        )

        W = dist.get_world_size() if dist.is_initialized() else 1
        b = self.args.per_device_train_batch_size
        G = self.args.gradient_accumulation_steps
        print("W,b,G =", W,b,G)
        print("updates_per_group =", getattr(batch_sampler, "updates_per_group", None))
        print("len(dataloader) microbatches =", len(dataloader))
        print("update_steps_per_epoch =", len(dataloader)//G)

        return dataloader


    # --- your helper preserved ---
    def _logit_pos_indices(self, input_ids, labels, allowed_ids):
        B, L = input_ids.shape
        is_letter = torch.isin(labels, allowed_ids)         # [B, L]
        pos_indices = is_letter[:, 1:]                      # [B, L-1]
        pad_false = torch.zeros((B, 1), dtype=torch.bool, device=input_ids.device)
        pos_indices = torch.cat([pos_indices, pad_false], dim=1)
        return pos_indices

    # --- main loss ---
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]
        output_start = inputs["output_start"]  # [B]
        labels = inputs["labels"]

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # [B, T, V]
        B, L = input_ids.shape



        B, T, V = logits.shape
        device = logits.device

        allowed_ids = self.processing_class(
            ['A','B','C','D','E'], add_special_tokens=False
        )["input_ids"]
        allowed_ids = torch.tensor(allowed_ids, device=device, dtype=torch.long).view(-1)


        
        k1 = int(np.random.randint(low = 2, high = self.m_train - 1))     # seed_len = 1 
        k2 = k1 + int(np.random.randint(low = 1, high = max(3, min(self.m_train - 1 - k1, self.delta)) , size = 1))


        def logits_pos(k):
            pos = output_start + (k - 1) * 2 
            return torch.clamp(pos, min=0, max=T-1)
        
        pos1 = logits_pos(k1) * torch.ones(B, dtype=torch.long, device=device)
        pos2 = logits_pos(k2) * torch.ones(B, dtype=torch.long, device=device)



        l1 = logits[torch.arange(B, device=device), pos1, :] / self.temperature
        l2 = logits[torch.arange(B, device=device), pos2, :] / self.temperature

        # for i in range(10):
        #     l = logits[:, output_start +i , :] / self.temperature
        #     p = F.softmax(l, dim=-1)
        #     p = p.mean(dim = 0)
        #     p /= p.sum().clamp_min(1e-8)
        #     print(f"[DBG] first token + {i} prob: {torch.index_select(p, -1, allowed_ids)}")
        # exit()


        if self.backward:
            # teacher = q1; stopgrad q1
            q2 = F.softmax(l2, dim=-1)
            q2_hat = q2.mean(dim = 0)
            q2_hat /= q2_hat.sum().clamp_min(1e-8)

            if self.target_model is None:
                with torch.no_grad():
                    q1 = F.softmax(l1, dim=-1)
                    q1_hat = q1.mean(dim = 0)
                    q1_hat /= q1_hat.sum().clamp_min(1e-8)
            else:
                with torch.no_grad():
                    out_t = self.target_model(input_ids=input_ids, attention_mask=attention_mask)
                    log_t = out_t.logits
                    l1_t = log_t[torch.arange(B, device=device), pos1, :] / self.temperature
                    q1 = F.softmax(l1_t, dim=-1)
                    q1_hat = q1.mean(dim = 0)
                    q1_hat /= q1_hat.sum().clamp_min(1e-8)
                            # ---- 4) CE: H(q1, q2) = - sum_z q1(z) log q2(z)
            eps = 1e-8
            ce_loss = -(q1_hat * torch.log(q2_hat + eps)).sum()  # [B]

        else:
            # stopgrad 

            q1 = F.softmax(l1, dim=-1)
            q1_hat = q1.mean(dim = 0)
            q1_hat /= q1_hat.sum().clamp_min(1e-8)

            if self.target_model is None:
                with torch.no_grad():
                    q2 = F.softmax(l2, dim=-1)
                    q2_hat = q2.mean(dim = 0)
                    q2_hat /= q2_hat.sum().clamp_min(1e-8)
            else:
                with torch.no_grad():
                    out_t = self.target_model(input_ids=input_ids, attention_mask=attention_mask)
                    log_t = out_t.logits
                    l2_t = log_t[torch.arange(B, device=device), pos2, :] / self.temperature
                    q2 = F.softmax(l2_t, dim=-1)
                    q2_hat = q2.mean(dim = 0)
                    q2_hat /= q2_hat.sum().clamp_min(1e-8)

            # ---- 4) CE: H(q1, q2) = - sum_z q1(z) log q2(z)
            eps = 1e-8
            ce_loss = -(q2_hat * torch.log(q1_hat + eps)).sum()  # [B]
            kl_loss = torch.sum(q2_hat * (q2_hat.clamp_min(eps).log() - q1_hat.clamp_min(eps).log()))* (T * T)

        L_mp = ce_loss



        loss = L_mp
        if self.use_pg and (self.state.global_step >= self.pg_warmup_steps):
            L_scalar = ce_loss.detach()  # current scalar loss; no grad through baseline


            with torch.no_grad():
                p1 = F.softmax(logits[:, output_start + 1, :] / T, dim=-1)         # (B, V)
                H1 = -(p1 * (p1.clamp_min(1e-12)).log()).sum(dim=-1)        # (B,) entropy
                x = H1.mean().detach().item()                               # scalar feature


            if not hasattr(self, "_cv_init"):
                self._cv_init = True
                self._cv_mx  = x
                self._cv_mx2 = x * x
                self._cv_my  = float(L_scalar.item())
                self._cv_mxy = x * float(L_scalar.item())
                self._cv_w   = 0.0   # slope
                self._cv_c   = 0.0   # intercept


            b_val = self._cv_c + self._cv_w * x
            b = torch.tensor(b_val, device=device)



            win_start = output_start[0] + 1
            win_end   = min(pos2[0], L - 1)
            if win_end >= win_start:
                window = slice(win_start, win_end + 1)
                labels_w = labels[:, window].clone()
                valid_mask = (labels_w != -100)
                labels_w = labels_w.masked_fill(~valid_mask, 0)

                logp = torch.log_softmax(logits[:, window, :], dim=-1)  # (B, W, V)
                sel = logp.gather(2, labels_w.unsqueeze(-1)).squeeze(-1)  # (B, W)
                sel = sel * valid_mask.float()
                tokens_per_seq = valid_mask.sum(dim=1).clamp_min(1)
                avg_logp_per_seq = sel.sum(dim=1) / tokens_per_seq
                avg_logp_per_token = avg_logp_per_seq.mean()
            else:
                avg_logp_per_token = torch.tensor(0.0, device=device)

            # advantage & weighted PG term (unchanged)
            adv = (L_scalar - b).clamp(-self.adv_clip, self.adv_clip)
            # print("adv = ", adv, "avg_logp_per_token = ", avg_logp_per_token)
            score_term = self.beta_pg * adv * avg_logp_per_token
            # print("score = ", score_term)

            # --- AFTER forming score_term: update control-variate moments & coeffs for NEXT step ---
            tau = self.baseline_tau
            x_now = x
            y_now = float(L_scalar.item())

            mx  = tau * self._cv_mx  + (1 - tau) * x_now
            mx2 = tau * self._cv_mx2 + (1 - tau) * (x_now * x_now)
            my  = tau * self._cv_my  + (1 - tau) * y_now
            mxy = tau * self._cv_mxy + (1 - tau) * (x_now * y_now)

            varx = max(1e-6, mx2 - mx * mx)
            w = (mxy - mx * my) / varx
            c = my - w * mx

            self._cv_mx, self._cv_mx2, self._cv_my, self._cv_mxy = mx, mx2, my, mxy
            self._cv_w,  self._cv_c  = w,  c

        else:
            score_term = torch.tensor(0.0, device=device)
            
        loss += score_term



        if "group_id" in inputs:
            gids = inputs["group_id"]  
           
            unique = sorted(set(gids))
            ok = (len(unique) == 1)


            try:
                import torch.distributed as dist
                rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
            except Exception:
                rank = 0

            if rank == 0:
                if self.state.global_step < 20:
                    print(f"[dbg step={self.state.global_step}] batch group_id unique={unique} (B={len(gids)}) ok={ok}")



        else:
            try:
                import torch.distributed as dist
                rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
            except Exception:
                rank = 0
            if rank == 0 and self.state.global_step < 5:
                print("[dbg] inputs has no 'group_id'. Add it in collator for debugging.")


        if hasattr(self, "state") and (self.state.global_step % 10 == 0):
            print(f"[DBG] B={B}  ce={ce_loss.detach().item():.6f}   kl={kl_loss.detach().item():.6f}   "
                  f"pg={float(score_term.detach()):.6f} "
                  f"k1={k1} k2={k2} q_k1 = {torch.round(q1_hat[allowed_ids], decimals = 4)} q_k2 = {torch.round(q2_hat[allowed_ids], decimals = 4)}")

        if return_outputs:
            return loss, outputs
        return loss









###### QLORA with FSDP ########

class MPMaker():

    def __init__(self, model_dir, train_arg):
        self.model_dir = model_dir
        self.lr, self.qa_num = train_arg

        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_dir, 
            # quantization_config = bnb_config, 
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
            trust_remote_code=False,
            device_map = None
        )
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
        if self.model_dir == "mistral":
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir, padding_side = "left")

        self.eos_token = self.tokenizer.eos_token
        self.tokenizer.pad_token = self.eos_token

        self.choices_label = list(string.ascii_uppercase[:5])  # A-E
    




    def create_trainset(self, ds, pre_ds, user_msg, system_msg, J, qa, qa_num_list, pre_len: List):
        '''
        Let's simply fix the prompt : instruction + observations (None)
        Trainset for a fixed query; 
        pre_len (n) = {0, 2, 4, 6, 8}
        pre_str (r_n) = {0, 1, 2}  # top-3 plausible prefixes from the baseline
        rollouts per prefix (J) = 800
        '''




        # base_ds = load_dataset("json", data_files = "inference/llama_infer/csqa/base-seed0-pre-0-20260120_225537.jsonl")["train"]
        # base_ds = load_dataset("json", data_files = "inference/llama_infer/csqa/preprocessed_trainset_seed1.jsonl")["train"]
        base_ds = load_dataset("json", data_files = "eval/mistral_train/trainset_seed1-50-32-180.jsonl")["train"]
        base_J = 32

        # qa_num_list = list(range(qa_num_list)) 

        for qa_num in qa_num_list:
        
            # create query format 
            q = ds["question"][qa_num]
            contents = ds["choices"][qa_num]
            prefix_dict = pre_ds[qa_num]
            output_str_list = base_ds[qa_num * base_J : (qa_num + 1) * base_J]["output_str"] 

            q = "QUESTION:\n" + q + "\n\n"
            text = "CHOICES:\n"
            for j, c in enumerate(self.choices_label):
                add = c + " - " + contents['text'][j] + "\n"
                text += add 
            user_content =  user_msg + q + text + "\n\n"

            user_content += "SAMPLE: \n"
            uni_prob = [0.2, 0.2, 0.2, 0.2, 0.2]   # uniform
            # ref_prob = base_ds[qa_num * base_J]["probs"][0]
            
            # ref_prob = 0.3 * np.array(uni_prob) + 0.7 * np.array(ref_prob)
            # ref_prob /= ref_prob.sum()


            for r in range(3):
                for j in range(J):
                    # cid_sample = np.random.choice(self.choices_label, size = 100, p = ref_prob)
                    # output_str = '\n'.join(cid_sample)
                    # pre_output = output_str_list[j]
                    # output_str = [pre_output[i].strip() for i in range(100)]
                    output_str = "\n".join(output_str_list[j])
                    group_id = f"Q{qa_num}|n0"
                    data = {
                        "group_id" : group_id, 
                        "prefix_token_len": 0,
                        "messages": [
                            {"role": "system", "content": system_msg}, 
                            {"role": "user", "content": user_content},  
                            {"role": "assistant", "content": output_str}
                        ]
                    }
                    trainset.append(data)


            for pre_n in pre_len[1:]: 
                pre_str_list = prefix_dict[f"n_{pre_n}"]
                for r in range(3): 
                    pre_str = "\n".join(pre_str_list[r]["seq"]) + "\n"
                    ref_prob = pre_str_list[r]["exp_next_prob"]
                    ref_prob = 0.1 * np.array(uni_prob) + 0.9 * np.array(ref_prob)
                    ref_prob /= ref_prob.sum()
                    for _ in range(J):
                        cid_sample = np.random.choice(self.choices_label, size = 100, p = ref_prob)
                        output_str = '\n'.join(cid_sample)
                        group_id = f"Q{qa_num}|n{pre_n}|r{r}"
                        data = {
                            "group_id" : group_id, 
                            "prefix_token_len": pre_n,
                            "messages": [
                                {"role": "system", "content": system_msg}, 
                                {"role": "user", "content": user_content},  
                                {"role": "assistant", "content": pre_str + output_str}
                            ]
                        }
                        trainset.append(data)


        trainset = Dataset.from_list(trainset)

        # def formatting_prompts_func(examples):
        #     ids = [self.tokenizer.apply_chat_template(convo, tokenize = True, return_tensors = 'pt', add_generation_prompt = False).squeeze(0) for convo in examples["messages"]]
        #     return { "input_ids" : ids, }

        def formatting_prompts_func(examples):
            ids = []
            for convo in examples["messages"]:
                token_ids = self.tokenizer.apply_chat_template(
                    convo,
                    tokenize=True,
                    add_generation_prompt=False
                )
 
                if hasattr(token_ids, "tolist"):
                    token_ids = token_ids.tolist()
                ids.append(token_ids)
            return {"input_ids": ids}

        trainset = trainset.map(
            formatting_prompts_func,
            batched=True,
            remove_columns=["messages"] 
        )
        
        print("✅ Trainset created ! Length: ", len(trainset))

        return trainset
    
    def train(self, J, qa, qa_num, m_train, delta, backward, save_path, pre_len, epoch=2):

        def collator(batch):
            def find_subsequence(sequence, subsequence):
                for i in range(len(sequence) - len(subsequence) + 1):
                    if sequence[i:i+len(subsequence)] == subsequence:
                        return i
                return -1

            input_ids = [torch.tensor(item["input_ids"], dtype=torch.long) for item in batch]
            attention_mask = [torch.ones_like(ids) for ids in input_ids]
            group_ids = [item.get("group_id", "MISSING") for item in batch]

            assist_ids = self.tokenizer(
                "<|start_header_id|>assistant<|end_header_id|>",
                add_special_tokens=False
            )["input_ids"]

            post_header = self.tokenizer("\n\n", add_special_tokens=False)["input_ids"]
            post_len = len(post_header)

            labels = []
            output_start = []

            for item, ids in zip(batch, input_ids):
                idx = find_subsequence(ids.tolist(), assist_ids)
                if idx == -1:
                    raise ValueError("Assistant token not found")

    
                content_start = idx + len(assist_ids)
                pre_tok = int(item.get("prefix_token_len", 0))
                out_st = min(content_start + pre_tok * 2, ids.numel())
                output_start.append(out_st)

                label = ids.clone()
                label[:out_st] = -100          
                labels.append(label)

            input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
            attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
            labels = pad_sequence(labels, batch_first=True, padding_value=-100)

            return {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": labels, 
                "output_start": torch.tensor(output_start, dtype=torch.long),  
                "group_id": group_ids,
            }

        qa_num = self.qa_num
        trainset = self.create_trainset(J, qa, self.qa_num, pre_len = pre_len)
        
        train_args = TrainingArguments(
            output_dir = save_path, 
            per_device_train_batch_size= 4,
            # per_device_eval_batch_size= 1, 
            gradient_accumulation_steps= 4,  # 64
            learning_rate = self.lr, 
            dataloader_num_workers = 8,
            num_train_epochs= epoch, # 3
            warmup_steps= 5,
            weight_decay= 0.01, 
            
            lr_scheduler_type = 'cosine', 
            gradient_checkpointing= False, 
            logging_steps= 1, 
            optim = 'adamw_torch',   # paged_adamw_8bit
            bf16=True,
            seed = 2025,
            report_to = "wandb",
            # report_to = None,
            remove_unused_columns = False
        )
        world_size = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1
        b = train_args.per_device_train_batch_size
        target_J = J  # 500
        train_args.gradient_accumulation_steps = math.ceil(target_J / (b * world_size))

        peft_config = LoraConfig(
            lora_alpha = 8,
            lora_dropout = 0.05, 
            r = 16, 
            bias = 'none', 
            task_type= 'CAUSAL_LM', 
            target_modules= 'all-linear'
            # modules_to_save = ["lm_head", "embed_tokens"]
        )

        if train_args.gradient_checkpointing:
            train_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
            # self.model.gradient_checkpointing_enable()
        
        self.model.config.use_cache = False
        # self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
        # self.model = get_peft_model(self.model, peft_config)

        trainer = MPTrainer(
            model = self.model, 
            processing_class = self.tokenizer, 
            train_dataset= trainset, 
            # compute_metrics = eval_inst,
            args = train_args, 
            data_collator= collator, 
            peft_config = peft_config,
            m_train = m_train,
            delta = delta, 
            backward = backward 
        )

        # trainer.accelerator.print(f"{trainer.model}")

        trainer.train()
        # wandb.finish()
        print('TRAIN FINISHED!')

        # output_dir = train_args.output_dir + '/'  + train_time + '/final_model'

        trainer.save_model(train_args.output_dir)
        # trainer.save_model('train_res/' + output_dir)
        # trainer.save_model("./final_model")
        # trainer.model.save_pretrained("250913")

        # trainer.save_model(save_path)
        print('MODEL SAVED!: ', train_args.output_dir)
        exit()





if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--model_dir", type=str, default="default")
    parser.add_argument("--J", type=int, default=0.01)
    parser.add_argument("--qa", type=str, default="gpqa")
    # parser.add_argument("--qa_num", type=int, default=1)
    parser.add_argument("--qa_num", type=int, nargs='*', default=[0, 1, 2, 3, 4])
    parser.add_argument("--lr", type = float)
    parser.add_argument("--m_train", type = int, default = 20)
    parser.add_argument("--delta", type = int, default = 5)
    parser.add_argument("--seed_num", type = int, default = 2025)
    parser.add_argument("--pre_len", type=int, nargs='*', default=[0, 2, 4, 6])
    parser.add_argument("--epoch", type = int, default = 2)

    
    args = parser.parse_args()

    torch.manual_seed(args.seed_num) # 2025
    import numpy as np 
    np.random.seed(args.seed_num)

    MPSurvey = MPMaker(args.model_dir, [args.lr, args.qa_num])

    # MP tuning

    train_time = str(datetime.datetime.now()).split('.')[0].replace(":","").replace("-","").replace(" ","_")

    qa_num_str = str(args.qa_num)
    # qa_num_str = "".join(str(x) for x in args.qa_num)
    wandb.init(name = qa_num_str + "_lr_" + str(args.lr))

    save_path = "llama_adapter-" + str(args.lr) + '-' + str(args.qa_num) + '-' + train_time 
    MPSurvey.train(args.J, args.qa, args.qa_num, args.m_train, args.delta, False, save_path, args.pre_len, args.epoch)





    