import csv


class MV:
    def __init__(self, datafile):
        e2wl, w2el, label_set = self.gete2wlandw2el(datafile)
        self.e2wl = e2wl
        self.w2el = w2el
        self.workers = self.w2el.keys()
        self.label_set = label_set

    def run(self):
        """
        do MV
        e2lpd: questions as key and count for each label as values.
        """
        e2wl = self.e2wl
        e2lpd = {}
        for e in e2wl:
            # store the truth
            e2lpd[e] = {}

            # multi label
            for label in self.label_set:
                e2lpd[e][label] = 0

            for item in e2wl[e]:
                label = item[1]
                e2lpd[e][label] += 1

            alls = 0
            for label in self.label_set:
                alls += e2lpd[e][label]
            if alls != 0:
                for label in self.label_set:
                    e2lpd[e][label] = 1.0 * e2lpd[e][label] / alls
            else:
                for label in self.label_set:
                    e2lpd[e][label] = 1.0 / len(self.label_set)

        # return self.expand(e2lpd)
        return e2lpd

    def gete2wlandw2el(self, datafile):
        """
        e2wl: Problems as key and who answer it as value
        w2el: Worker as key and the problem he answered as value
        label_set: possible labels
        """
        e2wl = {}
        w2el = {}
        label_set = []

        f = open(datafile, 'r')
        reader = csv.reader(f)
        next(reader)

        for line in reader:
            example, worker, label = line
            if example not in e2wl:
                e2wl[example] = []
            e2wl[example].append([worker, label])

            if worker not in w2el:
                w2el[worker] = []
            w2el[worker].append([example, label])

            if label not in label_set:
                label_set.append(label)
        return e2wl, w2el, label_set


if __name__ == "__main__":
    datafile = 'demo_small.csv'
    # generate structures to pass into EM
    e2lpd = MV(datafile).run()
    print(e2lpd)
