import torch
import os
import argparse
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
)
from datasets import Dataset
import json

def load_and_process_json_data(input_path):
    texts, raw_labels = [], []
    
    if os.path.isfile(input_path):
        files = [input_path]
    else:
        files = [os.path.join(input_path, f) 
                for f in os.listdir(input_path) 
                if f.endswith('.json')]
    
    for file in files:
        with open(file, 'r', encoding='utf-8') as f:
            data = json.load(f)
            for item in data:
                if 'query' not in item or 'category' not in item:
                    raise ValueError(f"Invalid item format in {file}")
                texts.append(str(item['query']))
                raw_labels.append(int(item['category']))

    label_encoder = LabelEncoder()
    encoded_labels = label_encoder.fit_transform(raw_labels)
    return texts, encoded_labels, len(label_encoder.classes_)

class CustomDataCollator:
    def __init__(self, tokenizer, max_length=8192):
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __call__(self, features):
        batch_max_length = min(
            max(len(f["input_ids"]) for f in features),
            self.max_length
        )
        
        batch = {"input_ids": [], "attention_mask": [], "labels": []}
        
        for feature in features:
            input_ids = feature["input_ids"][:batch_max_length]
            attention_mask = feature["attention_mask"][:batch_max_length]
            
            pad_length = batch_max_length - len(input_ids)
            if pad_length > 0:
                input_ids += [self.tokenizer.pad_token_id] * pad_length
                attention_mask += [0] * pad_length
            
            batch["input_ids"].append(input_ids)
            batch["attention_mask"].append(attention_mask)
            batch["labels"].append(feature["labels"])
        
        return {
            "input_ids": torch.tensor(batch["input_ids"], dtype=torch.long),
            "attention_mask": torch.tensor(batch["attention_mask"], dtype=torch.long),
            "labels": torch.tensor(batch["labels"], dtype=torch.long)
        }

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', required=True)
    parser.add_argument('--model_path', required=True)
    parser.add_argument('--output_dir', default='')
    parser.add_argument('--learning_rate', type=float, default=5e-5)
    parser.add_argument('--num_train_epochs', type=int, default=3)
    parser.add_argument('--per_device_train_batch_size', type=int, default=32)
    parser.add_argument('--per_device_eval_batch_size', type=int, default=32)
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
    parser.add_argument('--shuffle_seed', type=int, default=43)
    parser.add_argument('--test_size', type=float, default=0.04)
    parser.add_argument('--bf16', action='store_true')
    parser.add_argument('--gradient_checkpointing', action='store_true')
    parser.add_argument('--local_rank', type=int, default=-1, help="Local rank passed from distributed launcher")
    parser.add_argument('--deepspeed', type=str, default=None, help="Path to DeepSpeed config file")
    args = parser.parse_args()

    # 数据加载与处理
    texts, labels, num_labels = load_and_process_json_data(args.data_path)
    
    # 数据打乱
    indices = np.random.permutation(len(texts))
    texts = [texts[i] for i in indices]
    labels = labels[indices]

    # 数据集分割
    train_texts, test_texts, train_labels, test_labels = train_test_split(
        texts, labels, 
        test_size=args.test_size,
        random_state=args.shuffle_seed,
        stratify=labels
    )

    # 模型初始化
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    
    # 修改点1：仅截断不填充
    def encode_data(texts, labels):
        encodings = tokenizer(
            texts,
            truncation=True,
            padding=False,
            max_length=8192,
            return_overflowing_tokens=False
        )
        return Dataset.from_dict({
            "input_ids": encodings["input_ids"],
            "attention_mask": encodings["attention_mask"],
            "labels": labels.astype(np.int64)
        })

    train_dataset = encode_data(train_texts, train_labels)
    test_dataset = encode_data(test_texts, test_labels)

    model = AutoModelForSequenceClassification.from_pretrained(
        args.model_path,
        num_labels = 8,
        torch_dtype = torch.bfloat16 if args.bf16 else torch.float32,
        use_flash_attention_2 = True  # 启用Flash Attention
    )
    data_collator = CustomDataCollator(tokenizer)

    training_args = TrainingArguments(
        output_dir=args.output_dir,
        evaluation_strategy="epoch",
        num_train_epochs=args.num_train_epochs,
        per_device_train_batch_size=args.per_device_train_batch_size,
        per_device_eval_batch_size=args.per_device_eval_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        learning_rate=args.learning_rate,
        bf16=args.bf16,         # bf16=True
        bf16_full_eval=True, 
        fp16=False, 
        fp16_full_eval=False,
        gradient_checkpointing=True,
        optim="adamw_torch",
        logging_dir='./logs',
        logging_steps=50,
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        dataloader_num_workers=16,
        dataloader_pin_memory=True,
        report_to="none",
        remove_unused_columns=True,
        group_by_length=True,
        deepspeed=args.deepspeed
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        data_collator=data_collator,
        compute_metrics=lambda p: {
            "accuracy": (np.argmax(p.predictions, axis=1) == p.label_ids).mean()
        }
    )

    if torch.cuda.is_available():
        print(f"{torch.cuda.memory_allocated()/1024**3:.2f} GB")
        print(f"{torch.cuda.get_device_properties(0).total_memory/1024**3:.2f} GB")

    try:
        trainer.train()
    except RuntimeError as e:
        if 'CUDA out of memory' in str(e):
            current_mem = torch.cuda.memory_allocated()/1024**3
            print(f"{current_mem:.2f}GB")
            raise

if __name__ == "__main__":
    main()