import pandas as pd
import pandera.pandas as pa
import loguru
from yaml import safe_load
import ray
import numpy as np
from pollinator.type import DatasetSplit, OptimizationInputType


class DAO:
    def __init__(
        self, config_path: str, batch_size: int, split: DatasetSplit, use_estimate: bool
    ):
        self.logger = loguru.logger
        self.config = safe_load(open(config_path))
        ray.init()  # FIXME: Handle the warning on Ray's object-store size.

        cost_schema = pa.DataFrameSchema(
            {
                "id": pa.Column(int, pa.Check.ge(1)),
                "cost": pa.Column(
                    float, pa.Check.ge(0)
                ),  # FIXME: Change it back to greater-than 0.
            },
        )
        self.cost = self._read_and_validate_data(
            self.config["cost"]["output_path"], cost_schema
        )

        cost_estimate_schema = cost_schema
        self.cost_estimate = self._read_and_validate_data(
            self.config["cost_estimate"]["output_path"], cost_estimate_schema
        )

        quality_schema = pa.DataFrameSchema(
            {
                "id": pa.Column(int, pa.Check.ge(1)),
                "quality": pa.Column(float, pa.Check.between(0, 1, include_min=True)),
            },
        )
        self.quality = self._read_and_validate_data(
            self.config["quality"]["output_path"], quality_schema
        )

        quality_estimate_schema = quality_schema
        self.quality_estimate = self._read_and_validate_data(
            self.config["quality_estimate"]["output_path"],
            quality_estimate_schema,
        )

        train_schema = pa.DataFrameSchema(
            {
                "id": pa.Column(int, pa.Check.ge(1)),
            },
        )
        self.train = self._read_and_validate_data(
            self.config["train"]["output_path"], train_schema
        )

        test_schema = train_schema
        self.test = self._read_and_validate_data(
            self.config["test"]["output_path"], test_schema
        )

        self._is_producer_at_top = self._create_reference_policy()
        self.logger.debug(f"Is producer at top? {self._is_producer_at_top}.")

        self.all_optimization_input = self._create_all_optimization_input(
            use_estimate=use_estimate, split=split
        )
        self.batch_optimization_input_iterable = (
            self.all_optimization_input.iter_batches(
                batch_size=batch_size, batch_format="pandas"
            )
        )
        self.row_optimization_input_iterable = self.all_optimization_input.iter_batches(
            batch_size=1, batch_format="pandas"
        )

    def _read_and_validate_data(
        self, path: str, schema: pa.DataFrameSchema
    ) -> pd.DataFrame:
        return schema.validate(pd.read_csv(path))

    def _create_reference_policy(self) -> pd.DataFrame:
        leaderboard_on_train = (
            (self.quality.merge(self.train, on="id", how="inner"))
            .groupby("producer")["quality"]
            .mean()
        ).sort_values(ascending=False)

        self.logger.debug(f"Leaderboard (train): {leaderboard_on_train}.")

        return (
            (leaderboard_on_train == leaderboard_on_train.max())
            .astype(int)
            .reset_index()
        )

    def _create_all_optimization_input(
        self, use_estimate: bool, split: DatasetSplit
    ) -> ray.data.Dataset:
        split = (
            self.test if split == DatasetSplit.TEST else self.test
        )  # FIXME: Handle validation data.

        cost = (
            (self.cost_estimate if use_estimate else self.cost)
            .merge(split, on="id", how="inner")
            .pivot(index="id", columns="producer", values="cost")
        ).reset_index()

        quality = (
            (self.quality_estimate if use_estimate else self.quality)
            .merge(split, on="id", how="inner")
            .pivot(index="id", columns="producer", values="quality")
        ).reset_index()

        reference = (
            (self.quality_estimate if use_estimate else self.quality)
            .merge(split, on="id", how="inner")
            .drop(columns=["quality"])
            .merge(self._is_producer_at_top, on="producer", how="inner")
            .pivot(index="id", columns="producer", values="quality")
        ).reset_index()

        optimization_input = cost.merge(
            quality,
            on=["id"],
            how="inner",
            suffixes=("_cost", "_quality"),
        ).merge(
            reference.rename(
                columns={
                    col: col + "_reference" for col in reference.columns if col != "id"
                }
            ),
            on=["id"],
            how="inner",
        )

        return ray.data.from_pandas(optimization_input)

    def get_all_optimization_input(self) -> OptimizationInputType:
        return OptimizationInputType(self.all_optimization_input.to_pandas())


if __name__ == "__main__":
    dao = DAO(
        "config/irtrouter-normalizer.yaml",
        batch_size=10000,
        split=DatasetSplit.TEST,
        use_estimate=True,
    )

    all_optimization_input = dao.get_all_optimization_input()

    for batch in dao.batch_optimization_input_iterable:
        batch_optimization_input = OptimizationInputType(batch)

    for row in dao.row_optimization_input_iterable:
        row_optimization_input = OptimizationInputType(row)
        break
