#!/usr/bin/env python
# -*- coding: utf-8 -*-
# From https://github.com/notani/python-glad

import numpy as np
import scipy as sp
import scipy.stats
import scipy.optimize

THRESHOLD = 1e-7


def sigmoid(x):
    return 1.0 / (1.0 + np.exp(-x))


def logsigmoid(x):
    return -np.log(1 + np.exp(-x))


# warnings.filterwarnings('error')
class GLAD:
    def __init__(
        self,
        n_classes=-1,
        n_workers=-1,
        n_task=-1,
        answers=None,
        n_iter=200,
        n_spam=0,
        **kwargs,
    ):
        self.n_classes = n_classes
        self.n_workers = n_workers
        self.n_task = n_task
        self.answers = answers
        self.priorZ = np.array([1 / n_classes] * n_classes)
        self.n_iter = n_iter
        self.n_spam = n_spam

        self.labels = np.zeros((self.n_task, self.n_workers))
        for task, ans in answers.items():
            for worker, lab in ans.items():
                if worker.startswith("spam"):
                    num = worker.split("spam")[1]
                    worker_last = self.n_workers - self.n_spam - 1
                    worker = worker_last + int(num)
                self.labels[int(task)][int(worker)] = lab + 1

        # Initialize Probs
        self.priorAlpha = np.ones(self.n_workers)
        self.priorBeta = np.ones(self.n_task)
        self.probZ = np.empty((self.n_task, self.n_classes))
        self.beta = np.empty(self.n_task)
        self.alpha = np.empty(self.n_workers)

    def EM(self):
        """Infer true labels, tasks' difficulty and workers' ability"""
        # Initialize parameters to starting values
        self.alpha = self.priorAlpha.copy()
        self.beta = self.priorBeta.copy()
        self.probZ[:] = self.priorZ[:]

        self.EStep()
        lastQ = self.computeQ()
        self.MStep()
        Q = self.computeQ()
        counter = 1
        while abs((Q - lastQ) / lastQ) > THRESHOLD and counter <= self.n_iter:
            lastQ = Q
            self.EStep()
            self.MStep()
            Q = self.computeQ()
            counter += 1
        if abs((Q - lastQ) / lastQ) > THRESHOLD:
            print(f"GLAD did not converge: err={abs((Q - lastQ) / lastQ)}")

    def calcLogProbL(self, item, *args):
        j = int(item[0])  # task ID
        delta = args[0][j]
        noResp = args[1][j]
        oneMinusDelta = (~delta) & (~noResp)
        # List[float]: alpha_i * exp(beta_j) for i = 0, ..., m-1
        exponents = item[1:]
        # Log likelihood for the observations s.t. l_ij == z_j
        correct = logsigmoid(exponents[delta]).sum()
        # Log likelihood for the observations s.t. l_ij != z_j
        wrong = (
            logsigmoid(-exponents[oneMinusDelta])
            - np.log(float(self.n_classes - 1))
        ).sum()
        # Return log likelihood
        return correct + wrong

    def EStep(self):
        """Evaluate the posterior probability of true labels given observed labels and parameters"""

        ab = np.dot(np.array([np.exp(self.beta)]).T, np.array([self.alpha]))
        ab = np.c_[np.arange(self.n_task), ab]

        for k in range(self.n_classes):
            self.probZ[:, k] = np.apply_along_axis(
                self.calcLogProbL,
                1,
                ab,
                (self.labels == k + 1),
                (self.labels == 0),
            )

        # Exponentiate and renormalize
        self.probZ = np.exp(self.probZ)
        s = self.probZ.sum(axis=1)
        self.probZ = (self.probZ.T / s).T
        assert not np.any(np.isnan(self.probZ)), "Invalid Value [EStep]"
        assert not np.any(np.isinf(self.probZ)), "Invalid Value [EStep]"

    def packX(self):
        return np.r_[self.alpha.copy(), self.beta.copy()]

    def unpackX(self, x):
        self.alpha = x[: self.n_workers].copy()
        self.beta = x[self.n_workers :].copy()

    def getBoundsX(self, alpha=(-100, 100), beta=(-100, 100)):
        alpha_bounds = np.array(
            [[alpha[0], alpha[1]] for i in range(self.n_workers)]
        )
        beta_bounds = np.array(
            [[beta[0], beta[1]] for i in range(self.n_workers)]
        )
        return np.r_[alpha_bounds, beta_bounds]

    def f(self, x):
        """Return the value of the objective function"""
        self.unpackX(x)
        return -self.computeQ()

    def df(self, x):
        """Return gradient vector"""
        self.unpackX(x)
        dQdAlpha, dQdBeta = self.gradientQ()
        # Flip the sign since we want to minimize
        assert not np.any(np.isinf(dQdAlpha)), "Invalid Gradient Value [Alpha]"
        assert not np.any(np.isinf(dQdBeta)), "Invalid Gradient Value [Beta]"
        assert not np.any(np.isnan(dQdAlpha)), "Invalid Gradient Value [Alpha]"
        assert not np.any(np.isnan(dQdBeta)), "Invalid Gradient Value [Beta]"
        return np.r_[-dQdAlpha, -dQdBeta]

    def MStep(self):
        initial_params = self.packX()
        params = sp.optimize.minimize(
            fun=self.f,
            x0=initial_params,
            method="CG",
            jac=self.df,
            tol=0.01,
            options={"maxiter": 25},
        )
        self.unpackX(params.x)

    def computeQ(self):
        """Calculate the expectation of the joint likelihood"""
        Q = 0
        # Start with the expectation of the sum of priors over all tasks
        Q += (self.probZ * np.log(self.priorZ)).sum()

        # the expectation of the sum of posteriors over all tasks
        ab = np.dot(np.array([np.exp(self.beta)]).T, np.array([self.alpha]))

        # logSigma = - np.log(1 + np.exp(-ab))
        logSigma = logsigmoid(ab)  # logP
        idxna = np.isnan(logSigma)
        if np.any(idxna):
            logSigma[idxna] = ab[
                idxna
            ]  # For large negative x, -log(1 + exp(-x)) = x

        # logOneMinusSigma = - np.log(1 + np.exp(ab))
        logOneMinusSigma = logsigmoid(-ab) - np.log(
            float(self.n_classes - 1)
        )  # log((1-P)/(K-1))
        idxna = np.isnan(logOneMinusSigma)
        if np.any(idxna):
            logOneMinusSigma[idxna] = -ab[
                idxna
            ]  # For large positive x, -log(1 + exp(x)) = x

        for k in range(self.n_classes):
            delta = self.labels == k + 1
            Q += (self.probZ[:, k] * logSigma.T).T[delta].sum()
            oneMinusDelta = (self.labels != k + 1) & (
                self.labels != 0
            )  # label == 0 -> no response
            Q += (self.probZ[:, k] * logOneMinusSigma.T).T[oneMinusDelta].sum()

        # Add Gaussian (standard normal) prior for alpha
        Q += np.log(sp.stats.norm.pdf(self.alpha - self.priorAlpha)).sum()

        # Add Gaussian (standard normal) prior for beta
        Q += np.log(sp.stats.norm.pdf(self.beta - self.priorBeta)).sum()

        if np.isnan(Q):
            return -np.inf
        return Q

    def dAlpha(self, item, *args):
        i = int(item[0])  # worker ID
        sigma_ab = item[1:]
        delta = args[0][:, i]
        noResp = args[1][:, i]
        oneMinusDelta = (~delta) & (~noResp)

        probZ = args[2]

        correct = (
            probZ[delta] * np.exp(self.beta[delta]) * (1 - sigma_ab[delta])
        )
        wrong = (
            probZ[oneMinusDelta]
            * np.exp(self.beta[oneMinusDelta])
            * (-sigma_ab[oneMinusDelta])
        )
        # Note: The derivative in Whitehill et al.'s appendix has the term ln   (K-1), which is incorrect.

        return correct.sum() + wrong.sum()

    def dBeta(self, item, *args):
        j = int(item[0])  # task ID
        sigma_ab = item[1:]
        delta = args[0][j]
        noResp = args[1][j]
        oneMinusDelta = (~delta) & (~noResp)

        probZ = args[2][j]

        correct = probZ * self.alpha[delta] * (1 - sigma_ab[delta])
        wrong = probZ * self.alpha[oneMinusDelta] * (-sigma_ab[oneMinusDelta])

        return correct.sum() + wrong.sum()

    def gradientQ(self):

        # prior prob.
        dQdAlpha = -(self.alpha - self.priorAlpha)
        dQdBeta = -(self.beta - self.priorBeta)

        ab = np.dot(np.array([np.exp(self.beta)]).T, np.array([self.alpha]))

        sigma = sigmoid(ab)
        sigma[np.isnan(sigma)] = 0  # :TODO check if this is correct

        labelersIdx = np.arange(self.n_workers).reshape((1, self.n_workers))
        sigma = np.r_[labelersIdx, sigma]
        sigma = np.c_[np.arange(-1, self.n_task), sigma]

        for k in range(self.n_classes):
            dQdAlpha += np.apply_along_axis(
                self.dAlpha,
                0,
                sigma[:, 1:],
                (self.labels == k + 1),
                (self.labels == 0),
                self.probZ[:, k],
            )

            dQdBeta += (
                np.apply_along_axis(
                    self.dBeta,
                    1,
                    sigma[1:],
                    (self.labels == k + 1),
                    (self.labels == 0),
                    self.probZ[:, k],
                )
                * np.exp(self.beta)
            )

        return dQdAlpha, dQdBeta

    def run(self):
        self.EM()

    def get_probas(self):
        return self.probZ
