from cvxpy.constraints import constraint
from pollinator.optimizer.abstract_optimizer import AbstractOptimizer
from pollinator.type import OptimizationInputType
from pollinator.data.dao import DAO
from pollinator.type import DatasetSplit

import pandas as pd
import numpy as np
import cvxpy as cp
import loguru


class CostQualityBatchOptimizer(AbstractOptimizer):
    def __init__(
        self,
        optimization_input: OptimizationInputType,
        gamma: float,
        min_quality: float,
    ):
        self.logger = loguru.logger
        self.optimization_input = optimization_input
        self.gamma = gamma
        self.min_quality = min_quality

        assert (
            self.optimization_input.cost.shape == self.optimization_input.quality.shape
        )

        self.optimization_problem = self.create_optimization_problem()
        self.logger.debug(f"Optimization problem: {self.optimization_problem}.")

    def create_optimization_problem(self) -> cp.Problem:
        self.m, self.n = self.optimization_input.cost.shape
        self.logger.debug(f"m: {self.m}, n: {self.n}.")

        self.X = cp.Variable((self.m, self.n))
        C = self.optimization_input.cost
        Q = self.optimization_input.quality
        R = self.optimization_input.reference

        PER_MILLE = 1e6

        objective = cp.Minimize(
            0.5 * self.gamma * cp.norm(self.X - R, "fro") ** 2
            + PER_MILLE * cp.sum(cp.multiply(C, self.X))
        )

        constraints = []

        self.nonnegativity_constraint = self.X >= 0
        constraints.append(self.nonnegativity_constraint)

        self.sum_to_one_constraint = cp.sum(self.X, axis=1) == 1
        constraints.append(self.sum_to_one_constraint)

        self.min_quality_constraint = (
            1 / self.m * cp.sum(cp.multiply(Q, self.X)) >= self.min_quality
        )
        constraints.append(self.min_quality_constraint)

        return cp.Problem(objective, constraints)

    def optimize(self) -> pd.DataFrame:
        self.optimization_problem.solve(verbose=True)
        self.logger.debug(f"OPT: {self.optimization_problem.value}.")
        self.logger.debug(
            f"Residual of min-quality constraint: {self.min_quality_constraint.value()}."
        )
        self.logger.debug(
            f"Dual variable of min-quality constraint: {self.min_quality_constraint.dual_value}."
        )
        self.logger.debug(f"Optimal solution: {self.X.value}.")
        # self.logger.debug(
        #     f"Dual variable of sum-to-one constraint: {self.sum_to_one_constraint.dual_value}."
        # )
        # self.logger.debug(
        #     f"Dual variable of nonnegativity constraint: {self.nonnegativity_constraint.dual_value}."
        # )

    def round_optimal_solution(self, optimal_solution: pd.DataFrame) -> pd.DataFrame:
        optimal_solution = optimal_solution.drop(columns=["id"])
        rounded_optimal_solution = optimal_solution.eq(
            optimal_solution.max(axis=1), axis=0
        ).astype(int)
        rounded_optimal_solution["id"] = self.optimization_input.id
        self.logger.debug(
            f"Columns in rounded optimal solution: {rounded_optimal_solution.columns.tolist()}."
        )
        self.logger.debug(
            f"Rounded optimal solution (head): {rounded_optimal_solution.head(2)}."
        )
        self.logger.debug(
            f"Rounded optimal solution (tail): {rounded_optimal_solution.tail(3)}."
        )
        return rounded_optimal_solution

    def get_optimal_solution(self) -> pd.DataFrame:
        optimal_solution = pd.DataFrame(
            self.X.value,
            columns=self.optimization_input.cost_producer,
            index=self.optimization_input.id,
        ).reset_index()
        self.logger.debug(
            f"Columns in optimal solution: {optimal_solution.columns.tolist()}."
        )
        self.logger.debug(f"Optimal solution (head): {optimal_solution.head(2)}.")
        self.logger.debug(f"Optimal solution (tail): {optimal_solution.tail(3)}.")

        rounded_optimal_solution = self.round_optimal_solution(optimal_solution)
        return rounded_optimal_solution.melt(
            id_vars=["id"], var_name="producer", value_name="allocation"
        )


if __name__ == "__main__":
    dao = DAO(
        "config/irtrouter-normalizer.yaml",
        batch_size=10000,
        use_estimate=True,
        split=DatasetSplit.TEST,
    )
    optimization_input = dao.get_all_optimization_input()

    optimizer = CostQualityBatchOptimizer(
        optimization_input, gamma=1e-5, min_quality=0.8161
    )
    optimizer.optimize()

    rounded_optimal_solution = optimizer.get_optimal_solution()
    rounded_optimal_solution = rounded_optimal_solution[
        rounded_optimal_solution["allocation"] == 1
    ][["id", "producer"]]
    
    loguru.logger.debug(
        f"Columns in rounded optimal solution: {rounded_optimal_solution.columns.tolist()}."
    )
    loguru.logger.debug(
        f"Rounded optimal solution (head): {rounded_optimal_solution.head(2)}."
    )
    loguru.logger.debug(
        f"Rounded optimal solution (tail): {rounded_optimal_solution.tail(3)}."
    )

    average_quality = dao.quality.merge(rounded_optimal_solution, on=["id", "producer"], how="inner")["quality"].mean()
    loguru.logger.debug(f"Average quality (accuracy): {average_quality}.")

    total_cost = dao.cost.merge(rounded_optimal_solution, on=["id", "producer"], how="inner")["cost"].sum()
    loguru.logger.debug(f"Total cost ($): {total_cost}.")
