# coding=utf-8
import csv
import random
import time

import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

from bwa import bwa
from ebcc import ebcc_vb
from ibcc import ibcc
from CATD import CATD
from DS import DS
from GLAD import GLAD
from MV import MV
from PM import PM


class INFERENCE:
    def __init__(self, answer_file, truth_file, result_file=None, alg='MV',
                 tasktype='categorical', show_acc=True):
        self.algorithm = alg
        self.answer_file = answer_file
        self.truth_file = truth_file
        self.result_file = result_file
        self.tasktype = tasktype
        self.show_acc = show_acc

    def calc_report(self, truth_list, prediction_list):
        if self.algorithm in ['IBCC', 'EBCC', 'BWA']:
            report = {'accuracy': accuracy_score(truth_list, prediction_list),
                      'precision': precision_score(truth_list, prediction_list, average='micro'),
                      'recall': recall_score(truth_list, prediction_list, average='micro'),
                      'fscore': f1_score(truth_list, prediction_list, average='micro'),
                      'microfscore': f1_score(truth_list, prediction_list, average='micro'),
                      'marcofscore': f1_score(truth_list, prediction_list, average='macro'),
                      'weightedfscore': f1_score(truth_list, prediction_list, average='weighted'),
                      }
        else:
            report = {'accuracy': accuracy_score(truth_list, prediction_list),
                      'precision': precision_score(truth_list, prediction_list, average='micro', pos_label='1'),
                      'recall': recall_score(truth_list, prediction_list, average='micro', pos_label='1'),
                      'fscore': f1_score(truth_list, prediction_list, average='micro', pos_label='1'),
                      'microfscore': f1_score(truth_list, prediction_list, average='micro'),
                      'marcofscore': f1_score(truth_list, prediction_list, average='macro'),
                      'weightedfscore': f1_score(truth_list, prediction_list, average='weighted'),
                      }
        return report

    def getaccuracy(self, truthfile, e2lpd, label_set):
        prediction_list = []
        truth_list = []
        e2truth = {}
        f = open(truthfile, 'r')
        reader = csv.reader(f)
        next(reader)
        for line in reader:
            example, truth = line
            e2truth[example] = truth
        tcount = 0
        count = 0
        for e in e2lpd:
            if e not in e2truth:
                continue

            temp = 0
            for label in e2lpd[e]:
                if temp < e2lpd[e][label]:
                    temp = e2lpd[e][label]
            candidate = []
            for label in e2lpd[e]:
                if temp == e2lpd[e][label]:
                    candidate.append(label)
            truth = random.choice(candidate)
            count += 1
            prediction_list.append(truth)
            truth_list.append(e2truth[e])
            if truth == e2truth[e]:
                tcount += 1
        return tcount * 1.0 / count, self.calc_report(truth_list, prediction_list)

    def getaccuracy_catd(self, truthfile, predict_truth, datatype):
        truth_list = []
        prediction_list = []
        e2truth = {}
        f = open(truthfile, 'r')
        reader = csv.reader(f)
        next(reader)

        for line in reader:
            example, truth = line
            e2truth[example] = truth

        tcount = 0
        count = 0

        for e, ptruth in predict_truth.items():

            if e not in e2truth:
                continue

            count += 1

            if datatype == 'continuous':
                tcount = tcount + (ptruth - float(e2truth[e])) ** 2
            else:
                prediction_list.append(ptruth)
                truth_list.append(e2truth[e])
                if ptruth == e2truth[e]:
                    tcount += 1

        if datatype == 'continuous':
            return pow(tcount / count, 0.5)
        else:
            return tcount * 1.0 / count, self.calc_report(truth_list, prediction_list)

    def get_acc(self, predictions, df_truth):
        truth_list = []
        prediction_list = []
        score = (predictions == predictions.max(axis=1, keepdims=True)).astype(np.cfloat)
        score /= score.sum(axis=1, keepdims=True)

        option = score.argmax(axis=1)
        for idx, row in df_truth.iterrows():
            truth_list.append(row['truth'])
            prediction_list.append(option[row['task']])
        report = self.calc_report(truth_list, prediction_list)
        acc = score[df_truth.task.values, df_truth.truth.values].sum() / df_truth.shape[0]
        return acc, report

    def infernce(self):
        if self.algorithm == 'MV':
            mv = MV(self.answer_file)
            e2lpd = mv.run()
            if self.show_acc:
                accuracy = self.getaccuracy(self.truth_file, e2lpd, mv.label_set)

        elif self.algorithm == 'DS':
            ds = DS(self.answer_file, 0.8)
            e2lpd, w2cm = ds.run(50)
            if self.show_acc:
                accuracy = self.getaccuracy(self.truth_file, e2lpd, ds.label_set)

        elif self.algorithm == 'GLAD':
            glad = GLAD(self.answer_file)
            e2lpd, weight = glad.run(1e-4)
            if self.show_acc:
                accuracy = self.getaccuracy(
                    self.truth_file, e2lpd, glad.label_set)
                # print(accuracy)

        elif self.algorithm == 'CATD':
            catd = CATD(self.answer_file, 'categorical')
            e2lpd, weight = catd.run(0.05, 100)
            if self.show_acc:
                accuracy = self.getaccuracy_catd(
                    self.truth_file, e2lpd, 'categorical')

        elif self.algorithm == 'PM':
            pm = PM(self.answer_file, 'categorical', '0/1 loss')
            e2lpd, weight = pm.run(10)
            if self.show_acc:
                accuracy = self.getaccuracy(
                    self.truth_file, e2lpd, pm.label_set)
                # print(accuracy)

        elif self.algorithm == 'BWA':
            df_label = pd.read_csv(self.answer_file)
            df_label = df_label.drop_duplicates(keep='first')
            prediction_ik = bwa(df_label.values)

            df_truth = pd.read_csv(self.truth_file)
            accuracy = self.get_acc(prediction_ik, df_truth)

        elif self.algorithm == 'IBCC':
            df_label = pd.read_csv(self.answer_file)
            df_label = df_label.drop_duplicates(keep='first')
            prediction_ik = ibcc(df_label.values)

            df_truth = pd.read_csv(self.truth_file)
            accuracy = self.get_acc(prediction_ik, df_truth)

        elif self.algorithm == 'EBCC':
            df_label = pd.read_csv(self.answer_file)
            df_label = df_label.drop_duplicates(keep='first')
            elbos = []
            seeds = []
            results = []
            for _ in range(1):
                seed = np.random.randint(1e8)
                prediction, elbo = ebcc_vb(df_label.values, num_groups=10, seed=seed, empirical_prior=True)
                elbos.append(elbo)
                results.append((prediction, seed, elbo))

            prediction_ik, seed, elbo = results[np.argmax(elbos)]

            df_truth = pd.read_csv(self.truth_file)
            accuracy = self.get_acc(prediction_ik, df_truth)

        return accuracy

