# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import List
import logging


logger = logging.getLogger(__name__)


def uniform(dataset_sizes: List[int]):
    return [1.0] * len(dataset_sizes)


def temperature_sampling(dataset_sizes, temp):
    total_size = sum(dataset_sizes)
    return [(size / total_size) ** (1.0/temp) for size in dataset_sizes]


def make_temperature_sampling(temp=1.0):
    def sampling_func(dataset_sizes):
        return temperature_sampling(dataset_sizes, temp)
    return sampling_func


def make_ratio_sampling(ratios):
    def sampling_func(dataset_sizes):
        return ratios
    return sampling_func


class SamplingMethod:
    @staticmethod
    def add_arguments(parser):
        parser.add_argument(
            '--sampling-method',
            choices=['uniform', 'temperature',  'concat', 'RoundRobin', ],
            type=str,
            default='concat',
            help='The method to sample data per language pairs')
        parser.add_argument('--sampling-temperature', default=1.5, type=float,
                            help='only work with --sampling-method temperature')

    @staticmethod
    def build_sampler(args, task):
        return SamplingMethod(args, task)

    def __init__(self, args, task):
        self.args = args
        self.task = task

    def is_adaptive(self):
        return False

    def sampling_method_selector(self):
        args = self.args
        logger.info(f'selected sampler: {args.sampling_method}')
        if args.sampling_method == 'uniform':
            return uniform
        elif args.sampling_method == 'temperature' or self.is_adaptive():
            return make_temperature_sampling(float(args.sampling_temperature))
        else:
            # default to concating all data set together
            return None
