import sys
import getopt
import math
from confidence_interval import ConfidenceInterval


def nextword(target, source):
    for i, w in enumerate(source):
        if w == target:
            return source[i + 1]


class LossStructOff:
    def __init__(
        self,
        n,
        h,
        srm,
        srm_penalty,
        test_loss,
        max_loss,
        nb_examples,
        ci_lower,
        ci_upper,
    ):
        self.n = n
        self.h = h
        self.srm = srm
        self.srm_penalty = srm_penalty
        self.test_loss = test_loss
        self.max_loss = max_loss
        self.nb_examples = nb_examples
        self.ci_lower = ci_lower
        self.ci_upper = ci_upper


class EvaluatorOffline:
    def __init__(self, srm_file_name, test_file_name, delta, alpha, quiet):
        self.srm_file_name = srm_file_name
        self.test_file_name = test_file_name
        self.costs = []
        self.initial = LossStructOff(0, 0, 0, sys.float_info.max, 0, 0, 0, 0, 0)
        self.optimized = LossStructOff(0, 0, 0, sys.float_info.max, 0, 0, 0, 0, 0)
        self.delta = delta
        self.conf_alpha = alpha
        self.quiet = quiet
        self.pmin = 0.05  # epsilon = 5%

    def eval(self):

        data_file = open(self.srm_file_name, "r")
        line = data_file.readline()

        count = 0
        while line:
            # Get data
            if line.find("CATS-offline") != -1:
                self.costs.append(LossStructOff(0, 0, 0, 0, 0, 0, 0, 0, 0))
                count += 1
            elif line.find("n = ") != -1:
                separator_position = len("n = ")
                separator_position_end = line.find("\n")
                self.costs[len(self.costs) - 1].n = float(
                    line[separator_position:separator_position_end]
                )

            elif line.find("h = ") != -1:
                separator_position = len("h = ")
                separator_position_end = line.find("\n")
                self.costs[len(self.costs) - 1].h = float(
                    line[separator_position:separator_position_end]
                )

            elif line.find("srm") != -1:
                s1 = line.split()
                self.costs[len(self.costs) - 1].srm = float(nextword("=", s1))

            line = data_file.readline()

        data_file = open(self.test_file_name, "r")
        line = data_file.readline()

        count = 0
        while line:
            # Get data
            if line.find("CATS-offline") != -1:
                count += 1

            elif line.find("n = ") != -1:
                separator_position = len("n = ")
                separator_position_end = line.find("\n")
                if self.costs[count - 1].n != float(
                    line[separator_position:separator_position_end]
                ):
                    print("error: n is not matched")

            elif line.find("h = ") != -1:
                separator_position = len("h = ")
                separator_position_end = line.find("\n")
                if self.costs[count - 1].h != float(
                    line[separator_position:separator_position_end]
                ):
                    print("error: h is not matched")

            elif line.find("test_loss") != -1:
                s1 = line.split()
                self.costs[count - 1].test_loss = float(nextword("=", s1))

            elif line.find("max_loss") != -1:
                s1 = line.split()
                self.costs[count - 1].max_loss = float(nextword("=", s1))

            elif line.find("nb_examples") != -1:
                s1 = line.split()
                self.costs[count - 1].nb_examples = float(nextword("=", s1))

            line = data_file.readline()

        self.calc_srm_penalty()
        self.get_optimized()

        self.saveConfidenceIntervals(self.initial)
        self.saveConfidenceIntervals(self.optimized)

        if not self.quiet:
            self.printAllResults()
            print("\ninitial model:")
            self.printResults(self.initial)
            print("optimized model:")
            self.printResults(self.optimized)

    def return_loss(self, model):
        if model == "init":
            return self.initial.test_loss, self.initial.ci_lower, self.initial.ci_upper
        elif model == "opt":
            return (
                self.optimized.test_loss,
                self.optimized.ci_lower,
                self.optimized.ci_upper,
            )

    def calc_srm_penalty(self):
        for c in self.costs:
            # c.srm_penalty = c.srm + math.sqrt(c.n * self.delta /(c.h * self.pmin * c.nb_examples)) # todo: fix
            c.srm_penalty = (
                c.srm
                + math.sqrt(
                    c.n * self.delta * c.srm / (c.h * self.pmin * c.nb_examples)
                )
                + (c.n * self.delta / (c.h * self.pmin * c.nb_examples))
            )  # todo: fix

    def get_optimized(self):
        if (
            self.costs[0].n != 4 or self.costs[0].h != 1
        ):  # todo: if changed to depth parameters
            print("error in finding initial model")
        self.initial = self.costs[0]
        for c in self.costs:
            if c.srm_penalty < self.optimized.srm_penalty:
                self.optimized = c

    def saveConfidenceIntervals(self, cost):
        cost.ci_lower, cost.ci_upper = ConfidenceInterval.calculate(
            cost.nb_examples, cost.test_loss, cost.max_loss, self.conf_alpha
        )

    def printAllResults(self):
        for cost in self.costs:
            print(
                "n, h, srm, srm_penalty, test_loss = {0}, {1}, {2}, {3}, {4}".format(
                    cost.n, cost.h, cost.srm, cost.srm_penalty, cost.test_loss
                )
            )
            print("C.I. = {0}, {1}".format(cost.ci_lower, cost.ci_upper))

    def printResults(self, cost):
        print(
            "n, h, srm, srm_penalty, test_loss = \n {0}, {1}, {2}, {3}, {4}".format(
                cost.n, cost.h, cost.srm, cost.srm_penalty, cost.test_loss
            )
        )
        print("C.I. = {0}, {1}".format(cost.ci_lower, cost.ci_upper))


if __name__ == "__main__":
    srm_file = "../../results/black_friday_offline_srm.txt"
    test_file = "../../results/black_friday_offline_test.txt"
    delta = 1
    alpha = 0.05
    model = "init"
    quiet = False

    # Parse options - get predict and data file names
    args = sys.argv[1:]
    opts, args = getopt.getopt(
        args,
        "d:p:c:a:r:q",
        ["srm_file=", "test_file", "delta=", "alpha=", "return_model=", "quiet"],
    )
    for opt, arg in opts:
        if opt in ("-d", "--srm_file"):
            srm_file = arg
        if opt in ("-p", "--test_file"):
            test_file = arg
        elif opt in ("-c", "--delta"):
            delta = float(arg)
        elif opt in ("-a", "--alpha"):
            alpha = float(arg)
        elif opt in ("-r", "--return_model"):
            model = arg
        elif opt in ("-q", "--quiet"):
            quiet = True

    # Print join lines to stdout
    fileJoiner = EvaluatorOffline(srm_file, test_file, delta, alpha, quiet)
    returnValue = fileJoiner.eval()
    print(fileJoiner.return_loss(model))
