############## import libraries ################
from transformers import LlamaTokenizer, LlamaForCausalLM
import os
import torch

import argparse

from myutils import visualize_output, calculate_ppl, set_seed
from selfcopy_models import LoRALlama, InfoTupleManager

from dataset_loader import get_datasets_for_SFT, DatasetManager, get_train_val_dataloader_from_texts

from mytraining import train, lorallama_warmup, parallel_online_warmup, get_online_warmup_postfix, train2



set_seed(42)



parser = argparse.ArgumentParser(description='llama test SFT')
# model
parser.add_argument('--model_dir', type=str, default='xxx/llama/llama-2-7b-hf', help='model directory')
parser.add_argument('--forward_type', type=int, default=0, help='forward_type')
parser.add_argument('--lr', type=float, default=1e-3, help='lr')
parser.add_argument('--n_epochs', type=int, default=2, help='n_epochs')
parser.add_argument('--lr_scheduler_type', type=str, default='linear', help='lr_scheduler_type')
parser.add_argument('--batch_size', type=int, default=1, help='batch_size')
parser.add_argument('--batch_size_val', type=int, default=1, help='batch_size_val')
parser.add_argument('--use_stored', type=int, default=0, help='use_stored')
parser.add_argument('--is_bf16', type=int, default=1, help='is_bf16')

parser.add_argument('--use_stored_dir', type=str, default=None, help='path to load the model that need to be used for SFT')
parser.add_argument('--original_skip_size', type=int, default=0, help='original_skip_size for fineweb')

parser.add_argument('--grad_acc', type=int, default=2, help='gradient accumulation')

parser.add_argument('--if_store_baseline', type=int, default=0, help='if store the baseline model')
parser.add_argument('--if_replace_self_attn', type=int, default=0, help='if replace the self attention layer with the self sora layer')

# dataset
parser.add_argument('--downsample_rate', type=float, default=0.01, help='downsample_rate')
parser.add_argument('--max_tokens', type=int, default=512, help='max_tokens')
parser.add_argument('--dataset_name', type=str, default='arxiv-math', help='dataset_name, support :: for concatenation')
parser.add_argument('--test_type', type=str, default='arxiv-math', help='test_type')

# info tuples
parser.add_argument('--info_tuples_type', type=int, default=0, help='info_tuples_type')
parser.add_argument('--rank', type=int, default=5, help='rank')
parser.add_argument('--start_layer', type=int, default=2, help='start_layer')
parser.add_argument('--end_layer', type=int, default=30, help='end_layer')
parser.add_argument('--step', type=int, default=2, help='layer stride')

# subtrain
parser.add_argument('--if_subtrain', type=int, default=0, help='if_subtrain')
parser.add_argument('--subtrain_type', type=int, default=0, help='subtrain_type')
parser.add_argument('--subtrain_downsample_rate', type=float, default=0.01, help='subtrain_downsample_rate')
parser.add_argument('--subtrain_use_stored', type=int, default=0, help='subtrain_use_stored')

# warmup
parser.add_argument('--if_warmup', type=int, default=0, help='warmup')
parser.add_argument('--dataset_type', type=int, default=2, help='dataset_type')
parser.add_argument('--random_size', type=int, default=100000, help='random_size')
parser.add_argument('--warmup_batch_size', type=int, default=500, help='batch size')

parser.add_argument('--warmup_text_batch_size', type=int, default=4, help='batch size')

parser.add_argument('--warmup_weight_decay', type=float, default=0.01, help='weight decay for warmup')

parser.add_argument('--warmup_lr', type=float, default=1e-4, help='lr')
parser.add_argument('--warmup_n_epochs', type=int, default=1, help='n_epochs')
parser.add_argument('--warmup_eval_freq', type=int, default=1, help='eval_freq')
parser.add_argument('--store_data_path', type=str, default='xxx/llama_researcher/', help='store_data_path')
parser.add_argument('--warmup_dataset_names', type=str, default='arxiv-math', help='warmup_dataset_names')
parser.add_argument('--warmup_downsample_rate', type=float, default=0.1, help='warmup_downsample_rate')
parser.add_argument('--warmup_model_dir', type=str, default='', help='store_dir for warmup model')



# online loading
parser.add_argument('--if_online', type=int, default=0, help='if_online')
parser.add_argument('--subset_size', type=int, default=30000, help='subset_size')
parser.add_argument('--subset_num', type=int, default=4, help='subset_num')
# parser.add_argument('--end2end_postfix', type=str, default=None, help='postfix for end2end model for start training')

# store whole model
parser.add_argument('--if_store_whole_model', type=int, default=0, help='if store the whole llama model for evaluating')

# others
parser.add_argument('--device', type=str, default='cuda:0', help='device')
parser.add_argument('--visible_devices', type=str, default='1', help='visible_device')
parser.add_argument('--store_model_dir', type=str, default='', help='store_dir for lora model')
parser.add_argument('--run_flag', type=int, default=1, help='run_flag')
parser.add_argument('--verbose', type=int, default=1, help='verbose')

# if only_teacher_test
parser.add_argument('--only_init_test', type=int, default=0, help='only_init_test')

args = parser.parse_args()

if args.only_init_test:
	args.run_flag = 0

# print args
for arg in vars(args):
	print(arg, getattr(args, arg))

# assert (args.visible_devices == "0") or int(args.visible_devices), f"just support one visible device currently!, but visible_devices: {args.visible_devices}"
os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_devices

os.makedirs(args.store_model_dir, exist_ok=True)

# get tokenizer
tokenizer = LlamaTokenizer.from_pretrained(args.model_dir)
tokenizer.pad_token = tokenizer.eos_token

# assert that if_warmup == 1 when if_online == 1
# assert (args.if_online == 0 or args.if_warmup == 1), f'if_online == 1, but if_warmup == 0'

################# First: get datasets or dataset_generator #################
DM = DatasetManager()

if args.if_online == 1:
    if args.warmup_dataset_names == 'fineweb':
        dataloader_generator, test_texts = DM.get_fineweb_dataloaders_online(
			# encode
			args.warmup_text_batch_size,
			1, # actually not use batch_size_val here
			tokenizer,
			args.max_tokens,
			# subset splits
			subset_size = args.subset_size,
			subset_num = args.subset_num
		)
    else:
		# get the train and test loader, and then simple use (train_loader, val_loader) as the dataloader_generator
        train_texts, val_texts, test_texts = DM.get_dataset_texts(
			args.warmup_dataset_names,
			args.warmup_dataset_names,
			# args.original_skip_size
		)
		# # get dataloaders from the texts
        # train_loader, val_loader = get_train_val_dataloader_from_texts(
		# 	train_texts,
		# 	val_texts,
		# 	batch_size=args.warmup_text_batch_size,
		# 	tokenizer=tokenizer,
		# 	MAX_LEN=args.max_tokens,
		# 	downsample_rate=args.warmup_downsample_rate
		# )
        # print(f'WARMUP: len(train_loader): {len(train_loader)}, len(val_loader): {len(val_loader)}')
        # dataloader_generator = [(train_loader, val_loader)]
        
        # split the texts into subset_num parts for online loading
        warmup_subset_num = args.subset_num
        warmup_train_subset_size = len(train_texts) // warmup_subset_num
        warmup_val_subset_size = len(val_texts) // warmup_subset_num
        
        args.subset_size = warmup_train_subset_size
        
        train_texts_splits = [train_texts[i*warmup_train_subset_size:(i+1)*warmup_train_subset_size] for i in range(warmup_subset_num)]
        val_texts_splits = [val_texts[i*warmup_val_subset_size:(i+1)*warmup_val_subset_size] for i in range(warmup_subset_num)]
        
        dataloader_generator = []
        for i in range(warmup_subset_num):
            train_loader, val_loader = get_train_val_dataloader_from_texts(
				train_texts_splits[i],
				val_texts_splits[i],
				batch_size=args.warmup_text_batch_size,
				tokenizer=tokenizer,
				MAX_LEN=args.max_tokens,
				downsample_rate=args.warmup_downsample_rate
			)
            dataloader_generator.append((train_loader, val_loader))
        
        print(f'WARMUP: len(dataloader_generator): {len(dataloader_generator)}, args.subset_size: {args.subset_size}, args.subset_num: {args.subset_num}')
  

 
 

# get info_tuples
ITM = InfoTupleManager(args.info_tuples_type)
info_tuples = ITM.get_info_tuples(
	rank=args.rank, 
	start_layer=args.start_layer, 
	end_layer=args.end_layer, 
	step=args.step
)

print(f'info_tuples: {info_tuples}')


if args.if_warmup == 1 and args.run_flag == 1:
    print(f'========================= START WARMUP =========================')
    tmp_info_tuples = [info_tuple for info_tuple in info_tuples if info_tuple[0] != info_tuple[1]] # skip the approximate layers
    
    if args.if_online == 1:
        parallel_online_warmup(
			model_dir=args.model_dir,
   			device=args.device,
			teacher_device='cuda:0',
   			info_tuples=tmp_info_tuples,
			forward_type=args.forward_type,
			dataloader_generator=dataloader_generator,
			# dataset
			start_layer_idx=args.start_layer,
			end_layer_idx=args.end_layer,
			# train
			warmup_batch_size=args.warmup_batch_size,
			lr=args.warmup_lr,
			num_epochs=args.warmup_n_epochs,
			eval_freq=args.warmup_eval_freq,
			store_model_dir=args.warmup_model_dir,
			subset_size=args.subset_size
		)
    else:
        
		# if os.path.exists(args.warmup_model_dir):
		#     print(f'!! Warmup model exists in {args.warmup_model_dir}\n move all its files to {args.store_model_dir} directly')
		# else:
		#     print(f'!! Warmup model does NOT exist\n1. Create a new one in {args.warmup_model_dir}\n2. Move all its files to {args.store_model_dir}')

        lorallama_warmup(
				model_dir=args.model_dir,
				device=args.device,
				info_tuples=tmp_info_tuples,
				forward_type=args.forward_type,
				# dataset
				dataset_type=args.dataset_type,	
				random_size=args.random_size,
				batch_size=args.warmup_batch_size,
				batch_size_val=args.warmup_batch_size,	
				# train
				lr=args.warmup_lr,
				num_epochs=args.warmup_n_epochs,
				eval_freq=args.warmup_eval_freq,
				store_model_dir=args.warmup_model_dir,
				store_data_path=args.store_data_path
			)
    
    # cp all the files
    os.system(f'cp -r {args.warmup_model_dir}/* {args.store_model_dir}/')
    
    


train_texts, val_texts, test_texts = DM.get_dataset_texts(
	args.dataset_name,
	args.test_type,
	args.original_skip_size
)

if args.if_subtrain == 1 and args.run_flag == 1:
    # subtrain
    it_stages_splits = ITM.get_subtrain_splits(info_tuples, subtrain_type=args.subtrain_type)
    
    # get subtrain datasets
    dataset_train_sub, dataset_val_sub = get_datasets_for_SFT(
        train_texts, 
		val_texts, 
		batch_size_val=args.batch_size_val, 
		tokenizer=tokenizer, 
		downsample_rate=args.subtrain_downsample_rate
	)
    
    
	# tmp_model = LlamaForCausalLM.from_pretrained(model_dir, device_map='auto')
    # tmp_model = LlamaForCausalLM.from_pretrained(args.model_dir).to(args.device)
    
    
    for idx in range(len(it_stages_splits)):
        it_splits = it_stages_splits[idx]
        print('#'*40)
        print(f'####### Stage {idx}: {it_splits} #######')
        print('#'*40)
        
        for it_split in it_splits:
            tmp_model = LlamaForCausalLM.from_pretrained(args.model_dir)
            subtrain_use_stored = 1 if idx > 0 else args.subtrain_use_stored # only judge if need to use_stored for the first stage
            
            student_model_sub = LoRALlama(tmp_model, it_split, forward_type=args.forward_type, store_model_dir=args.store_model_dir, use_stored=subtrain_use_stored)
            train(
				student_model_sub, 
				tokenizer, 
				dataset_train_sub, 
				dataset_val_sub,
				max_tokens = args.max_tokens, 
				batch_size = args.batch_size, 
				batch_size_val = args.batch_size_val,
				lr = args.lr,
				n_epochs = args.n_epochs, 
				lr_scheduler_type = args.lr_scheduler_type, 
				verbose = args.verbose,
				is_bf16 = args.is_bf16
			)
            student_model_sub.store_loramlp()
            student_model_sub.visualize_trainable_params()
            
            del tmp_model, student_model_sub
            torch.cuda.empty_cache()
        
#### get the dataset ####

if args.run_flag == 1:
	dataset_train, dataset_val = get_datasets_for_SFT(
			train_texts, 
			val_texts, 
			batch_size_val=args.batch_size_val, 
			tokenizer=tokenizer, 
			downsample_rate=args.downsample_rate
	)
 
test_encodings = tokenizer("\n".join(test_texts), return_tensors='pt')

########## train all layers ##########
if (args.run_flag == 0 and args.only_init_test == 0) or (args.if_subtrain != 0) or (args.if_warmup != 0):
	use_stored = 1
else:
	use_stored = args.use_stored
    
# use_stored = args.use_stored if (args.if_subtrain == 0 and args.if_warmup == 0) else 1

# tmp_model = LlamaForCausalLM.from_pretrained(model_dir, device_map='auto')
tmp_model = LlamaForCausalLM.from_pretrained(args.model_dir).to(args.device)

# if args.run_flag == 1:
if not args.only_init_test:
	visualize_output(tmp_model, None, tokenizer)
    # ppl = calculate_ppl(tmp_model, test_encodings, stride=512, device=args.device)
    # print(f'baseline model ppl: {ppl}') # 8.x for fineweb

if args.use_stored_dir is not None and args.use_stored_dir != "" and args.run_flag == 1: # remember this is for the loramlp rather than whole model!
    os.system(f'cp -r {args.use_stored_dir}/*.pth {args.store_model_dir}/')
    print(f'>>>> copy all the pth files from {args.use_stored_dir} to {args.store_model_dir}')
    os.system(f'ls {args.store_model_dir}')
    use_stored = 1

# end2end_postfix = args.end2end_postfix if args.end2end_postfix is not None else str(get_online_warmup_postfix(args.subset_size, args.subset_num-1))
use_postfix = args.if_online if (args.run_flag == 1 and args.warmup_dataset_names == 'fineweb') else 0
# print(f'### end2end_postfix: {end2end_postfix}')
student_model = LoRALlama(tmp_model, info_tuples, store_model_dir=args.store_model_dir, forward_type=args.forward_type, use_stored=use_stored, subset_size=args.subset_size, subset_num=args.subset_num, use_postfix=use_postfix)

print(f'\n######### if_store_baseline = {args.if_store_baseline} #############\n')
if args.if_store_baseline == 1:
    note = f'_baseline_warmup{args.if_warmup}'
    if args.if_replace_self_attn == 1:
        note += '_replace_SA'
    student_model.store_whole_llama_model(args.model_dir, tokenizer, if_remove_current_model=1, note=note, if_replace_self_attn=args.if_replace_self_attn)
    exit()

if args.run_flag == 1:	
	visualize_output(student_model, None, tokenizer)
	ppl = calculate_ppl(student_model, test_encodings, stride=512, device=args.device)
	student_model.visualize_trainable_params()
	print(f'original ppl: {ppl}')



################## some results 
# # max token 512, bs 2, epoch 2, lr 1e-3, ft 0, add self sora layer 6.4 (0.005), 5.8 for (0.01), 4.6 (0.1) 
# # max token 512, bs 1, epoch 2, lr 1e-3, ft 7, add self sora layer 6.4 (0.005), 5.9 for (0.01), 4.4 (0.1). But I think its generation is better. Especially for dialog




#### train the student model ####
if args.run_flag == 1:
	train(
		student_model, 
		tokenizer, 
		dataset_train, 
		dataset_val,
		max_tokens = args.max_tokens, 
		batch_size = args.batch_size, 
		batch_size_val = args.batch_size_val,
		lr = args.lr,
		n_epochs = args.n_epochs, 
		lr_scheduler_type = args.lr_scheduler_type, 
		verbose = args.verbose,
		is_bf16 = args.is_bf16,
		grad_acc = args.grad_acc
	)
 
	# train2(
	# 	student_model, 
	# 	tokenizer, 
	# 	dataset_train, 
	# 	dataset_val,
	# 	max_tokens = args.max_tokens, 
	# 	batch_size = args.batch_size, 
	# 	batch_size_val = args.batch_size_val,
	# 	lr = args.lr,
	# 	n_epochs = args.n_epochs, 
	# 	lr_scheduler_type = args.lr_scheduler_type, 
	# 	verbose = args.verbose,
	# 	is_bf16 = args.is_bf16,
	# 	grad_acc = args.grad_acc
	# )
	student_model.store_loramlp()
	student_model.visualize_trainable_params()


if args.only_init_test:
    visualize_output(student_model, tokenizer=tokenizer, device=args.device, if_short=1)
    ppl = calculate_ppl(student_model, test_encodings, stride=512, device=args.device)
    print(f'ppl: {ppl}')
    exit()
else:
	visualize_output(student_model, tokenizer=tokenizer, device=args.device)

	ppl = calculate_ppl(student_model, test_encodings, stride=512, device=args.device)
	# ppl = calculate_ppl(student_model, test_encodings, stride=512, device=None) # not support model parallelism now

	print(f'ppl: {ppl}')

# if args.if_store_whole_model == 1:
#     print(f'========================= STORE WHOLE MODEL =========================')
#     student_model.store_whole_llama_model(args.model_dir, tokenizer)
# else:
# 	print(f"============= No need to store the whole model ============")

if args.if_store_whole_model == 1 or args.run_flag == 1:
	student_model.store_whole_llama_model(args.model_dir, tokenizer)

# if args.if_store_whole_model == 1:
#     if args.run_flag == 0:
#         print(f'========================= STORE WHOLE MODEL =========================')
#         student_model.store_whole_llama_model(args.model_dir, tokenizer)
#     else:
#         print(f"=============== ERROR: Currently just support store whole model while finished running! ============")