from typing import Callable

import cvxpy as cp
import torch
import torch.nn.functional as F
from torch import FloatTensor, Tensor
from trl.trainer.dpo_config import FDivergenceConstants, FDivergenceType
from trl.trainer.dpo_trainer import DPOTrainer
from trl.trainer.utils import cap_exp

from .utils import ERROR_STATUSES, EXCEPTION, SOLVER_KWARGS, QParams


class DRDPOTrainer(DPOTrainer):
    q_params: dict[int, QParams]
    eps: float
    dist_fn: str

    def __init__(self, *args, **kwargs):
        eps = kwargs.pop("eps")
        dist_fn = kwargs.pop("dist_fn")
        super().__init__(*args, **kwargs)
        self.eps = eps
        self.dist_fn = dist_fn
        self.q_params = {}

    def get_base_logits(
        self,
        chosen_logps: FloatTensor,
        rejected_logps: FloatTensor,
        ref_chosen_logps: FloatTensor,
        ref_rejected_logps: FloatTensor,
    ) -> Tensor:
        device = self.accelerator.device

        chosen_logratios = chosen_logps.to(device) - (not self.reference_free) * ref_chosen_logps.to(device)
        rejected_logratios = rejected_logps.to(device) - (not self.reference_free) * ref_rejected_logps.to(device)

        if self.f_divergence_type == FDivergenceType.ALPHA_DIVERGENCE.value:
            alpha_coef = FDivergenceConstants.ALPHA_DIVERGENCE_COEF_DEFAULT
            if self.f_divergence_params and FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY in self.f_divergence_params:
                alpha_coef = float(self.f_divergence_params[FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY])
            logits = (cap_exp(rejected_logratios * -alpha_coef) - cap_exp(chosen_logratios * -alpha_coef)) / alpha_coef
        else:
            logratios = chosen_logps - rejected_logps
            if self.reference_free:
                ref_logratios = torch.tensor([0], dtype=logratios.dtype, device=logratios.device)
            else:
                ref_logratios = ref_chosen_logps - ref_rejected_logps

            logratios = logratios.to(self.accelerator.device)
            ref_logratios = ref_logratios.to(self.accelerator.device)
            logits = logratios - ref_logratios

            if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE.value:
                logits -= F.softplus(chosen_logratios) - F.softplus(rejected_logratios)

        return logits

    def get_optimization_problem(self, batch_size: int) -> QParams:
        if batch_size in self.q_params:
            return self.q_params[batch_size]

        if self.dist_fn == "tv":
            phi: Callable[[cp.Expression], cp.Expression] = lambda x: 1 / 2 * cp.norm(x - 1 / batch_size, p=1)
        elif self.dist_fn == "chi2o":
            phi: Callable[[cp.Expression], cp.Expression] = lambda x: 1 / (2 * batch_size) * cp.sum_squares(batch_size * x - 1)
        else:
            raise NotImplementedError(f"Distance function '{self.dist_fn}' is not implemented.")

        log_sigmoid_par = cp.Parameter((batch_size,))

        q_var = cp.Variable((batch_size,), nonneg=True)
        q_constraints: List[cp.Constraint] = [cp.sum(q_var) == 1, q_var <= 1, phi(q_var) <= self.eps]  # type: ignore

        q_expr = cp.multiply(q_var, log_sigmoid_par).sum()  # - self.eps / batch_size * phi(1 - batch_size * self.q_var)

        q_problem = cp.Problem(cp.Minimize(q_expr), q_constraints)

        self.q_params[batch_size] = (q_var, log_sigmoid_par, q_problem)

        return q_var, log_sigmoid_par, q_problem

    def solve_q(self, log_sigmoid_logits: torch.Tensor) -> torch.Tensor:
        q_var, log_sigmoid_par, q_problem = self.get_optimization_problem(log_sigmoid_logits.shape[0])

        log_sigmoid_par.value = log_sigmoid_logits.detach().cpu().float().numpy()

        try:
            q_problem.solve(**SOLVER_KWARGS)
            status = q_problem.status
            q = q_var.value
        except cp.SolverError as e:
            print(e)
            status = EXCEPTION
        if status in ERROR_STATUSES:
            print(f"q not found: {status}")
            print("log_sigmoid_par:", log_sigmoid_par.value)
            exit(1)
            # np.set_printoptions(suppress=True, precision=3)
            # vnp.set_printoptions(suppress=True, precision=3)
            # print(f"Error in solving: {status}. Showing verbose output:")
            # q_problem.solve(verbose=True, **SOLVER_KWARGS)
            # print("log_sigmoid_par:", log_sigmoid_par.value)
            # raise RuntimeError("q not found")

        return torch.tensor(q, device=log_sigmoid_logits.device).reshape(q_var.shape)

    def dpo_loss(
        self,
        chosen_logps: FloatTensor,
        rejected_logps: FloatTensor,
        ref_chosen_logps: FloatTensor,
        ref_rejected_logps: FloatTensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        if self.loss_type in ["sigmoid", "robust", "exo_pair", "hinge", "ipo", "bco_pair", "sppo_hard", "nca_pair", "aot_pair", "aot", "apo_zero", "apo_down", "discopop"]:
            return super().dpo_loss(chosen_logps, rejected_logps, ref_chosen_logps, ref_rejected_logps)

        device = self.accelerator.device
        logits = self.get_base_logits(chosen_logps, rejected_logps, ref_chosen_logps, ref_rejected_logps)

        if self.loss_type == "dr_sigmoid":
            log_sigmoid_logits = F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
            q = self.solve_q(-log_sigmoid_logits)

            # self.log({"q_mod_std": torch.std(q - 1 / log_sigmoid_logits.shape[0]).item(), "q_max": q.max().item(), "q_min": q.min().item()})

            losses = -q * log_sigmoid_logits - q * F.logsigmoid(-self.beta * logits) * self.label_smoothing
            losses = losses * log_sigmoid_logits.shape[0]
        else:
            raise ValueError(
                f"Unknown loss type: {self.loss_type}. Should be one of ['dr_sigmoid', 'sigmoid', 'robust', 'exo_pair', 'hinge', 'ipo', 'bco_pair', 'sppo_hard', 'nca_pair', 'aot_pair', 'aot', 'apo_zero', 'apo_down', 'discopop']"
            )

        chosen_rewards = self.beta * (chosen_logps.to(device) - ref_chosen_logps.to(device)).detach()
        rejected_rewards = self.beta * (rejected_logps.to(device) - ref_rejected_logps.to(device)).detach()

        return losses, chosen_rewards, rejected_rewards
