"""
=============================
Dawid and skene model (1979)
=============================

Assumptions:
- independent workers

Using:
- EM algorithm

Estimating:
- One confusion matrix for each workers
"""

import numpy as np
import pandas as pd


class Dawid_Skene:
    def __init__(self, answers, n_workers, n_classes, n_spam=0, cut=[]):
        self.answers = answers
        self.n_workers = n_workers
        self.n_classes = n_classes
        self.n_task = len(answers.keys()) + len(cut)
        # n_workers = n_spam + n_regular
        self.n_spam = n_spam
        self.cut = cut

    def get_crowd_matrix(self):
        matrix = np.zeros((self.n_task, self.n_workers, self.n_classes))
        for task, ans in self.answers.items():
            for worker, label in ans.items():
                if worker.startswith("spam"):
                    num = worker.split("spam")[1]
                    worker_last = self.n_workers - self.n_spam - 1
                    worker = worker_last + int(num)
                matrix[int(task), int(worker), label] += 1
        self.crowd_matrix = matrix

    def init_T(self):
        T = self.crowd_matrix.sum(axis=1)
        tdim = T.sum(1, keepdims=True)
        self.T = np.where(tdim > 0, T / tdim, 0)

    def m_step(self):
        """Maximizing log likelihood (see eq. 2.3 and 2.4 Dawid and Skene 1979)

        Returns:
            p: (p_j)_j probabilities that instance has true response j if drawn
        at random (class marginals)
            pi: number of times worker k records l when j is correct / number
        of instances seen by worker k where j is correct
        """
        p = self.T.sum(0) / self.n_task
        pi = np.zeros((self.n_workers, self.n_classes, self.n_classes))
        for k in range(self.n_workers):
            for j in range(self.n_classes):
                for q in range(self.n_classes):
                    pi[k, j, q] = self.T[:, j] @ self.crowd_matrix[:, k, q]
                denom = np.sum(pi[k, j, :])
                if denom > 0:
                    pi[k, j, :] /= denom
        self.p, self.pi = p, pi

    def e_step(self):
        """Estimate indicator variables (see eq. 2.5 Dawid and Skene 1979)
        Returns:
            T: New estimate for indicator variables (n_task, n_worker)
            denom: value used to compute likelihood easily
        """
        T = np.zeros((self.n_task, self.n_classes))
        for i in range(self.n_task):
            for j in range(self.n_classes):
                num = (
                    np.prod(
                        np.power(self.pi[:, j, :], self.crowd_matrix[i, :, :])
                    )
                    * self.p[j]
                )
                T[i, j] = num
        self.denom_e_step = T.sum(1, keepdims=True)
        T = np.where(self.denom_e_step > 0, T / self.denom_e_step, T)
        self.T = T

    def log_likelihood(self):
        return np.log(np.sum(self.denom_e_step))

    def run_em(self, epsilon=1e-7, maxiter=100):
        self.get_crowd_matrix()
        self.init_T()
        ll = []
        k, eps = 0, 1e1
        while k < maxiter and eps > epsilon:
            self.m_step()
            self.e_step()
            likeli = self.log_likelihood()
            ll.append(likeli)
            if len(ll) >= 2:
                eps = np.abs(ll[-1] - ll[-2])
            k += 1
        self.c = k
        if eps > epsilon:
            print(f"DS did not converge: err={eps}")
        return ll, k

    def get_predictions(self):
        return np.argmax(self.T, axis=1)

    def get_probas(self):
        return self.T
