import os,glob, subprocess, re, argparse, math
from huggingface_hub import login
from datasets import Dataset
# #
from utils import get_dataset,get_processFunc
# from glue_tasks import get_glue_tasks, for_Trainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, AutoModelForSeq2SeqLM
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
from transformers import Trainer, TrainingArguments
from train_flan import special_trainnig_process_flan
import torch
from tqdm import tqdm
import evaluate
import pandas as pd
import math,glob



# acc_metric = evaluate.load("accuracy")
# f1_metirc = evaluate.load("f1")
# def eval_metric(eval_predict):
#     predictions, labels = eval_predict
#     predictions = predictions.argmax(axis=-1)
#     acc = acc_metric.compute(predictions=predictions, references=labels)
#     f1 = f1_metirc.compute(predictions=predictions, references=labels)
#     acc.update(f1)
#     return acc

def model_name_to_ckpt(model_name):
    path = 'checkpoints/hub/models--'+model_name.replace('/','--')
    path += '/snapshots'
    
    if args.checkpoints2:
        path = path.replace('checkpoints','checkpoints2')
    if 'my-' in model_name:
        path = 'checkpoints/hub/'+model_name.replace('/','--')
        
        if args.checkpoints2:
            path = path.replace('checkpoints','checkpoints2')
        path += '/snapshots'
        path = '../llama_on_glue/' + path
        paths = os.listdir(path)
        path = path+'/'+paths[0]
        paths = os.listdir(path)

        m_paths = []
        for i in paths:
            i = i.replace('checkpoint-','')
            m_paths += [int(i)]
            
        path = path+'/checkpoint-'+str(sorted(m_paths)[-1])
        print(path)
    else:
        # if args.myllama:
        #     path = '../llama_on_glue/' + path
        paths = os.listdir(path)
        path = path+'/'+paths[0]
        print(path)
    
    return path


if __name__ == '__main__':
    
    # CUDA_VISIBLE_DEVICES=0 NCCL_P2P_DISABLE="1" NCCL_IB_DISABLE="1" python train.py --base_model=google/t5-v1_1-large --dataset=sst2 --output_model_prefix=my-t5-large/   --epochs=10 --batch_size=32  
    
    # accelerator = Accelerator()
    # # Make one log on every process with the configuration for debugging.
    # logging.basicConfig(
    #     format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    #     datefmt="%m/%d/%Y %H:%M:%S",
    #     level=logging.INFO,
    # )
    # logger.info(accelerator.state)
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--cases", type=int, default=4000)
    parser.add_argument("--epochs", type=int, default=-1)
    parser.add_argument("--base_model", type=str, default="meta-llama/Llama-3.2-1B")
    parser.add_argument("--dataset", type=str, default="sst2")
    parser.add_argument("--output_model_prefix", type=str, default="my-llama2/llama-")
    parser.add_argument("--no_eval", action='store_true', default=False)
    parser.add_argument("--checkpoints2", action='store_true', default=False)

    
    parser.add_argument("--lr", type=float, default=5e-6)
    args = parser.parse_args()
    
    checkpoints = 'checkpoints'
    if args.checkpoints2:
        checkpoints = 'checkpoints2'
    
    args.output_model_name = \
        args.output_model_prefix + \
        args.base_model.split('/')[-1] + \
        args.dataset

    args.base_model = model_name_to_ckpt(args.base_model)
    
    if len(list(glob.glob(f'{checkpoints}/hub/{args.output_model_name.replace('/','--')}/snapshots/whateverjustsomething/*')))>0:
        exit(0)
    
    
    if 't5' in args.base_model and 'flan' not in args.base_model:
        print('t5 warning!')
        # assert 0
        # args.batch = args.batch_size
        # ds_train, ds_val = get_glue_tasks(args)
        # if args.epochs>0:
        #     args.steps = len(ds_train) * args.epochs
        # else:
        #     args.steps = math.ceil(args.cases / args.batch_size)
        # print(len(ds_train),args.steps)
        # args.eval_steps = math.ceil(args.cases / 1 / args.batch_size)
        
        
        # ds_train, ds_val = for_Trainer(ds_train), for_Trainer(ds_val)    
        
        # train_args = TrainingArguments(output_dir=f'{checkpoints}/hub/{args.output_model_name.replace('/','--')}/snapshots/whateverjustsomething/',  # 输出文件夹
        #                             per_device_train_batch_size=args.batch_size,  # 训练时的batch_size
        #                             per_device_eval_batch_size=args.batch_size,  # 验证时的batch_size
        #                             logging_steps=10,                # log 打印的频率
        #                             eval_strategy="no",     # 评估策略
        #                             max_steps = args.steps,            # 训练epoch数
        #                             save_strategy="steps",           # 保存策略
        #                             save_steps=args.eval_steps,
        #                             save_total_limit=1,              # 最大保存数
        #                             learning_rate=args.lr,              # 学习率
        #                             weight_decay=0.01,               # weight_decay
        #                             lr_scheduler_type='cosine',
        #                             bf16=True,
        #                             # max_seq_length=4096,
        #                             load_best_model_at_end=False,      # 训练完成后加载最优模型
        #                            )  
        # model = AutoModelForSeq2SeqLM.from_pretrained(args.base_model)

        # trainer = Trainer(model = model, 
        #             args = train_args, 
        #             train_dataset = ds_train, 
        #             ) 
        # trainer.train()
        
    
    else:
        args.steps = math.ceil(args.cases / args.batch_size)
        args.eval_steps = math.ceil(args.cases / 1 / args.batch_size)
        args.lr = 5e-6
        
   

        
        ds_train, ds_val = get_dataset(args.dataset)
        dtype = torch.bfloat16
        tokenizer = AutoTokenizer.from_pretrained(args.base_model)
        tokenizer.max_seq_length = 4096
        if 't5' not in args.base_model:
            model = AutoModelForCausalLM.from_pretrained(args.base_model)
        else:
            model = AutoModelForSeq2SeqLM.from_pretrained(args.base_model)
        
        if 'Qwen' in args.base_model or 'qwen' in args.base_model:
            prompt = '''<|im_start|>system
{}<|im_end|>
<|im_start|>user
{}<|im_end|>
<|im_start|>assistant
{}<|im_end|>'''
            response_template = "<|im_start|>assistant"
            tokenizer.pad_token = '<|endoftext|>'
        elif 'deepseek' in args.base_model:
            prompt = '''<｜begin▁of▁sentence｜> User: {}
{}
Assistant: {} <｜end▁of▁sentence｜>'''
            response_template = [77398, 25]
            tokenizer.pad_token = tokenizer.pad_token
        elif 'flan' in args.base_model:
            # prompt = '''{}{}\nAnswer: {} </s>'''
            # response_template = "Answer: "
            # tokenizer.pad_token = tokenizer.pad_token
            # args.lr = 1e-4
            special_trainnig_process_flan(args)
            exit(0)
        else:
            prompt = '''<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{} <|eot_id|><|start_header_id|>user<|end_header_id|>
{} <|eot_id|><|start_header_id|>assistant<|end_header_id|>
{} <|eot_id|>'''
            response_template = "<|start_header_id|>assistant<|end_header_id|>"
            tokenizer.pad_token = '<|reserved_special_token_0|>'
        
        train_args = SFTConfig(output_dir=f'{checkpoints}/hub/{args.output_model_name.replace('/','--')}/snapshots/whateverjustsomething/',  # 输出文件夹
                            per_device_train_batch_size=args.batch_size,  # 训练时的batch_size
                            per_device_eval_batch_size=args.batch_size,  # 验证时的batch_size
                            logging_steps=10,                # log 打印的频率
                            # eval_steps=args.eval_steps,
                            eval_strategy="no",     # 评估策略
                            # eval_accumulation_steps = 1, 
                            max_steps = args.steps,            # 训练epoch数
                            save_strategy="steps",           # 保存策略
                            save_steps=args.eval_steps,
                            save_total_limit=1,              # 最大保存数
                            learning_rate=args.lr,              # 学习率
                            weight_decay=0.01,               # weight_decay
                            lr_scheduler_type='cosine',
                            bf16=True,
                            # fp16=True,
                            max_seq_length=4096,
                            load_best_model_at_end=False,      # 训练完成后加载最优模型
                            deepspeed="deepspeed_config/stanford_alpaca.json",  	# DeepSpeed配置文件
                            ) 
            
        formatting_prompts_func = lambda x:get_processFunc(args.dataset)(x, prompt)
        collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
        
        trainer = SFTTrainer(model = model, 
                    args = train_args, 
                    train_dataset = ds_train, 
                    # eval_dataset = ds_val,
                    formatting_func=formatting_prompts_func,
                    data_collator=collator,
                    # compute_metrics = eval_metric
                    ) 
        trainer.train()



# CUDA_VISIBLE_DEVICES=2,3 accelerate launch train.py > logs.txt --dataset=sst2
# CUDA_VISIBLE_DEVICES=0 accelerate launch train.py 
# CUDA_VISIBLE_DEVICES=2,3 NCCL_P2P_DISABLE="1" NCCL_IB_DISABLE="1" python train.py > logs.txt --dataset=sst2 --output_model_prefix=my-llama-fewshot/llama- --cases=800
