import torch
from transformers import TrainingArguments
from trl import SFTTrainer
from myutils import print_debug
from tqdm import tqdm
from selfcopy_models import store_loramlp_util, visualize_model_named_parameters, get_loramlp_paths, load_lora_model_util
import copy
import os

from selfcopy_models import LoRAMLP, get_online_warmup_postfix
from dataset_loader import generate_random_dataloader, generate_mlp_ios_from_dataloaders, generate_mlp_ios_from_dataloaders_fast
from transformers import LlamaTokenizer, LlamaForCausalLM

import time

#Basic training settings using trl.SFT . Longer training and a different learning-rate scheduling will probably give better results.
def train(model, tokenizer, dataset_train, dataset_val, max_tokens=256, batch_size=2, batch_size_val=2, lr=1e-4, n_epochs=1, lr_scheduler_type='linear', verbose=True, is_bf16=False, grad_acc=2):
	grad_acc   = grad_acc # should not be too large, otherwise the model will not learn well. You can increase batch size
	
	max_steps  = -1 # max steps -> -1 means train forever
 
	steps_per_epoch = len(dataset_train) // (batch_size * grad_acc)
	eval_steps = int(steps_per_epoch * 0.1)
	logging_st = min(1000, eval_steps)
 
	training_args = TrainingArguments(
	    output_dir='.',	
	    per_device_train_batch_size=batch_size,
	    per_device_eval_batch_size=batch_size_val,
	    gradient_accumulation_steps=grad_acc,
	    learning_rate=lr,
	    logging_steps=logging_st,
	    num_train_epochs=n_epochs,
	    max_steps=max_steps,
	    # evaluation_strategy = "epoch", # evaluation_strategy -> means we evaluate after each epoch
        bf16=is_bf16,
	    fp16=not is_bf16,
	    max_grad_norm=1.0,
	    save_steps=10000000,
	    lr_scheduler_type= lr_scheduler_type,
		report_to = 'none', # disable wandb,
	)

	if(verbose==False): 
		training_args.logging_strategy = "epoch"
		training_args.evaluation_strategy = "epoch", # evaluation_strategy -> means we evaluate after each epoch
	else:
		training_args.logging_strategy = "steps"
		training_args.evaluation_strategy = "steps", # evaluation_strategy -> means we evaluate after each epoch
		training_args.eval_steps = eval_steps

	trainer = SFTTrainer(
	    model=model,
	    tokenizer=tokenizer,
	    max_seq_length=max_tokens,
	    train_dataset=dataset_train,
	    eval_dataset=dataset_val,
	    peft_config=None,
	    args=training_args,
	    dataset_text_field="text",
	)

	model.train()
	trainer.train()
 
	return model

# self define a lambda function for the lr_scheduler, it has a 2000 steps warmup stage, and decay linearly. Finally it has 10% of the peak learning rate.

def my_get_linear_schedule_with_warmup(optimizer, num_training_steps, num_warmup_steps=2000, decay_factor=0.1, last_epoch=-1):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))

        decay_steps = max(1, num_training_steps - num_warmup_steps)
        return decay_factor + (1.0 - decay_factor) * max(0, float(num_training_steps - current_step) / float(decay_steps))
        
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)


def train2(model, tokenizer, dataset_train, dataset_val, max_tokens=256, batch_size=2, batch_size_val=2, lr=1e-4, n_epochs=1, lr_scheduler_type='linear', verbose=True, is_bf16=False, grad_acc=2):
	print('!!!!!!!!!!!!!!! train2 !!!!!!!!!!!!!!!!!')
	
	max_steps  = -1 # max steps -> -1 means train forever
 
	steps_per_epoch = len(dataset_train) // (batch_size * grad_acc)
	eval_steps = int(steps_per_epoch * 0.1)
	logging_st = min(1000, eval_steps)
 
	training_args = TrainingArguments(
	    output_dir='.',	
	    per_device_train_batch_size=batch_size,
	    per_device_eval_batch_size=batch_size_val,
	    gradient_accumulation_steps=grad_acc,
	    learning_rate=lr,
	    logging_steps=logging_st,
	    num_train_epochs=n_epochs,
	    max_steps=max_steps,
	    # evaluation_strategy = "epoch", # evaluation_strategy -> means we evaluate after each epoch
        bf16=is_bf16,
	    fp16=not is_bf16,
	    max_grad_norm=1.0,
	    save_steps=10000000,
	    lr_scheduler_type= lr_scheduler_type,
		report_to = 'none', # disable wandb,
	)

	if(verbose==False): 
		training_args.logging_strategy = "epoch"
		training_args.evaluation_strategy = "epoch", # evaluation_strategy -> means we evaluate after each epoch
	else:
		training_args.logging_strategy = "steps"
		training_args.evaluation_strategy = "steps", # evaluation_strategy -> means we evaluate after each epoch
		training_args.eval_steps = eval_steps

    ############## new here
    
    # Define the optimizer manually
	optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        betas=(0.9, 0.95),  # These are the default values; change if needed
        eps=1e-5,
        weight_decay=1e-1,
    )
    
    # Calculate total training steps
	total_steps = steps_per_epoch * n_epochs
    

    # Create a learning rate scheduler with warmup
	scheduler = my_get_linear_schedule_with_warmup(
        optimizer,
        num_training_steps=total_steps,
        num_warmup_steps=min(2000, total_steps // 10),
        decay_factor=0.1,
    )
 
	trainer = SFTTrainer(
	    model=model,
	    tokenizer=tokenizer,
	    max_seq_length=max_tokens,
	    train_dataset=dataset_train,
	    eval_dataset=dataset_val,
	    peft_config=None,
	    args=training_args,
	    dataset_text_field="text",
        optimizers=(optimizer, scheduler)
	)

	model.train()
	trainer.train()
 
	return model
 



def L2_loss(outputs, outputs_teacher):
    return torch.norm(outputs - outputs_teacher, p=2, dim=-1).mean()       

@torch.no_grad()
def evaluate_model(student_mlp, teacher_mlp, dataloader, criterion, input_process = lambda x: x, device = 'cpu'):
    student_mlp.eval()
    teacher_mlp.eval()
    
    
    with torch.no_grad():
        total_loss = 0.0
        total_y_norm = 0.0
        total_l2_norm = 0.0
        flag = 0
        for data in tqdm(dataloader, desc='evaluating'):
            input = input_process(data)
            input = input.to(device)
   
            outputs = student_mlp(input)
            outputs_teacher = teacher_mlp(input)
            
            loss = criterion(outputs, outputs_teacher)
            
            total_loss += loss.item()
            total_y_norm += torch.norm(outputs_teacher, p=2, dim=1).mean().item()
            total_l2_norm += torch.norm(outputs-outputs_teacher, p=2).item()
            
            if flag==0:
                
                print_debug(f'input = {torch.norm(input, p=2, dim=1).mean().item()}') # 0.93
                print_debug(f'outputs_teacher = {torch.norm(outputs_teacher, p=2, dim=1).mean().item()}') # 0.026 at the end
                print_debug(f'outputs_teacher - input = {torch.norm(outputs_teacher - input, p=2, dim=1).mean().item()}') # 0.93
                
                print_debug(f'outputs = {torch.norm(outputs, p=2, dim=1).mean().item()}') # 0.025 at the end
                print_debug(f'outputs - outputs_teacher = {torch.norm(outputs - outputs_teacher, p=2, dim=1).mean().item()}') # 0.028 at the end
                
                # tmp_mlp = model.model.layers[target_layer].mlp.to(student_cuda)
                # outputs_tmp = tmp_mlp(input)
                # print_debug(f'outputs_teacher - outputs_tmp = {torch.norm(outputs_teacher - outputs_tmp, p=2, dim=1).mean().item()}') # 0.26 at the end
                
                flag += 1
            # torch.norm(outputs_teacher - input) = 4.5, torch.norm of input, outputs, outputs_teacher are 4.5, 0.36, 0.38 after 1 epoch for type 0
            # it's much larger than using weights from neighbor. So using weights from neighbor is necessary.
        
    avg_loss, avg_y_norm, avg_l2_norm = total_loss / len(dataloader), total_y_norm / len(dataloader), total_l2_norm / len(dataloader)
    print(f'avg_loss: {avg_loss}, avg_y_norm: {avg_y_norm}, avg_l2_norm: {avg_l2_norm}')
    return avg_loss, avg_y_norm

@torch.no_grad()
def online_evaluate_model(student_mlp, dataloader, criterion, input_process = lambda x: x, device = 'cpu'):
    student_mlp.eval()
    
    with torch.no_grad():
        total_loss = 0.0
        flag = 0
        for input_data, label_data in tqdm(dataloader, desc='evaluating'):
            input = input_process(input_data).to(device)
            student_output = student_mlp(input)
            loss = criterion(student_output, label_data.to(device))
            
            total_loss += loss.item()
    
    return total_loss / len(dataloader)

            # if flag==0:
                
            #     print_debug(f'input = {torch.norm(input, p=2, dim=1).mean().item()}') # 0.93
            #     print_debug(f'outputs_teacher = {torch.norm(outputs_teacher, p=2, dim=1).mean().item()}') # 0.026 at the end
            #     print_debug(f'outputs_teacher - input = {torch.norm(outputs_teacher - input, p=2, dim=1).mean().item()}') # 0.93
                
            #     print_debug(f'outputs = {torch.norm(outputs, p=2, dim=1).mean().item()}') # 0.025 at the end
            #     print_debug(f'outputs - outputs_teacher = {torch.norm(outputs - outputs_teacher, p=2, dim=1).mean().item()}') # 0.028 at the end
                
            #     # tmp_mlp = model.model.layers[target_layer].mlp.to(student_cuda)
            #     # outputs_tmp = tmp_mlp(input)
            #     # print_debug(f'outputs_teacher - outputs_tmp = {torch.norm(outputs_teacher - outputs_tmp, p=2, dim=1).mean().item()}') # 0.26 at the end
                
            #     flag += 1
            # # torch.norm(outputs_teacher - input) = 4.5, torch.norm of input, outputs, outputs_teacher are 4.5, 0.36, 0.38 after 1 epoch for type 0
            # # it's much larger than using weights from neighbor. So using weights from neighbor is necessary.
            
    
# def get_online_warmup_postfix(subset_size, subset_idx):
#     return f'{(subset_idx + 1) * subset_size}'

def online_train_loramlp(
    student_loramlp, 
    
    train_loader, 
    eval_loader, 
    
    optimizer,
    scheduler = None, 
    criterion = L2_loss,
    
    num_epochs = 1,
	eval_freq = 1,
    store_model_dir = '',
    
    ref_layer = 2,
    target_layer = 3,
    input_process = lambda x: x,
    
    subset_size = 3000,
    subset_idx = 0,
    final_postfix = 0
):
    print(f'Len of train_loader: {len(train_loader)}, Len of eval_loader: {len(eval_loader)}')
    # if len(train_loader) > 0 and len(eval_loader) > 0:
    #     print(f'len of train_loader[0]: {len(train_loader[0])}, len of eval_loader[0]: {len(eval_loader[0])}')
        
    device = student_loramlp.device
    best_eval_loss = online_evaluate_model(student_loramlp, eval_loader, criterion, input_process, device)
    init_eval_loss = best_eval_loss
    print(f'\n########## initial eval loss: {best_eval_loss} ##########\n')
    
    for epoch in range(num_epochs):
        print(f'=> Epoch: {epoch}')
        
        student_loramlp.train()
        for input_data, label_data in tqdm(train_loader, desc='training'):
            input_data = input_process(input_data).to(device)
            label = label_data.to(device)
            
            optimizer.zero_grad()
            
            output = student_loramlp(input_data)
            loss = criterion(output, label)
            
            # print(f'shape of input_data: {input_data.shape}, shape of label: {label.shape}')
            # visualize_model_named_parameters(student_loramlp, print_grad=False, prefix='stu')
            # print(f'loss: {loss.item()}')
            # print(f'output shape: {output.shape}, mean: {output.mean().item()}, label mean: {label.mean().item()}, input mean: {input_data.mean().item()}')
            
            loss.backward()
            optimizer.step()
            
        # scheduler
        if scheduler is not None:
            scheduler.step()
            print(f'current lr: {scheduler.get_last_lr()}')
            
        if epoch % eval_freq == 0 or epoch == num_epochs - 1:
            eval_loss = online_evaluate_model(student_loramlp, eval_loader, criterion, input_process, device)
            
            if epoch == num_epochs - 1:
                visualize_model_named_parameters(student_loramlp, print_grad=True, prefix='')
                
            print(f'\n########## eval loss: {eval_loss}, improve ratio: {(init_eval_loss - eval_loss) / init_eval_loss} ##########\n')
            
            if eval_loss < best_eval_loss:
                best_eval_loss = eval_loss
                note = '' if final_postfix == 1 else get_online_warmup_postfix(subset_size, subset_idx)
                
                store_loramlp_util(student_loramlp, ref_layer, target_layer, store_model_dir, note=note)


def check_finish_training_by_loss(best_loss, eval_loss, init_loss, threshold=0.01, check_type=0):
    if check_type == 0:
        if eval_loss < best_loss:
            return 1 # continue training
        else:
            return 0 # stop training
    elif check_type == 1:
        if eval_loss < best_loss:
            # improve_ratio = (init_loss - eval_loss) / init_loss
            # prev_improve_ratio = (init_loss - best_loss) / init_loss
            # if (improve_ratio - prev_improve_ratio) > threshold:
            #     return 1 # continue training
            # else:
            #     return -1 # stop training
            
            relative_improve_ratio = (best_loss - eval_loss) / best_loss    

            if relative_improve_ratio > threshold:
                return 1 # continue training
            else:   
                return -1 # stop training
            
        else:
            return 0 # stop training
        

def train_loramlp(
    teacher_mlp, 
    student_loramlp, 
    
    train_loader, 
    eval_loader, 
    
    optimizer,
    scheduler = None, 
    criterion = L2_loss,
    
    num_epochs = 1,
	eval_freq = 1,
    store_model_dir = '',
    
    ref_layer = 2,
    target_layer = 3,
    input_process = lambda x: x
):
    
    teacher_mlp.eval()
    device = teacher_mlp.gate_proj.weight.device
    
    # baseline_eval_loss = evaluate_model(teacher_mlp, teacher_mlp, eval_loader, criterion, input_process)
    # print(f'baseline eval loss: {baseline_eval_loss}')
    
    best_eval_loss, y_norm = evaluate_model(student_loramlp, teacher_mlp, eval_loader, criterion, input_process, device)
    init_eval_loss = best_eval_loss
    
    print(f'\n########## initial eval loss: {best_eval_loss}, y_norm: {y_norm} ##########\n')
    
    for epoch in range(num_epochs):
        print(f'=> Epoch: {epoch}')
        
        student_loramlp.train()
        for data in tqdm(train_loader, desc='training'):
            input = input_process(data).to(device)
            
            optimizer.zero_grad()
            
            outputs = student_loramlp(input)
            outputs_teacher = teacher_mlp(input)
            
            loss = criterion(outputs, outputs_teacher)
            
            loss.backward()
            optimizer.step()
            
        # scheduler
        if scheduler is not None:
            scheduler.step()
            print(f'current lr: {scheduler.get_last_lr()}')
            
        if epoch % eval_freq == 0 or epoch == num_epochs - 1:
            eval_loss, y_norm = evaluate_model(student_loramlp, teacher_mlp, eval_loader, criterion, input_process, device)
            
            if epoch == num_epochs - 1:
                visualize_model_named_parameters(student_loramlp, print_grad=True, prefix='')
                
            print(f'\n########## eval loss: {eval_loss}, y_norm {y_norm}, improve ratio: {(init_eval_loss - eval_loss) / init_eval_loss} ##########\n')
            
            # if eval_loss < best_eval_loss:
            check_final = check_finish_training_by_loss(best_eval_loss, eval_loss, init_eval_loss, threshold=0.01)
            if check_final != 0:
                best_eval_loss = eval_loss
                store_loramlp_util(student_loramlp, ref_layer, target_layer, store_model_dir)
                
                if check_final == -1:
                    print(f'\n########## small improve -> early stop ##########')
                    break
            else:
                print(f'\n########## overfit -> early stop ##########')
                break
            
    return best_eval_loss, y_norm


def get_optimizers(lr, model, op_type=0, weight_decay=0.1, warmup_steps=20, decay_factor=0.1, total_steps=100):
    print(f'>>>>>>>>>>>>>>>>  get_optimizers: op_type={op_type}, weight_decay={weight_decay}, warmup_steps={warmup_steps}, decay_factor={decay_factor}, total_steps={total_steps}')
    if op_type == 0:
        optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
    elif op_type == 1:
        optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
    elif op_type == 2:
        optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay)
        scheduler = my_get_linear_schedule_with_warmup(optimizer, total_steps, warmup_steps, decay_factor)
    elif op_type == -1:
        optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
        # constant learning rate
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda x: 1.0)
    elif op_type == -2:
        optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay)
        # constant learning rate
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda x: 1.0)
        
    return optimizer, scheduler



def online_mlp_warmup(
    model_dir,
    device, 
    
    ref_layer,
    target_layer,
    rank,
    forward_type,
    # dataset
    train_loader,
    eval_loader,
    # train
    lr,
    num_epochs,
    eval_freq,
    store_model_dir,
    
    subset_size = 3000,
    subset_idx = 0,
    
    final_postfix = 0
):  
    model = LlamaForCausalLM.from_pretrained(model_dir)
    
    ref_gate = model.model.layers[ref_layer].mlp.gate_proj.weight.data
    ref_up = model.model.layers[ref_layer].mlp.up_proj.weight.data
    ref_down = model.model.layers[ref_layer].mlp.down_proj.weight.data
    
    #### STEP1: load a student model. If subset_idx == 0, then initialize a new model. Otherwise, load the existing model (postfix for subset_idx-1)
    max_available_prev_subset_idx = subset_idx-1
    
    student_mlp = None
    
    note = '' if final_postfix == 1 else get_online_warmup_postfix(subset_size, max_available_prev_subset_idx)
    
    while max_available_prev_subset_idx >= 0:
        student_mlp = load_lora_model_util(
                gate_weights=ref_gate,
                up_weights=ref_up,
                down_weights=ref_down,
                rank=rank,
                forward_type=forward_type,
                device=device,
                ref_layer=ref_layer,
                target_layer=target_layer,
                store_model_dir=store_model_dir,
                note=note,
                strict=False
            )
        
        if student_mlp is not None:
            print(f'=== load the existing student model ===')
            break
        else:
            max_available_prev_subset_idx -= 1
        
        # if 
    
    if student_mlp is None:
        print(f'=== create a new student model ===')
        # the only difference between online_mlp_warmup and mlp_warmup is that 1/we don't need teacher model, 2/we already have dataloaders
        
            
        
        
        student_mlp = LoRAMLP(
            ref_gate, 
            ref_up, 
            ref_down, 
            
            rank=rank,
            forward_type=forward_type,
            device=device
        )
        
    visualize_model_named_parameters(student_mlp, print_grad=False, prefix='stu init')
    
    ##### STEP2: define the optimizer, scheduler, and criterion #####
    # optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, student_mlp.parameters()), lr=lr)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
    
    optimizer, scheduler = get_optimizers(lr, student_mlp, op_type=0)
    # optimizer, scheduler = get_optimizers(lr, student_mlp, op_type=1)
    # optimizer, scheduler = get_optimizers(lr, student_mlp, op_type=-1)
    # optimizer, scheduler = get_optimizers(lr, student_mlp, op_type=-2, weight_decay=1e-2)
    criterion = L2_loss
    
    ##### STEP3: train the student mlp #####
    online_train_loramlp(
        student_mlp,
        train_loader,
        eval_loader,
        
        optimizer,
        scheduler = scheduler, 
        criterion = criterion,
        
        num_epochs = num_epochs,
        eval_freq = eval_freq,
        store_model_dir = store_model_dir,
        
        ref_layer = ref_layer,
        target_layer = target_layer,
        input_process = lambda x: x,
        
        subset_size = subset_size,
        subset_idx = subset_idx,
        final_postfix = final_postfix
    )
    
    ##### STEP5: clean CUDA memory #####
    del model, student_mlp, ref_gate, ref_up, ref_down
    torch.cuda.empty_cache()
        


def mlp_warmup(
    model_dir,
    device, 
    
    ref_layer,
    target_layer,
    rank,
    
    forward_type,
    # dataset
    dataset_type,
    random_size,
    batch_size,
    batch_size_val,
    # train
    lr,
    num_epochs,
    eval_freq,
    store_model_dir,
    store_data_path
):
    ##### STEP0: load the model #####
    model = LlamaForCausalLM.from_pretrained(model_dir).to(device)
    
    ##### STEP1: define student and teacher layers #####
    device = model.model.layers[target_layer].mlp.gate_proj.weight.device
    print(f'mlp_warmup device: {device}')
    
    teacher_mlp = copy.deepcopy(model.model.layers[target_layer].mlp)
    teacher_mlp = teacher_mlp.to(device)
    teacher_mlp.eval() 
    
    # teacher mlp -> no gradient
    for param in teacher_mlp.parameters():
        param.requires_grad = False
        
    ref_gate = model.model.layers[ref_layer].mlp.gate_proj.weight.data
    # ref_gate = model.model.layers[target_layer].mlp.gate_proj.weight.data
    # print(f'!! Here is a large difference!!') # -> worse than use ref_layer. 
    
    ref_up = model.model.layers[ref_layer].mlp.up_proj.weight.data
    ref_down = model.model.layers[ref_layer].mlp.down_proj.weight.data
    
    student_mlp = LoRAMLP(
        ref_gate, 
        ref_up, 
        ref_down, 
        
        rank=rank,
        forward_type=forward_type,
        device=device
    )
    
    visualize_model_named_parameters(student_mlp, print_grad=False, prefix='stu init')
    
    
    ##### STEP3: define the dataloader #####
    train_loader, eval_loader = generate_random_dataloader(
        dataset_type=dataset_type,
        random_size=random_size,
        dim=ref_gate.size(1),
        batch_size=batch_size,
        batch_size_val=batch_size_val,
        store_data_dir=store_data_path,
        input_layer=target_layer
    )
    
    ##### STEP2: define the optimizer, scheduler, and criterion #####
    # optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, student_mlp.parameters()), lr=lr)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
    
    optimizer, scheduler = get_optimizers(lr, student_mlp, op_type=0)
    # optimizer, scheduler = get_optimizers(lr, student_mlp, op_type=1)
    # optimizer, scheduler = get_optimizers(lr, student_mlp, op_type=-1)
    # optimizer, scheduler = get_optimizers(lr, student_mlp, op_type=-2, weight_decay=1e-2)
    
    # optimizer, scheduler = get_optimizers(lr, student_mlp, op_type=2, total_steps=len(train_loader) * num_epochs, warmup_steps=20, decay_factor=0.1, weight_decay=1e-2)
    []
    criterion = L2_loss

    ##### STEP4: train the student mlp #####
    best_eval_loss, y_norm = train_loramlp(
        teacher_mlp, 
        student_mlp, 
        
        train_loader, 
        eval_loader, 
        
        optimizer,
        scheduler = scheduler, 
        criterion = criterion,
        
        num_epochs = num_epochs,
        eval_freq = eval_freq,
        store_model_dir = store_model_dir,
        
        ref_layer = ref_layer,
        target_layer = target_layer,
        input_process = lambda x: x
    )
    
    ##### STEP5: clean CUDA memory #####
    del model, teacher_mlp, student_mlp
    torch.cuda.empty_cache()
    
    return best_eval_loss, y_norm
    
    
    
    

def lorallama_warmup(
    model_dir,
    device,
    info_tuples, # only 1 stage,
    forward_type,
    # dataset
    dataset_type,
    random_size,
    batch_size,
    batch_size_val,
    # train
    lr,
    num_epochs,
    eval_freq,
    store_model_dir,
    store_data_path
):
    best_eval_loss_list = []
    y_norm_list = []
    
    for info in info_tuples:
        ref_layer, target_layer, rank = info
        print(f'lorallama_warmup: ref_layer={ref_layer}, target_layer={target_layer}, rank={rank}')
        
        file_path = get_loramlp_paths(ref_layer, target_layer, store_model_dir)
        
        if os.path.exists(file_path):
            print(f'!! loramlp exists in {file_path}, skip training')
            continue
        else:
            print(f'!! loramlp does not exist, start training and stored in {file_path}')
            best_eval_loss, y_norm = mlp_warmup(
                model_dir,
                device,
                ref_layer,
                target_layer,
                rank,
                
                forward_type,
                # dataset
                dataset_type,
                random_size,
                batch_size,
                batch_size_val,
                # train
                lr,
                num_epochs,
                eval_freq,
                store_model_dir,
                store_data_path
            )

            best_eval_loss_list.append(best_eval_loss)
            y_norm_list.append(y_norm)
        
        print(f'=== lorallama_warmup: ref_layer={ref_layer}, target_layer={target_layer}, rank={rank} is done, best_eval_loss={best_eval_loss}, y_norm={y_norm} ===')
    
    # print the list
    print(f'best_eval_loss_list: {best_eval_loss_list}')
    print(f'y_norm_list: {y_norm_list}')
    
    if len(best_eval_loss_list) > 1:
        print(f'average best_eval_loss: {sum(best_eval_loss_list) / len(best_eval_loss_list)}')
        print(f'average y_norm: {sum(y_norm_list) / len(y_norm_list)}')
    
    

        
        
def init_teacher_model(model_dir, teacher_device):
    teacher_model = LlamaForCausalLM.from_pretrained(model_dir).to(teacher_device)
    teacher_model.eval() 
    
    # teacher mlp -> no gradient
    for param in teacher_model.parameters():
        param.requires_grad = False
    
    return teacher_model

def parallel_online_warmup(
    # args for 
    model_dir,
    device,
    teacher_device,
    info_tuples, # only 1 stage,
    forward_type,
    
    # online loading
    dataloader_generator, # iterator
    start_layer_idx,
    end_layer_idx,
    # train args in each subset
    warmup_batch_size,
    lr,
    num_epochs,
    eval_freq,
    store_model_dir,
    
    subset_size = 3000,
):
    # the outer loop is for each dataset subset
    # the inner loop is for each info_tuple (student mlp)
    
    
    B = 0
    
    final_postfix = 1 if len(dataloader_generator) == 1 else 0
    print(f'final_postfix: {final_postfix}, dataset_generator: {dataloader_generator}')
    
        
    for train_loader, test_loader in dataloader_generator:
        start_time = time.time()
        
        if final_postfix:
            model_postfix = ''
        else:
            model_postfix = get_online_warmup_postfix(subset_size, B)
            
        
        # check if the last info has been trained!
        last_ref, last_target, rank = info_tuples[-1]
        print(f'last_ref={last_ref}, last_target={last_target}')
        file_path = get_loramlp_paths(last_ref, last_target, store_model_dir, note=model_postfix)
        
        if os.path.exists(file_path):
            print(f'!! loramlp exists in {file_path}, which is the last file in {model_postfix}, skip this loader!')
        else:
            print(f'======================= B:{B} --- tr: {len(train_loader)}, te: {len(test_loader)} ==========================')
            teacher_model = init_teacher_model(model_dir, teacher_device)
            train_mlp_ios_dataloaders = generate_mlp_ios_from_dataloaders_fast(train_loader, teacher_model, warmup_batch_size, teacher_device, start_layer_idx, end_layer_idx, is_val=False)
            test_mlp_ios_dataloaders = generate_mlp_ios_from_dataloaders_fast(test_loader, teacher_model, warmup_batch_size, teacher_device, start_layer_idx, end_layer_idx, is_val=True)
            del teacher_model
            torch.cuda.empty_cache()
            
            # once we get the mlp ios dataloader, we should train all the student mlps in the info_tuples. Support both sequential training and parallel training.
            
            for info in info_tuples:
                ref_layer, target_layer, rank = info
                print(f'online_warmup: ref_layer={ref_layer}, target_layer={target_layer}, rank={rank}')
                
                file_path = get_loramlp_paths(ref_layer, target_layer, store_model_dir, note=model_postfix)
                
                if os.path.exists(file_path):
                    print(f'!! loramlp exists in {file_path}, skip training')
                    continue
                else:
                    print(f'!! loramlp does not exist, start training and stored in {file_path}')
                    online_mlp_warmup(
                        model_dir,
                        device, 
                        
                        ref_layer,
                        target_layer,
                        rank,
                        forward_type,
                        # dataset
                        train_mlp_ios_dataloaders[target_layer],
                        test_mlp_ios_dataloaders[target_layer],
                        # train
                        lr,
                        num_epochs,
                        eval_freq,
                        store_model_dir,
                        
                        subset_size = subset_size,
                        subset_idx = B,
                        final_postfix = final_postfix
                    )
                
                print(f'=== online_warmup: ref_layer={ref_layer}, target_layer={target_layer}, rank={rank} is done ===')

            del train_mlp_ios_dataloaders, test_mlp_ios_dataloaders
            # clean gpu memory
            torch.cuda.empty_cache()
            
            
        B += 1
        
        print(f'>>>>>>>>>>>>>>>>>>>>>> Finish B:{B} --- time: {time.time() - start_time} ==========================')
        
        

        