

# import pandas as pd
# import tensorflow as tf
# from sklearn.model_selection import train_test_split
# from transformers import BertTokenizer, TFBertForSequenceClassification, BertConfig
# from transformers import DataCollatorWithPadding
# from transformers import create_optimizer
# import os
# import shutil
# import torch
#
#
# physical_devices = tf.config.list_physical_devices('GPU')
# if physical_devices:
#     tf.config.experimental.set_memory_growth(physical_devices[0], True)
#     print(f'Using GPU: {physical_devices[0].name}')
# else:
#     print('No GPU found, using CPU.')
#
# pretrained_model_name = 'google-bert/bert-base-uncased'
# data_dir = './finetune'
#
# # 初始化BERT tokenizer
# tokenizer = BertTokenizer.from_pretrained(pretrained_model_name)
# # 加载本地BERT模型配置
# config = BertConfig.from_json_file('./google-bert/bert-base-uncased/config.json')
# model = TFBertForSequenceClassification.from_pretrained(pretrained_model_name)
#
# for dirpath, dirnames, filenames in os.walk(data_dir):
#     for dirname in dirnames:
#         folder_path = os.path.join(dirpath, dirname)
#         train_data = os.path.join(folder_path, 'train.csv')
#         test_data = os.path.join(folder_path, 'test.csv')
#         dev_data = os.path.join(folder_path, 'dev.csv')
#

import os
import pandas as pd
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from datasets import Dataset
import torch
from safetensors.torch import load_model, save_model

os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
# def tokenize_function(examples):
#     try:
#         # 确保只处理有效的字符串输入
#         valid_texts = [text for text in examples['text'] if isinstance(text, str)]
#         return tokenizer(valid_texts, padding='max_length', truncation=True)
#     except Exception as e:
#         print(f"Error in tokenization: {e}")
#         return {}


# 检查是否有可用的GPU
if torch.cuda.is_available():
    # 设置当前设备为第一个GPU
    device = torch.device("cuda:0")
    print(f'Using GPU: {torch.cuda.get_device_name(0)}')

    # 设置内存增长选项（PyTorch默认没有像TensorFlow那样的显式设置）
    # 但PyTorch会自动处理内存分配，可以使用CUDA上下文控制内存管理
    torch.cuda.set_per_process_memory_fraction(0.9, 0)  # 设定使用GPU内存的比例为80%
else:
    device = torch.device("cpu")
    print('No GPU found, using CPU.')


# 读取CSV数据
def load_data(file_path):
    data = pd.read_csv(file_path, sep='\t', header=None, names=["file_id", "label", "star", "text"])
    return data


# 创建Dataset对象
def create_dataset(data):
    return Dataset.from_pandas(data[['text', 'label']])


# Tokenize数据集
def tokenize_function(examples):
    print("Processing examples:", examples)
    return tokenizer(examples['text'], padding='max_length', truncation=True)


def check_dataset(dataset):
    non_string_indices = []
    for i, example in enumerate(dataset):
        if not isinstance(example['text'], str):
            non_string_indices.append(i)
    return non_string_indices


# 遍历数据目录
data_dir = './finetune'

for dirpath, dirnames, filenames in os.walk(data_dir):
    for dirname in dirnames:
        if not "delegatecall" in dirname:
            print(("Dirname", dirname))
            folder_path = os.path.join(dirpath, dirname)

            train_file = os.path.join(folder_path, 'train.tsv')
            test_file = os.path.join(folder_path, 'test.tsv')
            dev_file = os.path.join(folder_path, 'dev.tsv')

            train_data = load_data(train_file)
            dev_data = load_data(dev_file)
            test_data = load_data(test_file)

            train_dataset = create_dataset(train_data)
            dev_dataset = create_dataset(dev_data)
            test_dataset = create_dataset(test_data)

            non_string_indices_train = check_dataset(train_dataset)
            non_string_indices_dev = check_dataset(dev_dataset)
            non_string_indices_test = check_dataset(test_dataset)

            print("Non-string indices in train dataset:", non_string_indices_train)
            print("Non-string indices in dev dataset:", non_string_indices_dev)
            print("Non-string indices in test dataset:", non_string_indices_test)

            # 加载预训练的BERT tokenizer
            tokenizer = BertTokenizer.from_pretrained('./google-bert/bert-base-uncased')

            train_dataset = train_dataset.map(tokenize_function, batched=True)
            dev_dataset = dev_dataset.map(tokenize_function, batched=True)
            test_dataset = test_dataset.map(tokenize_function, batched=True)

            # 设置格式为torch tensor
            train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
            dev_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
            test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

            # 加载预训练的BERT模型
            model = BertForSequenceClassification.from_pretrained('./google-bert/bert-base-uncased', num_labels=2)

            # 训练参数设置
            training_args = TrainingArguments(
                output_dir=f'./results/{dirname}',
                evaluation_strategy="epoch",
                learning_rate=2e-5,
                per_device_train_batch_size=16,
                per_device_eval_batch_size=16,
                num_train_epochs=50,
                weight_decay=0.01,
            )

            # 定义Trainer
            trainer = Trainer(
                model=model,
                args=training_args,
                train_dataset=train_dataset,
                eval_dataset=dev_dataset,
            )

            # 训练模型
            trainer.train()

            # 评估模型
            trainer.evaluate()

            # 保存模型
            #trainer.save_model(f"./models/{dirname}")
            if not os.path.exists(f'./models/{dirname}'):
                os.makedirs(f'./models/{dirname}')
            save_model(model,f"./models/{dirname}/model.safetensors")
            # 测试集预测
            predictions = trainer.predict(test_dataset)
            print(f"Predictions for {dirname}: {predictions.predictions.argmax(-1)}")
