# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os

from mmpt.utils import recursive_config


class BaseJob(object):
    def __init__(self, yaml_file, dryrun=False):
        self.yaml_file = yaml_file
        self.config = recursive_config(yaml_file)
        self.dryrun = dryrun

    def submit(self, **kwargs):
        raise NotImplementedError

    def _normalize_cmd(self, cmd_list):
        cmd_list = list(cmd_list)
        yaml_index = cmd_list.index("[yaml]")
        cmd_list[yaml_index] = self.yaml_file
        return cmd_list


class LocalJob(BaseJob):

    CMD_CONFIG = {
        "local_single": [
            "fairseq-train", "[yaml]", "--user-dir", "mmpt",
            "--task", "mmtask", "--arch", "mmarch",
            "--criterion", "mmloss",
        ],
        "local_small": [
            "fairseq-train", "[yaml]", "--user-dir", "mmpt",
            "--task", "mmtask", "--arch", "mmarch",
            "--criterion", "mmloss",
            "--distributed-world-size", "2"
        ],
        "local_big": [
            "fairseq-train", "[yaml]", "--user-dir", "mmpt",
            "--task", "mmtask", "--arch", "mmarch",
            "--criterion", "mmloss",
            "--distributed-world-size", "8"
        ],
        "local_predict": ["python", "mmpt_cli/predict.py", "[yaml]"],
    }

    def __init__(self, yaml_file, job_type=None, dryrun=False):
        super().__init__(yaml_file, dryrun)
        if job_type is None:
            self.job_type = "local_single"
            if self.config.task_type is not None:
                self.job_type = self.config.task_type
        else:
            self.job_type = job_type
        if self.job_type in ["local_single", "local_small"]:
            if self.config.fairseq.dataset.batch_size > 32:
                print("decreasing batch_size to 32 for local testing?")

    def submit(self):
        cmd_list = self._normalize_cmd(LocalJob.CMD_CONFIG[self.job_type])
        if "predict" not in self.job_type:
            # append fairseq args.
            from mmpt.utils import load_config

            config = load_config(config_file=self.yaml_file)
            for field in config.fairseq:
                for key in config.fairseq[field]:
                    if key in ["fp16", "reset_optimizer", "reset_dataloader", "reset_meters"]:  # a list of binary flag.
                        param = ["--" + key.replace("_", "-")]
                    else:
                        if key == "lr":
                            value = str(config.fairseq[field][key][0])
                        elif key == "adam_betas":
                            value = "'"+str(config.fairseq[field][key])+"'"
                        else:
                            value = str(config.fairseq[field][key])
                        param = [
                            "--" + key.replace("_", "-"),
                            value
                        ]
                    cmd_list.extend(param)

        print("launching", " ".join(cmd_list))
        if not self.dryrun:
            os.system(" ".join(cmd_list))
        return JobStatus("12345678")


class JobStatus(object):
    def __init__(self, job_id):
        self.job_id = job_id

    def __repr__(self):
        return self.job_id

    def __str__(self):
        return self.job_id

    def done(self):
        return False

    def running(self):
        return False

    def result(self):
        if self.done():
            return "{} is done.".format(self.job_id)
        else:
            return "{} is running.".format(self.job_id)

    def stderr(self):
        return self.result()

    def stdout(self):
        return self.result()
