import warnings
from typing import Any, Callable, Union

import cvxpy as cp
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from trl import RewardTrainer

from .utils import ERROR_STATUSES, EXCEPTION, SOLVER_KWARGS, QParams


def cp_log_sigmoid(z: cp.Expression):
    return cp.log(1 / (1 + cp.exp(z)))


class DRRewardTrainer(RewardTrainer):
    q_params: dict[int, QParams]
    eps: float
    dist_fn: str

    def __init__(self, *args, **kwargs):
        eps = kwargs.pop("eps")
        dist_fn = kwargs.pop("dist_fn")
        super().__init__(*args, **kwargs)
        self.eps = eps
        self.dist_fn = dist_fn
        self.q_params = {}

    def get_optimization_problem(self, batch_size: int) -> QParams:
        if batch_size in self.q_params:
            return self.q_params[batch_size]

        if self.dist_fn == "tv":
            phi: Callable[[cp.Expression], cp.Expression] = lambda x: 1 / 2 * cp.norm(x - 1 / batch_size, p=1)
        elif self.dist_fn == "chi2o":
            phi: Callable[[cp.Expression], cp.Expression] = lambda x: 1 / (2 * batch_size) * cp.sum_squares(batch_size * x - 1)
        else:
            raise NotImplementedError(f"Distance function '{self.dist_fn}' is not implemented.")

        log_sigmoid_par = cp.Parameter((batch_size, 1))

        q_var = cp.Variable((batch_size, 1), nonneg=True)
        q_constraints: List[cp.Constraint] = [cp.sum(q_var) == 1, q_var <= 1, phi(q_var) <= self.eps]  # type: ignore
        q_expr = -cp.multiply(q_var, log_sigmoid_par).sum()  # - self.eps / batch_size * phi(1 - batch_size * self.q_var)

        q_problem = cp.Problem(cp.Maximize(q_expr), q_constraints)

        self.q_params[batch_size] = (q_var, log_sigmoid_par, q_problem)

        return q_var, log_sigmoid_par, q_problem

    def solve_q(self, log_sigmoid_logits: torch.Tensor) -> torch.Tensor:
        q_var, log_sigmoid_par, q_problem = self.get_optimization_problem(log_sigmoid_logits.shape[0])

        log_sigmoid_par.value = log_sigmoid_logits.detach().cpu().float().numpy()

        try:
            q_problem.solve(**SOLVER_KWARGS)
            status = q_problem.status
            q = q_var.value
        except cp.SolverError as e:
            print(e)
            status = EXCEPTION
        if status in ERROR_STATUSES:
            print(f"q not found: {status}")
            print("reward diff:", log_sigmoid_par.value)
            exit(1)
            # np.set_printoptions(suppress=True, precision=3)
            # vnp.set_printoptions(suppress=True, precision=3)
            # print(f"Error in solving: {status}. Showing verbose output:")
            # q_problem.solve(verbose=True, **SOLVER_KWARGS)
            # print("reward diff:", rew_diff_par.value)
            # raise RuntimeError("q not found")

        return torch.tensor(q, device=log_sigmoid_logits.device).reshape(q_var.shape)

    def compute_loss(
        self,
        model: Union[PreTrainedModel, nn.Module],
        inputs: dict[str, Union[torch.Tensor, Any]],
        return_outputs=False,
        num_items_in_batch=None,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
        rewards_chosen = model(
            input_ids=inputs["input_ids_chosen"],
            attention_mask=inputs["attention_mask_chosen"],
            return_dict=True,
        )["logits"]
        rewards_rejected = model(
            input_ids=inputs["input_ids_rejected"],
            attention_mask=inputs["attention_mask_rejected"],
            return_dict=True,
        )["logits"]
        if "margin" in inputs:
            log_sigmoid_logits = nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"])
        else:
            log_sigmoid_logits = nn.functional.logsigmoid(rewards_chosen - rewards_rejected)

        q = self.solve_q(log_sigmoid_logits)

        self.log({"q_mod_std": torch.std(q - 1 / rewards_chosen.shape[0]).item(), "q_max": q.max().item(), "q_min": q.min().item()})

        loss = -(q * log_sigmoid_logits).sum()

        if self.args.center_rewards_coefficient is not None:  # type: ignore
            loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2)  # type: ignore

        if return_outputs:
            return loss, {
                "rewards_chosen": rewards_chosen,
                "rewards_rejected": rewards_rejected,
            }
        return loss
