# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os
import warnings
from contextlib import contextmanager
from typing import Any, Dict, Optional, Text

from ..config import C
from ..utils import Wrapper
from ..utils.exceptions import RecorderInitializationError
from .exp import Experiment
from .expm import ExpManager
from .recorder import MLflowRecorder, Recorder


class QlibRecorder:
    """A global system that helps to manage the experiments."""

    suffix: str = ""  # suffix represents sub-dir in the experiment directory

    def __init__(self, exp_manager: ExpManager):
        self.exp_manager: ExpManager = exp_manager
        self.suffix = ""

    def __repr__(self):
        return "{name}(manager={manager})".format(
            name=self.__class__.__name__, manager=self.exp_manager
        )

    @contextmanager
    def start(
        self,
        *,
        experiment_id: Optional[Text] = None,
        experiment_name: Optional[Text] = None,
        recorder_id: Optional[Text] = None,
        recorder_name: Optional[Text] = None,
        uri: Optional[Text] = None,
        resume: bool = False,
    ):
        """Method to start an experiment. This method can only be called within
        a Python's `with` statement. Here is the example code:

        .. code-block:: Python

            # start new experiment and recorder
            with R.start(experiment_name='test', recorder_name='recorder_1'):
                model.fit(dataset)
                R.log...
                ... # further operations

            # resume previous experiment and recorder
            with R.start(experiment_name='test', recorder_name='recorder_1', resume=True): # if users want to resume recorder, they have to specify the exact same name for experiment and recorder.
                ... # further operations


        Parameters
        ----------
        experiment_id : str
            id of the experiment one wants to start.
        experiment_name : str
            name of the experiment one wants to start.
        recorder_id : str
            id of the recorder under the experiment one wants to start.
        recorder_name : str
            name of the recorder under the experiment one wants to start.
        uri : str
            The tracking uri of the experiment, where all the artifacts/metrics etc. will be stored.
            The default uri is set in the q4l.qlib.config. Note that this uri argument will not change the one defined in the config file.
            Therefore, the next time when users call this function in the same experiment,
            they have to also specify this argument with the same value. Otherwise, inconsistent uri may occur.
        resume : bool
            whether to resume the specific recorder with given name under the given experiment.

        """
        run = self.start_exp(
            experiment_id=experiment_id,
            experiment_name=experiment_name,
            recorder_id=recorder_id,
            recorder_name=recorder_name,
            uri=uri,
            resume=resume,
        )
        try:
            yield run
        except Exception as e:
            self.end_exp(
                Recorder.STATUS_FA
            )  # end the experiment if something went wrong
            raise e
        self.end_exp(Recorder.STATUS_FI)

    def start_exp(
        self,
        *,
        experiment_id=None,
        experiment_name=None,
        recorder_id=None,
        recorder_name=None,
        uri=None,
        resume=False,
    ):
        """Lower level method for starting an experiment. When use this method,
        one should end the experiment manually and the status of the recorder
        may not be handled properly. Here is the example code:

        .. code-block:: Python

            R.start_exp(experiment_name='test', recorder_name='recorder_1')
            ... # further operations
            R.end_exp('FINISHED') or R.end_exp(Recorder.STATUS_S)


        Parameters
        ----------
        experiment_id : str
            id of the experiment one wants to start.
        experiment_name : str
            the name of the experiment to be started
        recorder_id : str
            id of the recorder under the experiment one wants to start.
        recorder_name : str
            name of the recorder under the experiment one wants to start.
        uri : str
            the tracking uri of the experiment, where all the artifacts/metrics etc. will be stored.
            The default uri are set in the q4l.qlib.config.
        resume : bool
            whether to resume the specific recorder with given name under the given experiment.

        Returns
        -------
        An experiment instance being started.

        """
        return self.exp_manager.start_exp(
            experiment_id=experiment_id,
            experiment_name=experiment_name,
            recorder_id=recorder_id,
            recorder_name=recorder_name,
            uri=uri,
            resume=resume,
        )

    def end_exp(self, recorder_status=Recorder.STATUS_FI):
        """Method for ending an experiment manually. It will end the current
        active experiment, as well as its active recorder with the specified
        `status` type. Here is the example code of the method:

        .. code-block:: Python

            R.start_exp(experiment_name='test')
            ... # further operations
            R.end_exp('FINISHED') or R.end_exp(Recorder.STATUS_S)

        Parameters
        ----------
        status : str
            The status of a recorder, which can be SCHEDULED, RUNNING, FINISHED, FAILED.

        """
        self.exp_manager.end_exp(recorder_status)

    def search_records(self, experiment_ids, **kwargs):
        """Get a pandas DataFrame of records that fit the search criteria.

        The arguments of this function are not set to be rigid, and they will be different with different implementation of
        ``ExpManager`` in ``Qlib``. ``Qlib`` now provides an implementation of ``ExpManager`` with mlflow, and here is the
        example code of the method with the ``MLflowExpManager``:

        .. code-block:: Python

            R.log_metrics(m=2.50, step=0)
            records = R.search_records([experiment_id], order_by=["metrics.m DESC"])

        Parameters
        ----------
        experiment_ids : list
            list of experiment IDs.
        filter_string : str
            filter query string, defaults to searching all runs.
        run_view_type : int
            one of enum values ACTIVE_ONLY, DELETED_ONLY, or ALL (e.g. in mlflow.entities.ViewType).
        max_results  : int
            the maximum number of runs to put in the dataframe.
        order_by : list
            list of columns to order by (e.g., “metrics.rmse”).

        Returns
        -------
        A pandas.DataFrame of records, where each metric, parameter, and tag
        are expanded into their own columns named metrics.*, params.*, and tags.*
        respectively. For records that don't have a particular metric, parameter, or tag, their
        value will be (NumPy) Nan, None, or None respectively.

        """
        return self.exp_manager.search_records(experiment_ids, **kwargs)

    def list_experiments(self):
        """Method for listing all the existing experiments (except for those
        being deleted.)

        .. code-block:: Python

            exps = R.list_experiments()

        Returns
        -------
        A dictionary (name -> experiment) of experiments information that being stored.

        """
        return self.exp_manager.list_experiments()

    def list_recorders(self, experiment_id=None, experiment_name=None):
        """Method for listing all the recorders of experiment with given id or
        name.

        If user doesn't provide the id or name of the experiment, this method will try to retrieve the default experiment and
        list all the recorders of the default experiment. If the default experiment doesn't exist, the method will first
        create the default experiment, and then create a new recorder under it. (More information about the default experiment
        can be found `here <../component/recorder.html#q4l.qlib.workflow.exp.Experiment>`__).

        Here is the example code:

        .. code-block:: Python

            recorders = R.list_recorders(experiment_name='test')

        Parameters
        ----------
        experiment_id : str
            id of the experiment.
        experiment_name : str
            name of the experiment.

        Returns
        -------
        A dictionary (id -> recorder) of recorder information that being stored.

        """
        return self.get_exp(
            experiment_id=experiment_id, experiment_name=experiment_name
        ).list_recorders()

    def get_exp(
        self,
        *,
        experiment_id=None,
        experiment_name=None,
        create: bool = True,
        start: bool = False,
    ) -> Experiment:
        """Method for retrieving an experiment with given id or name. Once the
        `create` argument is set to True, if no valid experiment is found, this
        method will create one for you. Otherwise, it will only retrieve a
        specific experiment or raise an Error.

        - If '`create`' is True:

            - If `active experiment` exists:

                - no id or name specified, return the active experiment.

                - if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name.

            - If `active experiment` not exists:

                - no id or name specified, create a default experiment, and the experiment is set to be active.

                - if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given name or the default experiment.

        - Else If '`create`' is False:

            - If `active experiment` exists:

                - no id or name specified, return the active experiment.

                - if id or name is specified, return the specified experiment. If no such exp found, raise Error.

            - If `active experiment` not exists:

                - no id or name specified. If the default experiment exists, return it, otherwise, raise Error.

                - if id or name is specified, return the specified experiment. If no such exp found, raise Error.

        Here are some use cases:

        .. code-block:: Python

            # Case 1
            with R.start('test'):
                exp = R.get_exp()
                recorders = exp.list_recorders()

            # Case 2
            with R.start('test'):
                exp = R.get_exp(experiment_name='test1')

            # Case 3
            exp = R.get_exp() -> a default experiment.

            # Case 4
            exp = R.get_exp(experiment_name='test')

            # Case 5
            exp = R.get_exp(create=False) -> the default experiment if exists.

        Parameters
        ----------
        experiment_id : str
            id of the experiment.
        experiment_name : str
            name of the experiment.
        create : boolean
            an argument determines whether the method will automatically create a new experiment
            according to user's specification if the experiment hasn't been created before.
        start : bool
            when start is True,
            if the experiment has not started(not activated), it will start
            It is designed for R.log_params to auto start experiments

        Returns
        -------
        An experiment instance with given id or name.

        """
        return self.exp_manager.get_exp(
            experiment_id=experiment_id,
            experiment_name=experiment_name,
            create=create,
            start=start,
        )

    def set(self, experiment_name: str, recorder_name: str):
        """Set the active experiment and recorder.

        Will set `active_experiment` for self.exp_manager and `active_recorder`
        for self.exp_manager.active_experiment.

        """
        active_experiment: Experiment = self.exp_manager.get_exp(
            experiment_name=experiment_name, start=False, create=False
        )
        active_recorder = active_experiment.get_recorder(
            recorder_name=recorder_name, create=False
        )
        active_experiment.active_recorder = active_recorder
        self.exp_manager.active_experiment = active_experiment

    def delete_exp(self, experiment_id=None, experiment_name=None):
        """Method for deleting the experiment with given id or name. At least
        one of id or name must be given, otherwise, error will occur.

        Here is the example code:

        .. code-block:: Python

            R.delete_exp(experiment_name='test')

        Parameters
        ----------
        experiment_id : str
            id of the experiment.
        experiment_name : str
            name of the experiment.

        """
        self.exp_manager.delete_exp(experiment_id, experiment_name)

    def get_uri(self):
        """Method for retrieving the uri of current experiment manager.

        Here is the example code:

        .. code-block:: Python

            uri = R.get_uri()

        Returns
        -------
        The uri of current experiment manager.

        """
        return self.exp_manager.uri

    def set_uri(self, uri: Optional[Text]):
        """Method to reset the **default** uri of current experiment manager.

        NOTE:

        - When the uri is refer to a file path, please using the absolute path instead of strings like "~/mlruns/"
          The backend don't support strings like this.

        """
        self.exp_manager.default_uri = uri

    @contextmanager
    def uri_context(self, uri: Text):
        """Temporarily set the exp_manager's **default_uri** to uri.

        NOTE:
        - Please refer to the NOTE in the `set_uri`

        Parameters
        ----------
        uri : Text
            the temporal uri

        """
        prev_uri = self.exp_manager.default_uri
        self.exp_manager.default_uri = uri
        try:
            yield
        finally:
            self.exp_manager.default_uri = prev_uri

    @property
    def artifact_uri(self):
        rec = self.get_recorder()
        if not isinstance(rec, MLflowRecorder):
            raise TypeError(
                "The `artifact_uri` is only available for MLflowRecorder, "
                "but the current recorder is {}".format(type(rec))
            )
        arti_root_uri = rec.get_artifact_uri()
        artifact_uri = os.path.join(arti_root_uri, self.suffix)
        return artifact_uri

    def get_recorder(
        self,
        *,
        recorder_id=None,
        recorder_name=None,
        experiment_id=None,
        experiment_name=None,
    ) -> Recorder:
        """Method for retrieving a recorder.

        - If `active recorder` exists:

            - no id or name specified, return the active recorder.

            - if id or name is specified, return the specified recorder.

        - If `active recorder` not exists:

            - no id or name specified, raise Error.

            - if id or name is specified, and the corresponding experiment_name must be given, return the specified recorder. Otherwise, raise Error.

        The recorder can be used for further process such as `save_object`, `load_object`, `log_params`,
        `log_metrics`, etc.

        Here are some use cases:

        .. code-block:: Python

            # Case 1
            with R.start(experiment_name='test'):
                recorder = R.get_recorder()

            # Case 2
            with R.start(experiment_name='test'):
                recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d')

            # Case 3
            recorder = R.get_recorder() -> Error

            # Case 4
            recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d') -> Error

            # Case 5
            recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d', experiment_name='test')


        Here are some things users may concern
        - Q: What recorder will it return if multiple recorder meets the query (e.g. query with experiment_name)
        - A: If mlflow backend is used, then the recorder with the latest `start_time` will be returned. Because MLflow's `search_runs` function guarantee it

        Parameters
        ----------
        recorder_id : str
            id of the recorder.
        recorder_name : str
            name of the recorder.
        experiment_name : str
            name of the experiment.

        Returns
        -------
        A recorder instance.

        """
        return self.get_exp(
            experiment_name=experiment_name,
            experiment_id=experiment_id,
            create=False,
        ).get_recorder(recorder_id, recorder_name, create=False, start=False)

    def delete_recorder(self, recorder_id=None, recorder_name=None):
        """Method for deleting the recorders with given id or name. At least one
        of id or name must be given, otherwise, error will occur.

        Here is the example code:

        .. code-block:: Python

            R.delete_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d')

        Parameters
        ----------
        recorder_id : str
            id of the experiment.
        recorder_name : str
            name of the experiment.

        """
        self.get_exp().delete_recorder(recorder_id, recorder_name)

    def save_objects(
        self, local_path=None, artifact_path=None, **kwargs: Dict[Text, Any]
    ):
        """Method for saving objects as artifacts in the experiment to the uri.
        It supports either saving from a local file/directory, or directly
        saving objects. User can use valid python's keywords arguments to
        specify the object to be saved as well as its name (name: value).

        In summary, this API is designs for saving **objects** to **the experiments management backend path**,
        1. Qlib provide two methods to specify **objects**
        - Passing in the object directly by passing with `**kwargs` (e.g. R.save_objects(trained_model=model))
        - Passing in the local path to the object, i.e. `local_path` parameter.
        2. `artifact_path` represents the  **the experiments management backend path**

        - If `active recorder` exists: it will save the objects through the active recorder.
        - If `active recorder` not exists: the system will create a default experiment, and a new recorder and save objects under it.

        .. note::

            If one wants to save objects with a specific recorder. It is recommended to first get the specific recorder through `get_recorder` API and use the recorder the save objects. The supported arguments are the same as this method.

        Here are some use cases:

        .. code-block:: Python

            # Case 1
            with R.start(experiment_name='test'):
                pred = model.predict(dataset)
                R.save_objects(**{"pred.pkl": pred}, artifact_path='prediction')
                rid = R.get_recorder().id
            ...
            R.get_recorder(recorder_id=rid).load_object("prediction/pred.pkl")  #  after saving objects, you can load the previous object with this api

            # Case 2
            with R.start(experiment_name='test'):
                R.save_objects(local_path='results/pred.pkl', artifact_path="prediction")
                rid = R.get_recorder().id
            ...
            R.get_recorder(recorder_id=rid).load_object("prediction/pred.pkl")  #  after saving objects, you can load the previous object with this api


        Parameters
        ----------
        local_path : str
            if provided, them save the file or directory to the artifact URI.
        artifact_path : str
            the relative path for the artifact to be stored in the URI.
        **kwargs: Dict[Text, Any]
            the object to be saved.
            For example, `{"pred.pkl": pred}`

        """
        if local_path is not None and len(kwargs) > 0:
            raise ValueError(
                "You can choose only one of `local_path`(save the files in a path) or `kwargs`(pass in the objects directly)"
            )
        if artifact_path is None:
            artifact_path = R.suffix
        elif len(R.suffix) > 0:
            artifact_path = os.path.join(R.suffix, artifact_path)
        self.get_exp().get_recorder(start=True).save_objects(
            local_path, artifact_path, **kwargs
        )

    def load_object(self, name: Text):
        """Method for loading an object from artifacts in the experiment in the
        uri."""
        if len(R.suffix) > 0:
            fname = os.path.join(R.suffix, name)
        else:
            fname = name
        return self.get_exp().get_recorder(start=True).load_object(fname)

    def log_params(self, **kwargs):
        """Method for logging parameters during an experiment. In addition to
        using ``R``, one can also log to a specific recorder after getting it
        with `get_recorder` API.

        - If `active recorder` exists: it will log parameters through the active recorder.
        - If `active recorder` not exists: the system will create a default experiment as well as a new recorder, and log parameters under it.

        Here are some use cases:

        .. code-block:: Python

            # Case 1
            with R.start('test'):
                R.log_params(learning_rate=0.01)

            # Case 2
            R.log_params(learning_rate=0.01)

        Parameters
        ----------
        keyword argument:
            name1=value1, name2=value2, ...

        """
        self.get_exp(start=True).get_recorder(start=True).log_params(**kwargs)

    def log_metrics(self, step=None, **kwargs):
        """Method for logging metrics during an experiment. In addition to using
        ``R``, one can also log to a specific recorder after getting it with
        `get_recorder` API.

        - If `active recorder` exists: it will log metrics through the active recorder.
        - If `active recorder` not exists: the system will create a default experiment as well as a new recorder, and log metrics under it.

        Here are some use cases:

        .. code-block:: Python

            # Case 1
            with R.start('test'):
                R.log_metrics(train_loss=0.33, step=1)

            # Case 2
            R.log_metrics(train_loss=0.33, step=1)

        Parameters
        ----------
        keyword argument:
            name1=value1, name2=value2, ...

        """
        self.get_exp(start=True).get_recorder(start=True).log_metrics(
            step, **kwargs
        )

    def log_artifact(
        self, local_path: str, artifact_path: Optional[str] = None
    ):
        """Log a local file or directory as an artifact of the currently active
        run.

        - If `active recorder` exists: it will set tags through the active recorder.
        - If `active recorder` not exists: the system will create a default experiment as well as a new recorder, and set the tags under it.

        Parameters
        ----------
        local_path : str
            Path to the file to write.
        artifact_path : Optional[str]
            If provided, the directory in ``artifact_uri`` to write to.

        """
        self.get_exp(start=True).get_recorder(start=True).log_artifact(
            local_path, artifact_path
        )

    def download_artifact(
        self, path: str, dst_path: Optional[str] = None
    ) -> str:
        """Download an artifact file or directory from a run to a local
        directory if applicable, and return a local path for it.

        Parameters
        ----------
        path : str
            Relative source path to the desired artifact.
        dst_path : Optional[str]
            Absolute path of the local filesystem destination directory to which to
            download the specified artifacts. This directory must already exist.
            If unspecified, the artifacts will either be downloaded to a new
            uniquely-named directory on the local filesystem.

        Returns
        -------
        str
            Local path of desired artifact.

        """
        self.get_exp(start=True).get_recorder(start=True).download_artifact(
            path, dst_path
        )

    def set_suffix(self, suffix: str):
        self.suffix = suffix

    def get_suffix(self):
        return self.suffix

    def set_tags(self, **kwargs):
        """Method for setting tags for a recorder. In addition to using ``R``,
        one can also set the tag to a specific recorder after getting it with
        `get_recorder` API.

        - If `active recorder` exists: it will set tags through the active recorder.
        - If `active recorder` not exists: the system will create a default experiment as well as a new recorder, and set the tags under it.

        Here are some use cases:

        .. code-block:: Python

            # Case 1
            with R.start('test'):
                R.set_tags(release_version="2.2.0")

            # Case 2
            R.set_tags(release_version="2.2.0")

        Parameters
        ----------
        keyword argument:
            name1=value1, name2=value2, ...

        """
        self.get_exp(start=True).get_recorder(start=True).set_tags(**kwargs)


class RecorderWrapper(Wrapper):
    """Wrapper class for QlibRecorder, which detects whether users reinitialize
    qlib when already starting an experiment."""

    def register(self, provider):
        if self._provider is not None:
            expm = getattr(self._provider, "exp_manager")
            if expm.active_experiment is not None:
                warnings.warn(
                    "Please don't reinitialize Qlib if QlibRecorder is already activated. Otherwise, the experiment stored location will be modified."
                )
        self._provider = provider


import sys

if sys.version_info >= (3, 9):
    from typing import Annotated

    QlibRecorderWrapper = Annotated[QlibRecorder, RecorderWrapper]
else:
    QlibRecorderWrapper = QlibRecorder

# global record
R: QlibRecorderWrapper = RecorderWrapper()
