import json
import random
import numpy as np
from sklearn.neighbors import NearestNeighbors

# 读取 all_features.json 文件内容
with open('dataset/all_features.json', 'r', encoding='utf-8') as file:
    data = json.load(file)

# 根据 'y' 值将数据分为两类
data_class_1 = [item for item in data if item['y'] == 1]
data_class_0 = [item for item in data if item['y'] == 0]

print("类别1样本数:", len(data_class_1))
print("类别0样本数:", len(data_class_0))

target_count = len(data_class_0)


# 通用的样本生成函数
def generate_sample(sample1, sample2):
    # 合并 token，去重并随机选择一半
    combined_tokens = list(set(sample1["token"] + sample2["token"]))
    selected_tokens = random.sample(combined_tokens, len(combined_tokens) // 2)

    combined_tokens_group = sample1["token"] + sample2["token"]
    combined_comments_group = sample1["comments_feature"] + sample2["comments_feature"]
    new_token_group = {key: value for key, value in zip(combined_tokens_group, combined_comments_group)}

    selected_comments_feature = [new_token_group[key] for key in selected_tokens]

    # 生成新样本
    new_sample = {
        "y": sample1["y"],  # 保留类别
        "contract_name": sample1["contract_name"],  # 保留 contract_name
        "token": selected_tokens,
        "ast_features": list(
            np.clip(
                (np.array(sample1["ast_features"]) * 0.8 + np.array(sample2["ast_features"]) * 0.2)
                + np.random.normal(0, 0.01, size=512),
                0, 1
            )
        ),
        "cfg_features": list(
            np.clip(
                (np.array(sample1["cfg_features"]) * 0.8 + np.array(sample2["cfg_features"]) * 0.2)
                + np.random.normal(0, 0.01, size=512),
                0, 1
            )
        ),
        "comments_feature": selected_comments_feature,
    }
    return new_sample


# 数据增强函数
def generate_new_samples(existing_samples, target_count):
    for sample in existing_samples:
        sample["cfg_features"] = [float(value) for value in sample["cfg_features"]]
    new_samples = []
    while len(existing_samples) + len(new_samples) < target_count:
        # 随机选两个样本进行插值
        sample1 = random.choice(existing_samples)
        sample2 = random.choice(existing_samples)

        # 生成新样本
        new_sample = generate_sample(sample1, sample2)
        new_samples.append(new_sample)
    return new_samples


def smote_samples(existing_samples, target_count, k_neighbors=5):
    for sample in existing_samples:
        sample["cfg_features"] = [float(value) for value in sample["cfg_features"]]
    # 提取特征，组合 ast_features 和 cfg_features
    features = np.array([item["ast_features"] + item["cfg_features"] for item in existing_samples])
    nbrs = NearestNeighbors(n_neighbors=k_neighbors).fit(features)
    new_samples = []

    while len(existing_samples) + len(new_samples) < target_count:
        idx = random.randint(0, len(existing_samples) - 1)
        neighbors = nbrs.kneighbors([features[idx]], return_distance=False)[0]
        neighbor_idx = random.choice(neighbors)

        # 插值生成 comments_feature
        sample1 = existing_samples[idx]
        sample2 = existing_samples[neighbor_idx]

        # 生成新样本
        new_sample = generate_sample(sample1, sample2)
        new_samples.append(new_sample)
    return new_samples


# Step 1: 过采样现有类别1样本
oversampled_class_1 = data_class_1 * 2

print("过采样后类别1样本数:", len(oversampled_class_1))

# Step 2: SMOTE 生成更多样本
smote_target = max(0, len(oversampled_class_1) + (target_count - len(oversampled_class_1)) // 2)
smote_generated_samples = smote_samples(oversampled_class_1, smote_target)

# Step 3: 数据增强生成更多样本（基于过采样+SMOTE 的结果）
remaining_target = target_count - len(oversampled_class_1) - len(smote_generated_samples)
final_generated_samples = generate_new_samples(oversampled_class_1 + smote_generated_samples, target_count)

# 合并增强后的类别1样本
balanced_class_1 = (oversampled_class_1 + smote_generated_samples + final_generated_samples)[:target_count]

print("SMOTE生成的新样本数:", len(smote_generated_samples))
print("最终生成的新样本数:", len(final_generated_samples))
print("增强后类别1样本数:", len(balanced_class_1))

# 随机打乱两类数据
random.shuffle(balanced_class_1)
random.shuffle(data_class_0)

# 按 8:2 比例划分数据
split_index_1 = int(len(balanced_class_1) * 0.8)
split_index_0 = int(len(data_class_0) * 0.8)

train_data = balanced_class_1[:split_index_1] + data_class_0[:split_index_0]
evaluate_data = balanced_class_1[split_index_1:] + data_class_0[split_index_0:]

# 再次打乱训练集和验证集，确保混合后的数据顺序随机
random.shuffle(train_data)
random.shuffle(evaluate_data)

# 将 train_data 写入 train.json 文件
with open('dataset/train.json', 'w', encoding='utf-8') as train_file:
    json.dump(train_data, train_file, ensure_ascii=False)

# 将 evaluate_data 写入 evaluate.json 文件
with open('dataset/evaluate.json', 'w', encoding='utf-8') as evaluate_file:
    json.dump(evaluate_data, evaluate_file, ensure_ascii=False)

print("数据划分完成，已生成 train.json 和 evaluate.json 文件。")


# import json
# import random
#
#
# # Function to load data from JSON file
# def load_data(filename):
#     with open(filename, 'r', encoding='utf-8') as file:
#         return json.load(file)
#
#
# # Function to save data to JSON file
# def save_data(data, filename):
#     with open(filename, 'w', encoding='utf-8') as file:
#         json.dump(data, file, ensure_ascii=False)
#
#
# # Load all data from all_features.json
# data = load_data('dataset/all_features.json')
# print(len(data))
#
# # Filter evaluate_data based on contract_name condition
# filtered_datas = [item for item in data if
#                   (int(item['contract_name'][:-4]) < 1061) or (int(item['contract_name'][:-4]) > 2918)]
# rest_data = [item for item in data if
#              (int(item['contract_name'][:-4]) >= 1061) and (int(item['contract_name'][:-4]) <= 2918)]
#
# random.shuffle(filtered_datas)
#
# evaluate_data_filtered = filtered_datas[:667]
# train_data = filtered_datas[667:]
#
# train_data = train_data.extend(rest_data)
#
# # Save train_data to dataset.json
# save_data(train_data, r'E:\2024\experiment_code_clone\total4\all_features\dataset_clone\dataset.json')
#
# # Save evaluate_data_filtered to evaluate_clone.json
# save_data(evaluate_data_filtered,
#           r'E:\2024\experiment_code_clone\total4\all_features\dataset_clone\evaluate_clone.json')
#
# print("数据划分完成，已生成 dataset.json 和 evaluate_clone.json 文件。")
