import numpy as np
import torch
from tqdm import tqdm
from naslib.search_spaces import NasBench301SearchSpace
from naslib.utils import get_train_val_loaders, get_project_root, get_dataset_api
from naslib.predictors import ZeroCost
from fvcore.common.config import CfgNode
import pandas as pd
from naslib.search_spaces.core.query_metrics import Metric

def test_api(graph, dataset_api, metric):
    result = graph.query(metric, dataset="cifar10", dataset_api=dataset_api)
    assert result != -1
    return result

# 设置搜索空间参数
num_samples = 20000  # 随机抽样架构数量
top_k = 300     # 最终保留前 30 个架构
# 只关注两个指标：flops 和 nwot
all_metrics = ['params']

# 创建配置以获取训练和验证加载器
config = {
    'dataset': 'cifar10',       # 数据集名称
    'data': str(get_project_root()) + '/data',  # 数据路径
    'search': {
        'seed': 9001,           # 随机种子
        'train_portion': 0.7,     # 训练集所占比例
        'batch_size': 32,         # 批量大小
    }
}
config = CfgNode(config)

# 获取数据加载器
train_loader, val_loader, test_loader, train_transform, valid_transform = get_train_val_loaders(config)

# 使用 NASBench-301 API 获取数据集 API
space = "nasbench301"
task = "cifar10"
dataset_api = get_dataset_api(search_space=space, dataset=task)

# 存储所有采样架构的结果，结果格式：(spec, scores)
# scores 字典包含：'val_acc', 'flops', 'nwot'
all_results = []

for _ in tqdm(range(num_samples), desc="Sampling architectures"):
    # 随机生成架构
    graph = NasBench301SearchSpace()
    # graph.sample_random_architecture()
    graph.sample_random_architecture()
    # 调用 parse() 构造模型，确保模型包含参数信息
    graph.parse()
    
    # 计算验证准确率（这里仍然用 test_api 获取 ground truth）
    val_acc = test_api(graph, dataset_api, Metric.VAL_ACCURACY)
    
    # 获取架构的规格
    spec = graph.get_hash()
    
    # 计算指标分数
    scores = {}
    scores["val_acc"] = val_acc
    for metric in all_metrics:
        zc_predictor = ZeroCost(method_type=metric)
        # 对于所有指标均传入 train_loader
        score = zc_predictor.query(graph=graph, dataloader=train_loader)
        scores[metric] = score
        
    all_results.append((spec, scores))

# 根据 flops 分数排序（假设更高分数更好），选取前 1/3 架构
all_results_sorted = sorted(all_results, key=lambda x: x[1]["params"], reverse=True)

final_results = all_results_sorted[:top_k]

# 整理结果用于显示和保存
result_data = {
    "spec": [r[0] for r in final_results],
    "val_acc": [r[1]["val_acc"] for r in final_results],
    "flops": [r[1]["params"] for r in final_results],
}
result_df = pd.DataFrame(result_data, index=[f"Arch_{i+1}" for i in range(len(final_results))])

# 保存为 CSV 文件，方便日后读取
output_filename = "top30_architectures_by_para_nb301.csv"
result_df.to_csv(output_filename, index=True)
print(f"Top {len(final_results)} architectures saved to {output_filename}")

# 同时尝试利用 ace_tools 显示结果，如不可用则直接打印 DataFrame
try:
    import ace_tools as tools
    tools.display_dataframe_to_user(name="Top Architectures by NWOT (after FLOPS filtering)", dataframe=result_df)
except ModuleNotFoundError:
    print(result_df)
