import torch
import datasets
import transformers
from transformers import (
    CONFIG_MAPPING,
    MODEL_MAPPING,
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    SchedulerType,
    default_data_collator,
    get_scheduler
)
import torch_npu
from torch_npu.npu import amp
from torch_npu.contrib import transfer_to_npu

import os
import argparse

from minillm.pipelines import LMPipeline

import sys

class Logger(object):
    def __init__(self, filename):
        self.terminal = sys.stdout
        self.log = open(filename, "w", encoding="utf-8")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)

    def flush(self):
        self.terminal.flush()
        self.log.flush()

# data path: /opt/dpcvol/datasets

def parse_args():
    parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task")

    parser.add_argument("--dataset_name", type=str)
    parser.add_argument("--data_path", type=str)
    parser.add_argument("--save_path", type=str)
    parser.add_argument("--config_path", type=str)
    parser.add_argument("--tokenizer_path", type=str)
    parser.add_argument("--max_tokens", type=int)
    parser.add_argument("--trust_remote_code", type=bool, 
                        default=True)
    parser.add_argument("--use_slow_tokenizer", action="store_true",)
    parser.add_argument("--model_parallel", action="store_true")
    parser.add_argument("--train-num", type=int, default=-1)
    parser.add_argument("--dev-num", type=int, default=-1)
    parser.add_argument("--max_length", type=int, default=1024)
    parser.add_argument("--seed_lm", type=int, default=7)
    parser.add_argument("--model_type", type=str, default='gpt2')
    parser.add_argument("--LInit_path", type=str, default=None)

    args, unknown = parser.parse_known_args()

    return args

def main():
    myargs = parse_args()

    if not os.path.exists(myargs.save_path):
        os.makedirs(myargs.save_path)
    
     # ===== 新增：日志重定向 =====
    log_file = os.path.join(myargs.save_path, "train.log")
    sys.stdout = Logger(log_file)
    sys.stderr = sys.stdout
    print(f"所有输出将同时记录到: {log_file}")

    # load tokenizers
    tokenizer = AutoTokenizer.from_pretrained(
        myargs.tokenizer_path, use_fast=not myargs.use_slow_tokenizer, trust_remote_code = myargs.trust_remote_code
    )

    # load dataset
    if myargs.dataset_name == 'openwebtext-100k':
        raw_datasets = datasets.load_dataset('parquet', data_files=myargs.data_path)
        print("Total dataset info...")
        print(raw_datasets)
        raw_datasets = raw_datasets["train"].train_test_split(test_size=0.1, seed=2333)
        print("Splited dataset info...")
        print(raw_datasets)
        total_samples = total_samples = len(raw_datasets["train"])
        context_length = 512  # use a small context length
        
        # preprocess dataset
        def tokenize(element):
            if myargs.dataset_name == 'openwebtext-100k':
                outputs = tokenizer(
                element["text"],
                truncation=True,
                max_length=context_length,
                return_overflowing_tokens=True,
                return_length=True,
                )
            input_batch= []
            for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
                if length == context_length:
                    input_batch.append(input_ids)
            return {"input_ids": input_batch}

        tokenized_datasets = raw_datasets.map(
            tokenize, batched=True, remove_columns=raw_datasets["train"].column_names)
        
        tokenizer.pad_token = tokenizer.eos_token
        
    elif myargs.dataset_name == 'openwebtext':
        torch.distributed.init_process_group('nccl',world_size=1,rank=0)
        num_row = int(myargs.max_tokens/512)
        print(f"训练所用的token数目: {num_row*512 / 10**8}亿")
        lm_pipeline = LMPipeline(
            myargs, tokenizer, "train", myargs.data_path, num=num_row) if myargs.data_path is not None else None
        eval_lm_pipeline = LMPipeline(
            myargs, tokenizer, "valid", myargs.data_path, num=myargs.dev_num) if myargs.data_path is not None else None
        
        total_samples = total_samples = len(lm_pipeline)
        # train_dataloader = lm_pipeline.create_loader(
        #     myargs.per_device_train_batch_size, shuffle=True, num_workers=myargs.num_workers, drop_last=True)
        # # lm_iterator = iter(lm_dataloader)
        # eval_dataloader = eval_lm_pipeline.create_loader(
        #     myargs.per_device_eval_batch_size, shuffle=False, num_workers=myargs.num_workers, drop_last=False)
        tokenized_datasets = {'train':lm_pipeline, 'test':eval_lm_pipeline}
        
    print("Tokenize dataset info...")
    print(tokenized_datasets)
   
    # 选取前 nB 个 token
    # if myargs.max_tokens:
    #     num_row = int(myargs.max_tokens/512)
    #     subset = tokenized_datasets['train'][:195312]
    #     tokenized_datasets = {'train':subset, 'test':eval_lm_pipeline}
    #     # tokenized_datasets['train'] = tokenized_datasets['train'].select(selected_indices)
    #     total_tokens = sum(len(item['input_ids']) for item in tokenized_datasets['train'])
    #     print(f'所用训练数据的token数量: {total_tokens/10**8:.2f} 亿')

    data_collator = transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)

    # prepare a model from scratch or initialized by learngene
    if myargs.LInit_path is None:
        print('Training from scratch!')
        config = AutoConfig.from_pretrained(
                myargs.config_path,
                trust_remote_code=myargs.trust_remote_code,
            )
        model = AutoModelForCausalLM.from_config(config)
    else:
        print('Training from Learngene Init!')
        config = AutoConfig.from_pretrained(myargs.LInit_path)
        model = AutoModelForCausalLM.from_pretrained(myargs.LInit_path, config=config)



    model_size = sum(t.numel() for t in model.parameters())
    print("Model Config:")
    print(config)
    print(f"Model Size: {model_size/1000**2:.1f}M parameters")

    # train
    Train_args = transformers.TrainingArguments(
        output_dir="gpt2_results",
        per_device_train_batch_size=16,  # 每个GPU的训练batch数
        per_device_eval_batch_size=16,  # 每个GPU的测试batch数
        # eval_strategy="steps",
        eval_steps=5000,
        logging_steps=50,
        gradient_accumulation_steps=2,  # 梯度累计总数
        num_train_epochs=2,  # 训练epoch数
        weight_decay=0.1,
        warmup_steps=2_00,
        optim="adamw_torch",  # 优化器使用adamw
        lr_scheduler_type="cosine",  # 学习率衰减策略
        learning_rate=5e-4,  # 基础学习率，
        save_steps=5_00,
        save_total_limit=10,
        bf16=True,  # 开启bf16训练, 对于Amper架构以下的显卡建议替换为fp16=True
        # max_steps=20000,  # 设置最大训练步数为 10,000
    )
    print("Train Args:")
    print(Train_args)
    # 输出总共训练的step个数
    batch_size = Train_args.per_device_train_batch_size
    num_epochs = Train_args.num_train_epochs
    gradient_accumulation_steps = Train_args.gradient_accumulation_steps
    steps_per_epoch = total_samples // (batch_size * gradient_accumulation_steps)
    total_steps = steps_per_epoch * num_epochs
    print(f"总训练步数: {total_steps}")
    # enjoy training
    trainer = transformers.Trainer(
        model=model,
        tokenizer=tokenizer,
        args=Train_args,
        data_collator=data_collator,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["test"]
    )
    trainer.train()

    # save model
    model.save_pretrained(myargs.save_path)  # 保存模型的路径
    print(f"模型权重保存到: {myargs.save_path}!")

if __name__ == "__main__":
    main()