import csv


class EM:
    def __init__(self, datafile, initquality):
        e2wl, w2el, label_set = gete2wlandw2el(datafile)
        self.e2wl = e2wl
        self.w2el = w2el
        self.workers = self.w2el.keys()
        self.label_set = label_set
        self.initalquality = initquality


def E_step(theta_A, theta_B):
    """
    The E step aim at construct the sequence using the parameter now.
    exp_counts_A is the expected sequence of A
    """
    counts_for_seqs = [Counter(seq) for seq in flip_seqs]
    seq_a = [likelihood_of_theta(theta_A, **count)
             for count in counts_for_seqs]
    seq_b = [likelihood_of_theta(theta_B, **count)
             for count in counts_for_seqs]

    weight_a = [a / (a + b) for a, b in zip(seq_a, seq_b)]
    weight_b = [b / (a + b) for a, b in zip(seq_a, seq_b)]

    exp_counts_A = [expected_counts(counts, w)
                    for counts, w in zip(counts_for_seqs, weight_a)]
    exp_counts_B = [expected_counts(counts, w)
                    for counts, w in zip(counts_for_seqs, weight_b)]
    return exp_counts_A, exp_counts_B


def M_step(exp_counts_A, exp_counts_B):
    """
    The M step aim at find the new parameter based on the sequence.
    """
    count_a = sum(exp_counts_A, Counter())
    count_b = sum(exp_counts_B, Counter())
    theta_A = estimate_theta(**count_a)
    theta_B = estimate_theta(**count_b)
    return theta_A, theta_B


def EM(theta_A, theta_B, iter_num):
    for index in range(0, iter_num):
        # print('{}\ttheta_A: {:.5f}\ttheta_B: {:.5f}'.format(
        #     index, theta_A, theta_B))
        exp_counts_A, exp_counts_B = E_step(theta_A, theta_B)
        theta_A, theta_B = M_step(exp_counts_A, exp_counts_B)


###################################
# The above is the EM method (a class)
# The following are several external functions
###################################

def gete2wlandw2el(datafile):
    e2wl = {}
    w2el = {}
    label_set = []

    f = open(datafile, 'r')
    reader = csv.reader(f)
    next(reader)

    for line in reader:
        example, worker, label = line
        if example not in e2wl:
            e2wl[example] = []
        e2wl[example].append([worker, label])

        if worker not in w2el:
            w2el[worker] = []
        w2el[worker].append([example, label])

        if label not in label_set:
            label_set.append(label)

    return e2wl, w2el, label_set


if __name__ == "__main__":
    datafile = 'demo_small.csv'
    iterations = 20  # EM iteration number
    initquality = 0.7
    em = EM(datafile, initquality)
    e2lpd, w2cm = em.run(iterations)

    # getaccuracy(truthfile, e2lpd, EM.label_set)
    # print(w2cm)
    # print(e2lpd)

    # truthfile = r'./truth.csv'
    # accuracy = getaccuracy(truthfile, e2lpd, label_set)
    # print accuracy
