import os, re, glob, math, time, json, pickle
os.environ["CUDA_VISIBLE_DEVICES"] = '0' 
import torch
from datasets import Dataset, DatasetDict
from transformers import (
    AutoModelForCausalLM,
    DataCollatorForLanguageModeling,
    PreTrainedTokenizerFast,
    Trainer,
    TrainingArguments,
    TrainerCallback,
)
from peft import LoraConfig, get_peft_model, TaskType


n_layer = 6
base_model_path = f''

n = 1000
k_ratio = 0.1
p_in = 0.1
p_out = 0.01
method = 'path'
hidden_size = 512
block_size = 128
max_length = 64
train_num_ratio = 0.5
overlap_method = 'partial'  # 'none' or 'full' or 'partial'
model_selected = 3
new_train_ratio = 0.2   

backbone_model = 'llama'  # 'llama' or 'mixtral' or 'qwen'

print(overlap_method)
base_model_path = os.path.join(base_model_path, f"{n}_{k_ratio}_{p_in}_{p_out}")
tokenizer_path = os.path.join(base_model_path, f"baby_tokenizer.json")


output_dir = f'{method}_{n}_{k_ratio}_{p_in}_{p_out}/{backbone_model}_{n_layer}_{hidden_size}_{train_num_ratio}/'


tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_path)  
vocab_size = len(set(tokenizer.get_vocab()))
print("vocab_size =", vocab_size)


with open(os.path.join(f'{base_model_path}',f'{method}_soft_train.pkl'),'rb') as f:
    new_training_corpus = pickle.load(f)

train_num = len(new_training_corpus)
valid_num = int(train_num*0.1)
train_num = int(len(new_training_corpus)*train_num_ratio)
new_training_corpus = new_training_corpus[valid_num:]

ori_train_num = max(1, int(len(new_training_corpus) * train_num_ratio))
new_train_num = max(1, int(len(new_training_corpus) * new_train_ratio))

if overlap_method == 'full':
    new_corpus = new_training_corpus[ori_train_num - new_train_num:ori_train_num]
elif overlap_method == 'partial':
    new_overlap_train_num = max(1, int(len(new_training_corpus) * (new_train_ratio/2)))
    new_corpus = new_training_corpus[ori_train_num - new_overlap_train_num : ori_train_num - new_overlap_train_num + new_train_num]
elif overlap_method == 'none':
    new_corpus = new_training_corpus[ori_train_num:ori_train_num + new_train_num]


print("new_corpus size =", len(new_corpus))

valid_num = max(1, int(len(new_corpus) * 0.1))
valid_corpus = new_corpus[:valid_num]
training_corpus = new_corpus[valid_num:]
print("train/val =", len(training_corpus), len(valid_corpus))

ds = DatasetDict({
    "train": Dataset.from_dict({"text": training_corpus}),
    "validation": Dataset.from_dict({"text": valid_corpus}),
})

def tok_fn(batch):
    return tokenizer(batch["text"])  

def group_texts(examples):
    concat = {k: sum(examples[k], []) for k in examples.keys()}
    total = len(concat["input_ids"]) // block_size * block_size
    result = {
        k: [t[i:i+block_size] for i in range(0, total, block_size)]
        for k, t in concat.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

tokenized = ds.map(tok_fn, batched=True, remove_columns=["text"])
lm_ds = tokenized.map(group_texts, batched=True)

def pick_checkpoint(which=1):
    checkpoints = []
    epoch_steps = []
    for f in os.listdir(output_dir):
        if f.endswith('.pkl'):
            continue
        if 'lora_finetune' in f:
            continue
        if 'ppo' in f:
            continue
        # epoch_steps.append(int(f.split('-')[-1]))
        full = os.path.join(output_dir, f)
        if os.path.isdir(full):
            checkpoints.append(full)
    checkpoints = sorted(checkpoints, key=lambda x: int(x.split('-')[-1]))
    return checkpoints[which]

base_or_ckpt_dir = output_dir
ckpt_path = pick_checkpoint(model_selected)
print("Loading checkpoint:", ckpt_path)


model = AutoModelForCausalLM.from_pretrained(
    ckpt_path,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None,
)
model.config.use_cache = False 

if backbone_model in ["llama", "qwen"]:
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
elif backbone_model == "mixtral":
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
else:
    target_modules = ["c_attn", "c_proj"]

lora_cfg = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules=target_modules,
)

try:
    model.enable_input_require_grads()
except Exception:
    pass

model = get_peft_model(model, lora_cfg)
model.print_trainable_parameters()

collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) 

N = len(lm_ds["train"])
per_device_bs = 64
grad_acc = 1
true_steps_per_epoch = max(math.ceil(N / per_device_bs / grad_acc), 1)
eval_steps = max(true_steps_per_epoch // 5, 1)

ft_output_dir = os.path.join(output_dir, f"lora_finetune_{model_selected}_{overlap_method}_{new_train_ratio}")
os.makedirs(ft_output_dir, exist_ok=True)

training_args = TrainingArguments(
    output_dir=ft_output_dir,
    num_train_epochs=20, 
    per_device_train_batch_size=per_device_bs,
    gradient_accumulation_steps=grad_acc,

    learning_rate=2e-4,
    weight_decay=0.0,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
    logging_steps=10,

    eval_strategy="steps",
    eval_steps=eval_steps,

    save_steps=true_steps_per_epoch,  
    save_total_limit=None,

    bf16=torch.cuda.is_available(),
    gradient_checkpointing=True,
    report_to="none",
)

class PrintInfoCallback(TrainerCallback):
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step == 0:
            return control
        if state.global_step % true_steps_per_epoch != 0:
            return control
        if not state.log_history:
            return control

        last_log = state.log_history[-1]
        current_loss = last_log.get("loss", last_log.get("train_loss"))
        if current_loss is None:
            return control

        gpu_mem = (torch.cuda.memory_allocated() / 1024**3) if torch.cuda.is_available() else 0.0
        print(
            f"[{time.strftime('%H:%M:%S')}] step={state.global_step} "
            f"(~epoch {state.global_step/true_steps_per_epoch:.2f}) "
            f"| loss={current_loss:.4f} | gpu={gpu_mem:.2f} GB"
        )
        return control

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_ds["train"],
    eval_dataset=lm_ds["validation"],
    data_collator=collator,
    callbacks=[PrintInfoCallback()],
)

trainer.train()

model.save_pretrained(ft_output_dir)

train_steps, train_losses = [], []
eval_steps_list, eval_losses = [], []
for entry in trainer.state.log_history:
    if "loss" in entry:
        train_steps.append(entry.get("step", None))
        train_losses.append(entry["loss"])
    if "eval_loss" in entry:
        eval_steps_list.append(entry.get("step", None))
        eval_losses.append(entry["eval_loss"])

loss_dict = {"train": list(zip(train_steps, train_losses)),
             "eval": list(zip(eval_steps_list, eval_losses))}
with open(os.path.join(ft_output_dir, "loss_all.json"), "w") as f:
    json.dump(loss_dict, f, indent=2)

print(" LoRA finetune done. Adapter saved to:", ft_output_dir)


do_merge = False
if do_merge:
    merged = model.merge_and_unload()
    merged_dir = ft_output_dir + "_merged"
    merged.save_pretrained(merged_dir)
    tokenizer.save_pretrained(merged_dir)
    print(" Merged model saved to:", merged_dir)