# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

from mmengine.model import BaseTTAModel

from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample


@MODELS.register_module()
class AverageClsScoreTTA(BaseTTAModel):

    def merge_preds(
        self,
        data_samples_list: List[List[DataSample]],
    ) -> List[DataSample]:
        """Merge predictions of enhanced data to one prediction.

        Args:
            data_samples_list (List[List[DataSample]]): List of predictions
                of all enhanced data.

        Returns:
            List[DataSample]: Merged prediction.
        """
        merged_data_samples = []
        for data_samples in data_samples_list:
            merged_data_samples.append(self._merge_single_sample(data_samples))
        return merged_data_samples

    def _merge_single_sample(self, data_samples):
        merged_data_sample: DataSample = data_samples[0].new()
        merged_score = sum(data_sample.pred_score
                           for data_sample in data_samples) / len(data_samples)
        merged_data_sample.set_pred_score(merged_score)
        return merged_data_sample
