# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import abc
import pickle
import warnings
from pathlib import Path
from typing import List, Tuple, Union

import pandas as pd

from ...log import get_module_logger
from ...utils import init_instance_by_config, load_dataset, time_to_slc_point
from ...utils.serial import Serializable
from ..data import D


class DataLoader(abc.ABC):
    """DataLoader is designed for loading raw data from original data source."""

    @abc.abstractmethod
    def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame:
        """load the data as pd.DataFrame.

        Example of the data (The multi-index of the columns is optional.):

            .. code-block:: text

                                        feature                                                             label
                                        $close     $volume     Ref($close, 1)  Mean($close, 3)  $high-$low  LABEL0
                datetime    instrument
                2010-01-04  SH600000    81.807068  17145150.0       83.737389        83.016739    2.741058  0.0032
                            SH600004    13.313329  11800983.0       13.313329        13.317701    0.183632  0.0042
                            SH600005    37.796539  12231662.0       38.258602        37.919757    0.970325  0.0289


        Parameters
        ----------
        instruments : str or dict
            it can either be the market name or the config file of instruments generated by InstrumentProvider.
        start_time : str
            start of the time range.
        end_time : str
            end of the time range.

        Returns
        -------
        pd.DataFrame:
            data load from the under layer source

        """


class DLWParser(DataLoader):
    """(D)ata(L)oader (W)ith (P)arser for features and names.

    Extracting this class so that QlibDataLoader and other dataloaders(such as
    QdbDataLoader) can share the fields.

    """

    def __init__(self, config: Union[list, tuple, dict]):
        """
        Parameters
        ----------
        config : Union[list, tuple, dict]
            Config will be used to describe the fields and column names

            .. code-block::

                <config> := {
                    "group_name1": <fields_info1>
                    "group_name2": <fields_info2>
                }
                or
                <config> := <fields_info>

                <fields_info> := ["expr", ...] | (["expr", ...], ["col_name", ...])
                # NOTE: list or tuple will be treated as the things when parsing
        """
        self.logger = get_module_logger(self.__class__.__name__)
        self.is_group = isinstance(config, dict)

        if self.is_group:
            self.fields = {
                grp: self._parse_fields_info(fields_info)
                for grp, fields_info in config.items()
            }
        else:
            self.fields = self._parse_fields_info(config)

    def _parse_fields_info(
        self, fields_info: Union[list, tuple]
    ) -> Tuple[list, list]:
        if len(fields_info) == 0:
            raise ValueError("The size of fields must be greater than 0")

        if not isinstance(fields_info, (list, tuple)):
            raise TypeError("Unsupported type")

        if isinstance(fields_info[0], str):
            exprs = names = fields_info
        elif isinstance(fields_info[0], (list, tuple)):
            exprs, names = fields_info
        else:
            raise NotImplementedError(f"This type of input is not supported")
        return exprs, names

    @abc.abstractmethod
    def load_group_df(
        self,
        instruments,
        exprs: list,
        names: list,
        start_time: Union[str, pd.Timestamp] = None,
        end_time: Union[str, pd.Timestamp] = None,
        gp_name: str = None,
    ) -> pd.DataFrame:
        """Load the dataframe for specific group.

        Parameters
        ----------
        instruments :
            the instruments.
        exprs : list
            the expressions to describe the content of the data.
        names : list
            the name of the data.

        Returns
        -------
        pd.DataFrame:
            the queried dataframe.

        """

    def load(
        self, instruments=None, start_time=None, end_time=None
    ) -> pd.DataFrame:
        if self.is_group:
            df = pd.concat(
                {
                    grp: self.load_group_df(
                        instruments, exprs, names, start_time, end_time, grp
                    )
                    for grp, (exprs, names) in self.fields.items()
                },
                axis=1,
            )
        else:
            exprs, names = self.fields
            df = self.load_group_df(
                instruments, exprs, names, start_time, end_time
            )
        return df


class QlibDataLoader(DLWParser):
    """Same as QlibDataLoader.

    The fields can be define by config

    """

    def __init__(
        self,
        config: Tuple[list, tuple, dict],
        filter_pipe: List = None,
        swap_level: bool = True,
        freq: Union[str, dict] = "day",
        inst_processors: Union[dict, list] = None,
    ):
        """
        Parameters
        ----------
        config : Tuple[list, tuple, dict]
            Please refer to the doc of DLWParser
        filter_pipe :
            Filter pipe for the instruments
        swap_level :
            Whether to swap level of MultiIndex
        freq:  dict or str
            If type(config) == dict and type(freq) == str, load config data using freq.
            If type(config) == dict and type(freq) == dict, load config[<group_name>] data using freq[<group_name>]
        inst_processors: dict | list
            If inst_processors is not None and type(config) == dict; load config[<group_name>] data using inst_processors[<group_name>]
            If inst_processors is a list, then it will be applied to all groups.
        """
        self.filter_pipe = filter_pipe
        self.swap_level = swap_level
        self.freq = freq

        # sample
        self.inst_processors = (
            inst_processors if inst_processors is not None else {}
        )
        assert isinstance(
            self.inst_processors, (dict, list)
        ), f"inst_processors(={self.inst_processors}) must be dict or list"

        super().__init__(config)

        if self.is_group:
            # check sample config
            if isinstance(freq, dict):
                for _gp in config.keys():
                    if _gp not in freq:
                        raise ValueError(f"freq(={freq}) missing group(={_gp})")
                assert (
                    self.inst_processors
                ), f"freq(={self.freq}), inst_processors(={self.inst_processors}) cannot be None/empty"

    def load_group_df(
        self,
        instruments,
        exprs: list,
        names: list,
        start_time: Union[str, pd.Timestamp] = None,
        end_time: Union[str, pd.Timestamp] = None,
        gp_name: str = None,
    ) -> pd.DataFrame:
        if instruments is None:
            warnings.warn("`instruments` is not set, will load all stocks")
            instruments = "all"
        if isinstance(instruments, str):
            instruments = D.instruments(
                instruments, filter_pipe=self.filter_pipe
            )
        elif self.filter_pipe is not None:
            warnings.warn(
                "`filter_pipe` is not None, but it will not be used with `instruments` as list"
            )

        freq = self.freq[gp_name] if isinstance(self.freq, dict) else self.freq
        inst_processors = (
            self.inst_processors
            if isinstance(self.inst_processors, list)
            else self.inst_processors.get(gp_name, [])
        )
        self.logger.info(
            'Loading data for group "%s" with freq "%s" from disk',
            gp_name,
            freq,
        )

        df = D.features(
            instruments,
            exprs,
            start_time,
            end_time,
            freq=freq,
            inst_processors=self.inst_processor.get(gp_name, []),
        )
        df = D.features(
            instruments,
            exprs,
            start_time,
            end_time,
            freq=freq,
            inst_processors=inst_processors,
        )
        df.columns = names
        self.logger.info(
            'Loaded data for group "%s" with freq "%s" from disk', gp_name, freq
        )
        if self.swap_level:
            df = (
                df.swaplevel().sort_index()
            )  # NOTE: if swaplevel, return <datetime, instrument>
        return df


class StaticDataLoader(DataLoader, Serializable):
    """DataLoader that supports loading data from file or as provided."""

    include_attr = ["_config"]

    def __init__(self, config: Union[dict, str, pd.DataFrame], join="outer"):
        """
        Parameters
        ----------
        config : dict
            {fields_group: <path or object>}
        join : str
            How to align different dataframes
        """
        self._config = config  # using "_" to avoid confliction with the method `config` of Serializable
        self.join = join
        self._data = None

    def __getstate__(self) -> dict:
        # avoid pickling `self._data`
        return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}

    def load(
        self, instruments=None, start_time=None, end_time=None
    ) -> pd.DataFrame:
        self._maybe_load_raw_data()
        if instruments is None:
            df = self._data
        else:
            df = self._data.loc(axis=0)[:, instruments]
        if start_time is None and end_time is None:
            return df  # NOTE: avoid copy by loc
        # pd.Timestamp(None) == NaT, use NaT as index can not fetch correct thing, so do not change None.
        start_time = time_to_slc_point(start_time)
        end_time = time_to_slc_point(end_time)
        return df.loc[start_time:end_time]

    def _maybe_load_raw_data(self):
        if self._data is not None:
            return
        if isinstance(self._config, dict):
            self._data = pd.concat(
                {
                    fields_group: load_dataset(path_or_obj)
                    for fields_group, path_or_obj in self._config.items()
                },
                axis=1,
                join=self.join,
            )
            self._data.sort_index(inplace=True)
        elif isinstance(self._config, (str, Path)):
            with Path(self._config).open("rb") as f:
                self._data = pickle.load(f)
        elif isinstance(self._config, pd.DataFrame):
            self._data = self._config


class DataLoaderDH(DataLoader):
    """DataLoaderDH DataLoader based on (D)ata (H)andler It is designed to load
    multiple data from data handler.

    - If you just want to load data from single datahandler, you can write them in single data handler

    TODO: What make this module not that easy to use.

    - For online scenario

        - The underlayer data handler should be configured. But data loader doesn't provide such interface & hook.

    """

    def __init__(
        self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False
    ):
        """
        Parameters
        ----------
        handler_config : dict
            handler_config will be used to describe the handlers

            .. code-block::

                <handler_config> := {
                    "group_name1": <handler>
                    "group_name2": <handler>
                }
                or
                <handler_config> := <handler>
                <handler> := DataHandler Instance | DataHandler Config

        fetch_kwargs : dict
            fetch_kwargs will be used to describe the different arguments of fetch method, such as col_set, squeeze, data_key, etc.

        is_group: bool
            is_group will be used to describe whether the key of handler_config is group

        """
        from ...data.dataset.handler import DataHandler  # pylint: disable=C0415

        if is_group:
            self.handlers = {
                grp: init_instance_by_config(config, accept_types=DataHandler)
                for grp, config in handler_config.items()
            }
        else:
            self.handlers = init_instance_by_config(
                handler_config, accept_types=DataHandler
            )

        self.is_group = is_group
        self.fetch_kwargs = {"col_set": DataHandler.CS_RAW}
        self.fetch_kwargs.update(fetch_kwargs)

    def load(
        self, instruments=None, start_time=None, end_time=None
    ) -> pd.DataFrame:
        if instruments is not None:
            get_module_logger(self.__class__.__name__).warning(
                f"instruments[{instruments}] is ignored"
            )

        if self.is_group:
            df = pd.concat(
                {
                    grp: dh.fetch(
                        selector=slice(start_time, end_time),
                        level="datetime",
                        **self.fetch_kwargs,
                    )
                    for grp, dh in self.handlers.items()
                },
                axis=1,
            )
        else:
            df = self.handlers.fetch(
                selector=slice(start_time, end_time),
                level="datetime",
                **self.fetch_kwargs,
            )
        return df
