# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import List


logger = logging.getLogger(__name__)


def uniform(dataset_sizes: List[int]):
    return [1.0] * len(dataset_sizes)


def temperature_sampling(dataset_sizes, temp):
    total_size = sum(dataset_sizes)
    return [(size / total_size) ** (1.0 / temp) for size in dataset_sizes]


def make_temperature_sampling(temp=1.0):
    def sampling_func(dataset_sizes):
        return temperature_sampling(dataset_sizes, temp)

    return sampling_func


def make_ratio_sampling(ratios):
    def sampling_func(dataset_sizes):
        return ratios

    return sampling_func


class SamplingMethod:
    @staticmethod
    def add_arguments(parser):
        parser.add_argument(
            "--sampling-method",
            choices=[
                "uniform",
                "temperature",
                "concat",
                "RoundRobin",
            ],
            type=str,
            default="concat",
            help="The method to sample data per language pairs",
        )
        parser.add_argument(
            "--sampling-temperature",
            default=1.5,
            type=float,
            help="only work with --sampling-method temperature",
        )

    @staticmethod
    def build_sampler(args, task):
        return SamplingMethod(args, task)

    def __init__(self, args, task):
        self.args = args
        self.task = task

    def is_adaptive(self):
        return False

    def sampling_method_selector(self):
        args = self.args
        logger.info(f"selected sampler: {args.sampling_method}")
        if args.sampling_method == "uniform":
            return uniform
        elif args.sampling_method == "temperature" or self.is_adaptive():
            return make_temperature_sampling(float(args.sampling_temperature))
        else:
            # default to concating all data set together
            return None
