import pandas as pd
import loguru
from pollinator.data.normalizer.abstract_normalizer import AbstractNormalizer
from yaml import safe_load
from collections import Counter


class IRTRouterNormalizer(AbstractNormalizer):
    def __init__(self, config_path: str):
        self.logger = loguru.logger
        self.config = safe_load(open(config_path))
        self.producers = self._get_producer_list()

    def _get_producer_list(self) -> list[str]:
        producers_in_quality_dataset = (
            pd.read_csv(self.config["quality"]["input_path"])
            .drop(columns=["question_id"])
            .columns.tolist()
        )
        self.logger.debug(
            f"Producers in quality dataset: {producers_in_quality_dataset}."
        )

        producers_in_quality_estimate_dataset = (
            pd.read_csv(self.config["quality_estimate"]["input_path"])
            .drop(columns=["question_id"])
            .columns.tolist()
        )
        self.logger.debug(
            f"Producers in quality-estimate dataset: {producers_in_quality_estimate_dataset}."
        )

        try:
            assert Counter(producers_in_quality_dataset) == Counter(
                producers_in_quality_estimate_dataset
            )
        except AssertionError:
            self.logger.error(
                f"Diff in producers: {Counter(producers_in_quality_dataset) 
                - Counter(producers_in_quality_estimate_dataset)}."
            )
            raise AssertionError(
                "Producers in quality dataset and quality-estimate dataset are not the same!"
            )

        self.logger.info(f"Producers: {Counter(producers_in_quality_dataset)}.")
        return producers_in_quality_dataset

    def _normalize_cost(self) -> pd.DataFrame:
        cost = pd.read_csv(self.config["cost"]["input_path"]).rename(
            columns={"question_id": "id"}
        )
        
        cost = cost.rename(
            columns={
                col: col.removesuffix("_cost")
                for col in cost.columns
                if col not in ["id"]
            }
        )

        producers_to_drop = [
            col for col in cost.columns if col not in self.producers and col != "id"
        ]
        if len(producers_to_drop) > 0:
            self.logger.warning(f"Producers to drop: {producers_to_drop}.")
        cost = cost.drop(columns=producers_to_drop)

        return (
            pd.melt(cost, id_vars=["id"], value_vars=self.producers)
            .rename(columns={"variable": "producer", "value": "cost"})
            .sort_values(by=["id"])
        )

    def _normalize_cost_estimate(self) -> pd.DataFrame:
        self.logger.debug(f"Producer config: {self.config['producer']}.")

        cost_estimate = pd.DataFrame(self.config["producer"]).astype(
            {"input_cost": "float64", "output_cost": "float64"}
        )

        cost_estimate["blended_cost"] = (
            cost_estimate["input_cost"] + cost_estimate["output_cost"]
        ) / 2.0

        cost_estimate = cost_estimate.drop(
            columns=["input_cost", "output_cost"]
        ).rename(columns={"blended_cost": "cost", "name": "producer"})

        num_producers_in_config = cost_estimate["producer"].nunique()

        cost_estimate = cost_estimate[cost_estimate["producer"].isin(self.producers)]

        num_producers_post_filter = cost_estimate["producer"].nunique()

        if num_producers_in_config > num_producers_post_filter:
            self.logger.warning(
                f"Dropping {num_producers_in_config - num_producers_post_filter} producers from config."
            )

        self.logger.debug(
            f"Columns in cost estimate before merge: {cost_estimate.columns.tolist()}."
        )

        return self.cost.drop(columns=["cost"]).merge(
            cost_estimate, on="producer", how="left"
        )

    def _normalize_quality(self) -> pd.DataFrame:
        return (
            pd.melt(
                pd.read_csv(self.config["quality"]["input_path"]).rename(
                    columns={"question_id": "id"}
                ),
                id_vars=["id"],
                value_vars=self.producers,
            )
            .rename(columns={"variable": "producer", "value": "quality"})
            .sort_values(by=["id"])
        )

    def _normalize_quality_estimate(self) -> pd.DataFrame:
        return (
            pd.melt(
                pd.read_csv(self.config["quality_estimate"]["input_path"]).rename(
                    columns={"question_id": "id"}
                ),
                id_vars=["id"],
                value_vars=self.producers,
            )
            .rename(columns={"variable": "producer", "value": "quality"})
            .sort_values(by=["id"])
        )

    def _normalize_train(self) -> pd.DataFrame:
        return (
            pd.read_csv(self.config["train"]["input_path"], on_bad_lines="warn")
            .drop(columns=["question"])
            .rename(columns={"question_id": "id"})
        )

    def _normalize_test(self) -> pd.DataFrame:
        return (
            pd.read_csv(self.config["test"]["input_path"], on_bad_lines="warn")
            .drop(columns=["question"])
            .rename(columns={"question_id": "id"})
        )


if __name__ == "__main__":
    normalizer = IRTRouterNormalizer("config/irtrouter-normalizer.yaml")
    normalizer.normalize_and_write()
