from abc import abstractmethod

import cvxopt
import cvxpy as cp
import numpy as np
import torch

import pdb
"""Implementation of Pareto HyperNetworks with:
1. Linear scalarization
3. EPO


EPO code from: https://github.com/dbmptr/EPOSearch
"""


class Solver:
    def __init__(self, n_tasks):
        super().__init__()
        self.n_tasks = n_tasks

    @abstractmethod
    def get_weighted_loss(self, losses, ray, parameters):
        pass

    def __call__(self, losses, ray, parameters):
        return self.get_weighted_loss(losses, ray, parameters)


class LinearScalarizationSolver(Solver):
    """For LS we use the preference ray to weigh the losses

    """

    def __init__(self, n_tasks):
        super().__init__(n_tasks)

    def get_weighted_loss(self, losses, ray, parameters):

        return (losses * ray).sum(), torch.Tensor([0 for i in ray])


class EPOSolver(Solver):
    """Wrapper over EPO

    """

    def __init__(self, n_tasks):
        super().__init__(n_tasks)
        self.solver = EPO(n_tasks=n_tasks)

    def get_weighted_loss(self, losses, ray, parameters):
        assert parameters is not None
        return self.solver.get_weighted_loss(losses, ray, parameters)


class EPO:

    def __init__(self, n_tasks):
        self.n_tasks = n_tasks

    def __call__(self, losses, ray, parameters):
        return self.get_weighted_loss(losses, ray, parameters)

    @staticmethod
    def _flattening(grad):
        #return torch.cat(tuple(g.reshape(-1, ) for i, g in enumerate(grad)), axis=0)
        return torch.cat(tuple(g.reshape(-1) for g in grad if g is not None), axis=0)

    def get_weighted_loss(self, losses, ray, parameters):
        lp = ExactParetoLP(m=self.n_tasks, r=ray.cpu().numpy())

        grads = []
        for i, loss in enumerate(losses):
            g = torch.autograd.grad(loss, parameters, retain_graph=True, allow_unused = True)
            flat_grad = self._flattening(g)
            grads.append(flat_grad.data)
        
        G = torch.stack(grads)
        GG_T = G @ G.T
        GG_T = GG_T.detach().cpu().numpy()
        numpy_losses = losses.detach().cpu().numpy()
        try:
            alpha = lp.get_alpha(numpy_losses, G=GG_T, C=True)
        except Exception as excep:
            print(excep)
            alpha = None

        if alpha is None:  # A patch for the issue in cvxpy
            alpha = (ray / ray.sum()).cpu().numpy()

        alpha *= self.n_tasks
        alpha = torch.from_numpy(alpha).to(losses.device)

        weighted_loss = torch.sum(losses * alpha)
        return weighted_loss, alpha


class ExactParetoLP(object):
    """modifications of the code in https://github.com/dbmptr/EPOSearch

    """

    def __init__(self, m, r, eps=1e-4):
        cvxopt.glpk.options["msg_lev"] = "GLP_MSG_OFF"
        self.m = m
        self.r = r
        self.eps = eps
        self.last_move = None
        self.a = cp.Parameter(m)        # Adjustments
        self.C = cp.Parameter((m, m))   # C: Gradient inner products, G^T G
        self.Ca = cp.Parameter(m)       # d_bal^TG
        self.rhs = cp.Parameter(m)      # RHS of constraints for balancing

        self.alpha = cp.Variable(m)     # Variable to optimize

        obj_bal = cp.Maximize(self.alpha @ self.Ca)   # objective for balance
        constraints_bal = [self.alpha >= 0, cp.sum(self.alpha) == 1,  # Simplex
                           self.C @ self.alpha >= self.rhs]
        self.prob_bal = cp.Problem(obj_bal, constraints_bal)  # LP balance

        obj_dom = cp.Maximize(cp.sum(self.alpha @ self.C))  # obj for descent
        constraints_res = [self.alpha >= 0, cp.sum(self.alpha) == 1,  # Restrict
                           self.alpha @ self.Ca >= -cp.neg(cp.max(self.Ca)),
                           self.C @ self.alpha >= 0]
        constraints_rel = [self.alpha >= 0, cp.sum(self.alpha) == 1,  # Relaxed
                           self.C @ self.alpha >= 0]
        self.prob_dom = cp.Problem(obj_dom, constraints_res)  # LP dominance
        self.prob_rel = cp.Problem(obj_dom, constraints_rel)  # LP dominance

        self.gamma = 0     # Stores the latest Optimum value of the LP problem
        self.mu_rl = 0     # Stores the latest non-uniformity

    def get_alpha(self, l, G, r=None, C=False, relax=False):
        r = self.r if r is None else r
        assert len(l) == len(G) == len(r) == self.m, "length != m"
        rl, self.mu_rl, self.a.value = adjustments(l, r)
        self.C.value = G if C else G @ G.T
        self.Ca.value = self.C.value @ self.a.value

        if self.mu_rl > self.eps:
            J = self.Ca.value > 0
            if len(np.where(J)[0]) > 0:

                J_star_idx = np.where(rl == np.max(rl))[0]
                self.rhs.value = self.Ca.value.copy()
                self.rhs.value[J] = -np.max(np.abs(self.Ca.value))     # Not efficient; but works.
                self.rhs.value[J_star_idx] = 0
            else:
                self.rhs.value = np.zeros_like(self.Ca.value)
            self.gamma = self.prob_bal.solve(solver=cp.GLPK, verbose=False)
            self.last_move = "bal"
        else:
            if relax:
                self.gamma = self.prob_rel.solve(solver=cp.GLPK, verbose=False)
            else:
                self.gamma = self.prob_dom.solve(solver=cp.GLPK, verbose=False)
            self.last_move = "dom"

        return self.alpha.value


def mu(rl, normed=False):
    if len(np.where(rl < 0)[0]):
        raise ValueError(f"rl<0 \n rl={rl}")
        # return None
    m = len(rl)
    l_hat = rl if normed else rl / rl.sum()
    eps = np.finfo(rl.dtype).eps
    l_hat = l_hat[l_hat > eps]
    return np.sum(l_hat * np.log(l_hat * m))


def adjustments(l, r=1):
    m = len(l)
    rl = r * l
    l_hat = rl / rl.sum()
    mu_rl = mu(l_hat, normed=True)
    a = r * (np.log(l_hat * m) - mu_rl)
    return rl, mu_rl, a
