import sys
from contextlib import redirect_stdout
from tqdm import tqdm

from neurobench.benchmarks import static_metrics
from . import workload_metrics
import torch

# workload metrics which require hooks
requires_hooks = [
    "activation_sparsity",
    "number_neuron_updates",
    "synaptic_operations",
    "membrane_updates",
]


class Benchmark:
    """Top-level benchmark class for running benchmarks."""

    def __init__(self, model, dataloader, preprocessors, postprocessors, metric_list):
        """
        Args:
            model: A NeuroBenchModel.
            dataloader: A PyTorch DataLoader.
            preprocessors: A list of NeuroBenchPreProcessors.
            postprocessors: A list of NeuroBenchPostProcessors.
            metric_list: A list of lists of strings of metrics to run.
                First item is static metrics, second item is data metrics.
        """

        self.model = model
        self.dataloader = dataloader  # dataloader not dataset
        self.preprocessors = preprocessors
        self.postprocessors = postprocessors

        self.static_metrics = {m: getattr(static_metrics, m) for m in metric_list[0]}
        self.workload_metrics = {
            m: getattr(workload_metrics, m) for m in metric_list[1]
        }

    def run(
        self,
        quiet=False,
        verbose: bool = False,
        dataloader=None,
        preprocessors=None,
        postprocessors=None,
        device=None,
    ):
        """
        Runs batched evaluation of the benchmark.  运行基准的批量评估。

        Args:
            dataloader (optional): override DataLoader for this run.
            preprocessors (optional): override preprocessors for this run.
            postprocessors (optional): override postprocessors for this run.
            quiet (bool, default=False): If True, output is suppressed.
            verbose (bool, default=False): If True, metrics for each bach will be printed.
                                           If False (default), metrics are accumulated and printed after all batches are processed.
            device (optional): use device for this run (e.g. 'cuda' or 'cpu').

        Returns:
            results: A dictionary of results.

        """
        print("=" * 20)   # 打印分隔符

        ## 如果 quiet 为 True，则 redirect_stdout(None) 会将所有输出重定向到 None，即不显示任何输出。
        ## 如果 quiet 为 False，则 redirect_stdout(sys.stdout) 会保持默认行为，将输出打印到控制台。
        with redirect_stdout(None if quiet else sys.stdout):
            print("Running benchmark")

            # Static metrics  # 打印footprint和connection_sparsity
            results = {}
            for m in self.static_metrics.keys():
                print("-" * 10)
                print(m)
                print(self.static_metrics[m](self.model))
                results[m] = self.static_metrics[m](self.model)

            # add hooks to the model
            # self.workload_metrics.keys(): 获取 self.workload_metrics 中所有指标的键（指标名称）。
            # m: 遍历 self.workload_metrics.keys() 中的每个指标名称。
            # m in requires_hooks: 检查当前指标名称 m 是否在 requires_hooks 中。
            # [m in requires_hooks for m in self.workload_metrics.keys()]: 生成一个布尔值列表，表示每个指标是否需要钩子。如果某些工作负载指标需要钩子（如激活值或连接信息），则添加相应的钩子。
            if any([m in requires_hooks for m in self.workload_metrics.keys()]):
                workload_metrics.detect_activations_connections(self.model)

            # 如果调用 run 方法时传入了 dataloader 参数（即 dataloader is not None），则使用传入的 dataloader。
            # 如果没有传入 dataloader 参数（即 dataloader is None），则使用默认的 self.dataloader。
            dataloader = dataloader if dataloader is not None else self.dataloader
            # 预处理同上
            preprocessors = (
                preprocessors if preprocessors is not None else self.preprocessors
            )
            # 后处理同上
            postprocessors = (
                postprocessors if postprocessors is not None else self.postprocessors
            )

            # Init/re-init stateful data metrics
            for m in self.workload_metrics.keys():
                # isinstance(self.workload_metrics[m], type)：检查当前指标是否是一个类（未实例化）（type）。
                # issubclass(self.workload_metrics[m], workload_metrics.AccumulatedMetric)：检查当前指标是否是 AccumulatedMetric 的子类。
                # 如果满足上述条件，则实例化该指标类。
                if isinstance(self.workload_metrics[m], type) and issubclass(
                    self.workload_metrics[m], workload_metrics.AccumulatedMetric
                ):
                    self.workload_metrics[m] = self.workload_metrics[m]()
                elif isinstance(
                    self.workload_metrics[m], workload_metrics.AccumulatedMetric
                ):  # new benchmark run, reset metric state
                    self.workload_metrics[m].reset()

            dataset_len = len(dataloader.dataset) # 得到数据集的数量（trial的总数）

            # 如果指定了设备，则将模型移动到该设备。
            if device is not None:
                self.model.net.to(device)

            batch_num = 0
                        # tqdm：在循环中显示进度条
                        # dataloader: 要遍历的对象，通常是一个 DataLoader 实例。
                        # total: 总进度条的长度，通常设置为数据加载器的长度（即批次数量）。
                        # disable: 是否禁用进度条。如果为 True，则不显示进度条。
            for data in tqdm(dataloader, total=len(dataloader), disable=quiet):
                if device is not None:
                    data = (data[0].to(device), data[1].to(device))

                batch_size = data[0].size(0)

                # convert data to tuple
                if type(data) is not tuple:
                    data = tuple(data)

                # Preprocessing data
                for alg in preprocessors:
                    data = alg(data)

                # Run model on test data
                with torch.no_grad():  # 禁用梯度计算
                    preds = self.model(data[0])

                for alg in postprocessors:
                    preds = alg(preds)

                # Data metrics 数据度量
                batch_results = {}
                for m in self.workload_metrics.keys():
                    batch_results[m] = self.workload_metrics[m](self.model, preds, data)   # 分别计算r2、activation_sparsity、synaptic_operations等
                print("output:")
                for m, v in batch_results.items():
                    print(m, v)
                    # AccumulatedMetrics are computed after all batches complete
                    if isinstance(
                        self.workload_metrics[m], workload_metrics.AccumulatedMetric
                    ):
                        print("accumulated metric")
                        continue
                    # otherwise accumulate via mean
                    else:
                        assert isinstance(v, float) or isinstance(v, int), "Data metric must return float or int to be accumulated"
                        if m not in results:
                            results[m] = v * batch_size / dataset_len
                        else:
                            results[m] += v * batch_size / dataset_len

                # delete hook contents 重置hook
                self.model.reset_hooks()

                if verbose:  # 打印详细结果
                    for m in self.workload_metrics.keys():
                        if isinstance(
                            self.workload_metrics[m], workload_metrics.AccumulatedMetric
                        ):
                            results[m] = self.workload_metrics[m].compute()
                    print(f"\nBatch num {batch_num + 1}/{len(dataloader)}")
                    print(results)

                batch_num += 1

            # compute AccumulatedMetrics after all batches if they are not calculated at every iteration 如果不是每次迭代都计算累积度量，则在所有批次后计算累积度量
            if not verbose:
                for m in self.workload_metrics.keys():
                    if isinstance(
                        self.workload_metrics[m], workload_metrics.AccumulatedMetric
                    ):
                        results[m] = self.workload_metrics[m].compute()

        # close hooks
        for hook in self.model.activation_hooks:
            hook.reset()
            hook.close()
        for hook in self.model.connection_hooks:
            hook.reset()
            hook.close()
        self.model.activation_hooks = []
        self.model.connection_hooks = []

        return results
