import os,glob, subprocess, re, argparse, math
# os.environ['HF_HOME'] = "checkpoints"
# os.environ['HTTP_PROXY'] = 'http://127.0.0.1:8118'
# os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:8118'
from huggingface_hub import login
from datasets import Dataset
# #
from utils import get_dataset,get_processFunc,get_texts
# from glue_tasks import get_glue_tasks, for_Trainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, AutoModelForSeq2SeqLM
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
from transformers import Trainer, TrainingArguments
import torch
from tqdm import tqdm
import evaluate
import pandas as pd
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
import math,glob

def special_trainnig_process_flan(args):

    checkpoints = 'checkpoints'
    if args.checkpoints2:
        checkpoints = 'checkpoints2'
        
    
    dtype = torch.bfloat16
    tokenizer = AutoTokenizer.from_pretrained(args.base_model)
    tokenizer.max_seq_length = 4096
    
    tx_train = get_texts(args.dataset)
    
    if args.epochs>0:
        args.steps = math.ceil(len(tx_train) * args.epochs / args.batch_size)
    else:
        assert 0
        args.steps = math.ceil(args.cases / args.batch_size)
        
    tokenized_inputs = tokenizer([i['instruction'].split('And')[-1]+'\n'+i['input']\
        for i in tx_train], padding = True, truncation = True, return_tensors="pt")
    inputs = tokenized_inputs['input_ids']

    tokenized_outputs = tokenizer([i['output'] for i in tx_train], padding = True, return_tensors="pt")
    outputs = tokenized_outputs['input_ids']
    
    # data = [{'input':i, 'output':j} for i,j in zip(inputs, outputs)]
    data = {'input_ids':inputs,'labels':outputs}
    
    
    model = AutoModelForSeq2SeqLM.from_pretrained(args.base_model)
    
    training_args = Seq2SeqTrainingArguments(
        output_dir=f'{checkpoints}/hub/{args.output_model_name.replace('/','--')}/snapshots/whateverjustsomething/',
        per_device_train_batch_size=args.batch_size,
        learning_rate=2e-5,
        num_train_epochs=args.epochs,
        weight_decay=0.01,               # weight_decay
        lr_scheduler_type='cosine',
        bf16=True,
        logging_steps=10,                # log 打印的频率
        # eval_steps=args.eval_steps,
        eval_strategy="no",     # 评估策略
        # eval_accumulation_steps = 1, 
        # max_steps = args.steps,            # 训练epoch数
        save_strategy="steps",           # 保存策略
        save_steps=0.4999,
        save_total_limit=1,              # 最大保存数
        load_best_model_at_end=False,      # 训练完成后加载最优模型
        deepspeed="deepspeed_config/stanford_alpaca.json",  	# DeepSpeed配置文件      
    )
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        # data_collator=collator,
        train_dataset=Dataset.from_dict(data)
    )
    trainer.train()
