import sys
import getopt
from confidence_interval import ConfidenceInterval


def nextword(target, source):
    for i, w in enumerate(source):
        if w == target:
            return source[i + 1]


class LossStructOn:
    def __init__(
        self, model, n, h, loss, time, max_cost, nb_examples, ci_lower, ci_upper
    ):
        self.model = model
        self.n = n
        self.h = h
        self.loss = loss
        self.time = time
        self.max_cost = max_cost
        self.nb_examples = nb_examples
        self.ci_lower = ci_lower
        self.ci_upper = ci_upper


class EvaluatorOnline:
    def __init__(self, file_name, alpha, quiet):
        self.file_name = file_name
        self.conf_alpha = alpha
        self.costs = []
        self.best_cats = LossStructOn("cats", 0, 0, sys.float_info.max, 0, 0, 0, 0, 0)
        self.best_disc_tree = LossStructOn(
            "disc_tree", 0, 0, sys.float_info.max, 0, 0, 0, 0, 0
        )
        self.best_disc_linear = LossStructOn(
            "disc_linear", 0, 0, sys.float_info.max, 0, 0, 0, 0, 0
        )
        self.max_time = 0.0
        self.quiet = quiet

    def eval(self):

        data_file = open(self.file_name, "r")
        line = data_file.readline()

        while line:
            # Get data
            if line.find("CATS-online") != -1:
                self.costs.append(
                    LossStructOn("cats", 0, 0, sys.float_info.max, 0, 0, 0, 0, 0)
                )
            elif line.find("Discretized-Tree-online") != -1:
                self.costs.append(
                    LossStructOn("disc_tree", 0, 0, sys.float_info.max, 0, 0, 0, 0, 0)
                )
            elif line.find("Discretized-Linear-online") != -1:
                self.costs.append(
                    LossStructOn("disc_linear", 0, 0, sys.float_info.max, 0, 0, 0, 0, 0)
                )

            elif line.find("timeout") != -1:
                s1 = line.split()
                self.max_time = float(nextword("timeout", s1))

            elif line.find("n = ") != -1:
                separator_position = len("n = ")
                separator_position_end = line.find("\n")
                self.costs[len(self.costs) - 1].n = float(
                    line[separator_position:separator_position_end]
                )

            elif line.find("h = ") != -1:
                separator_position = len("h = ")
                separator_position_end = line.find("\n")
                self.costs[len(self.costs) - 1].h = float(
                    line[separator_position:separator_position_end]
                )

            elif line.find("Max Cost=") != -1:
                separator_position = len("Max Cost=")
                self.costs[len(self.costs) - 1].max_cost = float(
                    line[separator_position:]
                )

            elif line.find("number of examples") != -1:
                s1 = line.split()
                self.costs[len(self.costs) - 1].nb_examples = int(nextword("=", s1))

            elif line.find("average loss") != -1:
                s1 = line.split()
                self.costs[len(self.costs) - 1].loss = float(nextword("=", s1))

            elif line.find("real") != -1:
                s1 = line.split()
                self.costs[len(self.costs) - 1].time = float(nextword("real", s1))

            line = data_file.readline()

        self.get_best_loss()

        self.saveConfidenceIntervals(self.best_cats)
        self.saveConfidenceIntervals(self.best_disc_tree)
        self.saveConfidenceIntervals(self.best_disc_linear)

        if not self.quiet:
            self.printAllResults()

            print("max_time = ", self.max_time)

            self.printBestResults(self.best_cats)
            self.printBestResults(self.best_disc_tree)
            self.printBestResults(self.best_disc_linear)

        self.find_error()

    def return_loss(self, model):
        if model == "cats":
            return self.best_cats.loss, self.best_cats.ci_lower, self.best_cats.ci_upper
        elif model == "disc_tree":
            return (
                self.best_disc_tree.loss,
                self.best_disc_tree.ci_lower,
                self.best_disc_tree.ci_upper,
            )
        elif model == "disc_linear":
            return (
                self.best_disc_linear.loss,
                self.best_disc_linear.ci_lower,
                self.best_disc_linear.ci_upper,
            )

    def return_all(self, model):
        n_ = []
        h_ = []
        loss_ = []
        time_ = []
        for c in self.costs:
            if c.model == model:
                if c.loss < 1:
                    loss_.append(c.loss)
                    time_.append(c.time)
                    n_.append(c.n)
                    h_.append(c.h)
        return loss_, time_, n_, h_

    def get_best_loss(self):
        for c in self.costs:
            if c.model == "cats":
                if c.loss < self.best_cats.loss:
                    self.best_cats = c
            elif c.model == "disc_tree":
                if c.loss < self.best_disc_tree.loss:
                    self.best_disc_tree = c
            elif c.model == "disc_linear":
                if c.loss < self.best_disc_linear.loss:
                    self.best_disc_linear = c

    def saveConfidenceIntervals(self, cost):
        if cost.max_cost != 0:
            cost.ci_lower, cost.ci_upper = ConfidenceInterval.calculate(
                cost.nb_examples, cost.loss, cost.max_cost, self.conf_alpha
            )

    def getTime(self, model, n, hp, h, mode):  # assumes costs is soreted wrt hp and n
        times = []
        if mode == "hp":
            n_ = []
            for c in self.costs:
                if c.model == model:
                    if c.h == hp:
                        times.append(c.time)
                        n_.append(c.n)
            return times, n_

        elif mode == "h":
            n_ = []
            for c in self.costs:
                if c.model == model:
                    if (c.h / c.n) == h:
                        times.append(c.time)
                        n_.append(c.n)
            return times, n_

        elif mode == "n":
            h_ = []
            for c in self.costs:
                if c.model == model:
                    if c.n == n:
                        times.append(c.time)
                        h_.append(c.h)
            return times, h_

    def printAllResults(self):
        for cost in self.costs:
            print(
                "model, n, h, loss, time = {0}, {1}, {2}, {3}, {4}".format(
                    cost.model, cost.n, cost.h, cost.loss, cost.time
                )
            )

    def printBestResults(self, cost):
        print(
            "model, n, h, loss, time = {0}, {1}, {2}, {3}, {4}".format(
                cost.model, cost.n, cost.h, cost.loss, cost.time
            )
        )
        print("C.I. = {0}, {1}".format(cost.ci_lower, cost.ci_upper))

    def find_error(self):
        for c in self.costs:
            if c.loss == sys.float_info.max:
                if c.time < self.max_time:
                    print("error in model={0}, n={1}, h={2}".format(c.model, c.n, c.h))


if __name__ == "__main__":
    namee = "BNG_cpu_act"
    data_file = "../../results/" + namee + "_online_validation.txt"
    alpha = 0.05
    model = "cats"
    quiet = False

    # Parse options - get predict and data file names
    args = sys.argv[1:]
    opts, args = getopt.getopt(
        args, "d:a:r:q", ["data_file=", "alpha=", "return_model=", "quiet"]
    )
    for opt, arg in opts:
        if opt in ("-d", "--data_file"):
            data_file = arg
        elif opt in ("-a", "--alpha"):
            alpha = float(arg)
        elif opt in ("-r", "--return_model"):
            model = arg
        elif opt in ("-q", "--quiet"):
            quiet = True

    # Print join lines to stdout
    fileJoiner = EvaluatorOnline(data_file, alpha, quiet)
    returnValue = fileJoiner.eval()
    print(fileJoiner.return_loss(model))
    print(fileJoiner.getTime("disc_linear", 0, 0, 0, "hp"))
