import numpy as np
import matplotlib

matplotlib.rcParams["text.usetex"] = True
import matplotlib.pyplot as plt

rng = np.random.default_rng(2383)
float_type = np.float64

T = 100_000
delta = 1.0 / T


class DataFromLowerBound:
    def __init__(self) -> None:
        self.base = np.zeros(dtype=float_type, shape=(K, K))
        self.theta = np.ones(D) / D / 4
        self.theta[rng.random(D) < 0.5] *= -1
        self.theta[-1] = 0.25
        self.p = np.zeros(dtype=float_type, shape=(K, K))
        self.phi = np.zeros(dtype=float_type, shape=(K, K, D))
        self.phi_outers = np.zeros(dtype=float_type, shape=(K, K, D, D))

        assert self.base.shape == self.phi.shape[:2]

        self.gen_data()

    def gen_data(self):
        # init_base():
        for i in range(K):
            for j in range(K):
                if i < two_pow_d and j < two_pow_d:
                    self.base[i, j] = 0.5
                if i >= two_pow_d and j >= two_pow_d:
                    self.base[i, j] = 0.5
                if i < two_pow_d and j >= two_pow_d:
                    self.base[i, j] = 0.75
                if i >= two_pow_d and j < two_pow_d:
                    self.base[i, j] = 0.25

        # init_phi():
        for i in range(K):
            for j in range(K):
                val = 0
                sign = 1
                if i < two_pow_d and j < two_pow_d:
                    continue
                if i >= two_pow_d and j >= two_pow_d:
                    continue
                if i < two_pow_d and j >= two_pow_d:
                    self.phi[i, j][-1] = 1
                    val = i
                if i >= two_pow_d and j < two_pow_d:
                    self.phi[i, j][-1] = -1
                    val = j
                    sign = -1
                for d in range(D - 1):
                    self.phi[i, j][d] = (val % 2 * 2 - 1) * sign  # lsb first in [0]
                    val >>= 1
                self.phi_outers[i, j, :, :] = np.outer(self.phi[i, j], self.phi[i, j])

        # F = lambda x: 1 / (1 + np.exp(-x)) - 0.5
        F = lambda x: x + 0.5
        self.p = F(self.phi @ self.theta)
        self.B = np.sum(self.p, axis=1) / K  # borda score
        self.B_star = self.B.max()

        self.R = [0]

    def sample(self, i, j):
        for _ in range(np.int32(self.N[i, j])):
            if self.sample_once(i, j):
                self.r[i, j] += 0.5
            else:
                self.r[i, j] += -0.5

    def sample_once(self, i, j):
        self.R.append(self.R[-1] + 2 * self.B_star - self.B[i] - self.B[j])
        if rng.random() < self.p[i, j]:
            return 1
        else:
            return 0


class DataEvents:
    def __init__(self) -> None:
        data_dict = np.load("data/events_ds_post.npz")
        self.theta = data_dict["theta"]
        self.phi = data_dict["phi"]
        self.p = data_dict["p"]
        # self.p = data_dict["p_est"]
        self.phi_outers = np.zeros(dtype=float_type, shape=(K, K, D, D))

        for i in range(K):
            for j in range(K):
                self.phi_outers[i, j, :, :] = np.outer(self.phi[i, j], self.phi[i, j])

        # self.p = F(self.phi @ self.theta)
        self.B = np.sum(self.p, axis=1) / K  # borda score
        self.B_star = self.B.max()

        self.R = [0]

    def sample(self, i, j):
        for _ in range(np.int32(self.N[i, j])):
            if self.sample_once(i, j):
                self.r[i, j] += 1

    def sample_once(self, i, j):
        self.R.append(self.R[-1] + 2 * self.B_star - self.B[i] - self.B[j])
        if rng.random() < self.p[i, j]:
            return 1
        else:
            return 0


class BETC_Linear:
    def __init__(self) -> None:
        super().__init__()
        self.N = np.zeros(dtype=float_type, shape=(K, K))
        self.r = np.zeros(dtype=float_type, shape=(K, K))
        self.pi = np.ones(dtype=float_type, shape=(K, K))
        self.pi = self.pi / self.pi.sum()
        self.g_pi_m = np.zeros(dtype=float_type, shape=(K, K))

    def find_gpi(self, V_inv):
        for i in range(K):
            for j in range(K):
                a = self.phi[i, j, :]
                res = a.T @ (V_inv) @ (a)
                self.g_pi_m[i, j] = res
                if a[0] != 0:
                    pass
        idx = np.unravel_index(np.argmax(self.g_pi_m), self.g_pi_m.shape)
        return self.g_pi_m[idx], idx

    def run(self):
        # G-optimal design
        # iters = np.ceil(np.log(np.log(K * K)) * D).astype(np.int32)
        iters = 20
        for _ in range(iters):
            # print("----", pi.sum())
            np.set_printoptions(formatter={"float": "{: 0.3f}".format})
            # print(pi, "pi")

            V = np.zeros(dtype=float_type, shape=(D, D))
            # V_inv = np.zeros(dtype=float_type, shape=(D, D))
            for i in range(K):
                for j in range(K):
                    V += self.pi[i, j] * self.phi_outers[i, j, :, :]

            # print(V, "V")
            V_inv = np.linalg.inv(V)
            # print(V_inv, "V_inv")
            g_pi, a_k = self.find_gpi(V_inv)
            # print("g_pi*", g_pi, "a*", a_k)
            gamma_k = (1.0 / D * g_pi - 1) / (g_pi - 1)
            # print("gamma_k", gamma_k)

            # print(np.sum(self.pi * self.g_pi_m), "sum of pi * g_pi_m", D, "D")
            self.pi = (1 - gamma_k) * self.pi
            self.pi[a_k[0], a_k[1]] += gamma_k
        # print(self.pi[self.pi > 1.0 / K / K], "self.pi, large elements")
        # print(np.argwhere(self.pi > 1.0 / K / K), "their indices")
        print(
            "explored dims",
            np.sum(self.pi > 1.0 / K / K),
            "|supp(pi)|",
            D * (D + 1) / 2,
        )

        # eps = np.log(K / delta) ** (1.0 / 3) * (1.0 * D / T) ** (-1.0 / 3)
        # print(eps, "old eps")
        # update N(i, j)
        for i in range(K):
            for j in range(K):
                # use with old eps
                # self.N[i, j] = np.ceil(
                #     2 * D * self.pi[i, j] / eps / eps * np.log(K * K / delta)
                # )
                self.N[i, j] = np.ceil(2 * D * self.pi[i, j] / eps / eps)

                if self.N[i, j] != 1:
                    print(self.N[i, j], "N({},{})".format(i, j))

        # initial rounds to make sure design matrix is non singular
        for t in range(np.ceil(tau).astype(np.int32)):
            i = rng.integers(0, K)
            j = rng.integers(0, K)
            self.sample_once(i, j)

        # explore according to optimal design
        for i in range(K):
            for j in range(K):
                self.sample(i, j)

        theta_hat = self.estimate_theta()

        for i in range(K):
            for j in range(i + 1, K):
                check_val = np.abs((theta_hat - self.theta) @ self.phi[i, j]) - eps
                # assert check_val - eps <= 0, "{} > {}".format(check_val, eps)

        # TODO: class argument
        if self.__class__ == BETC_Linear:
            B_hat = np.sum((self.phi @ theta_hat + 0.5) / K, axis=1)
        else:
            B_hat = np.sum(F(self.phi @ theta_hat) / K, axis=1)
        i_star = np.argmax(B_hat)
        print(B_hat, "B_hat")
        print(B_hat.max(), "B_hat.max()", i_star, "B_hat.argmax()")
        print(self.B_star, "B_star", self.B.argmax(), "i_star")

        t = len(self.R) - 1
        for i in range(t, T):
            self.sample_once(i_star, i_star)

    def estimate_theta(self):
        # MSE like regression estimator in closed form
        V = np.zeros(dtype=float_type, shape=(D, D))
        for i in range(K):
            for j in range(K):
                a = self.phi[i, j]
                V += np.outer(a, a) * self.N[i, j]
        V_inv = np.linalg.inv(V)

        theta_hat = np.zeros(dtype=float_type, shape=(D))
        for i in range(K):
            for j in range(K):
                theta_hat += self.phi[i, j] * self.r[i, j]
        theta_hat = V_inv @ theta_hat
        print(theta_hat, "theta_hat")
        print(self.theta, "theta")
        return theta_hat


class BETC_GLM(BETC_Linear):
    def estimate_theta(self):
        # MLE for GLM model
        X = self.phi.reshape(-1, D)
        y = self.r.reshape(-1)
        # f = lambda theta: ((X @ theta * y) - np.log(1 + np.exp(X @ theta))).sum()
        df = lambda theta: (
            X * (y.reshape(-1, 1) @ np.ones(shape=(1, D)))
            - X / (1 + np.exp(-X @ theta).reshape(-1, 1) @ np.ones(shape=(1, D)))
        ).sum(axis=0)

        theta = np.ones(D) / D
        lr = 1e-4 * 5
        for t in range(100):
            # print(f(theta))
            # print(theta)
            d = lr * df(theta)
            theta += d
            # print(np.sum(d * d))
        # print(df(theta), "grad")

        theta_hat = theta
        print(theta_hat, "theta_hat")
        print(self.theta, "theta")
        return theta_hat


class DEXP3:
    def run(self):
        q = np.ones(dtype=float_type, shape=(K,)) / K
        a1 = 1
        b1 = 1
        eta = a1 * np.power((np.log(K) / T / np.sqrt(K)), 2.0 / 3)
        gamma = b1 * np.sqrt(eta * K)
        s = np.zeros(dtype=float_type, shape=(K,))
        for t in range(T):
            x = rng.multinomial(1, q).argmax()
            y = rng.multinomial(1, q).argmax()
            ret = self.sample_once(x, y)
            tmp = 1.0 / K / q[x] * ret / q[y]
            s[x] += tmp
            q_tilda = np.exp(eta * s) / np.exp(eta * s).sum()
            q = (1 - gamma) * q_tilda + gamma / K
            if t % 10000 == 0:
                print(x, y, t, "t")
        print(q)


class UCB_Borda:
    def run(self):
        alpha = 0.3
        N = np.zeros(dtype=float_type, shape=(K))
        W = np.zeros(dtype=float_type, shape=(K))
        for t in range(T):
            mu = W / (N + 1e-10)
            cb = np.sqrt(np.log(t) / (N + 1e-10))
            i = np.argmax(mu + alpha * cb)
            j = rng.integers(0, K)
            ret = self.sample_once(i, j)
            N[i] += 1
            W[i] += ret
            if t % 10000 == 0:
                print(i, j, t, "t")


class ETC_Borda:
    def run(self):
        N = np.ceil(
            np.power(K, -2.0 / 3)
            * np.power(T, 2.0 / 3)
            * np.power(np.log(K / delta), 1.0 / 3)
        ).astype(np.int32)
        N = (np.ceil(1.0 * N / K) * K).astype(np.int32)
        print(N, "N", T, K)
        T1 = K * N
        nb = np.zeros(dtype=float_type, shape=(K))
        wb = np.zeros(dtype=float_type, shape=(K))
        for t in range(T1):
            x = t % K
            y = (t // K) % K
            ret = self.sample_once(x, y)
            if ret:
                wb[x] += 1
            nb[x] += 1
            if t % 10000 == 0:
                print(x, y, t, "t")
        i_star = np.argmax(wb / nb)
        print(i_star, "i_star")
        for t in range(T1, T):
            self.sample_once(i_star, i_star)


class BEXP_Linear:
    def run(self):
        from scipy.special import softmax

        eta = (
            np.power(np.log(K), 2.0 / 3) * np.power(D, -1.0 / 3) * np.power(T, -2.0 / 3)
        )
        gamma = np.sqrt(eta * D)
        eta *= 1
        print(eta, gamma, "eta, gamma")
        q = np.ones(dtype=float_type, shape=(K,)) / K
        W = np.zeros(dtype=float_type, shape=(K))
        for t in range(T):
            print(t, end="\r")
            gen_mult = lambda: np.sum((np.cumsum(q) < rng.random()))
            x = gen_mult()
            y = gen_mult()
            # print(x, y)
            ret = self.sample_once(x, y)
            if ret == 0:
                ret = -1
            Q = (self.phi_outers * np.outer(q, q).reshape(K, K, 1, 1)).sum(axis=(0, 1))
            # for i in range(K):
            #     for j in range(K):
            #         Q += q[i] * q[j] * np.outer(self.phi[i, j], self.phi[i, j])
            Q_inv = np.linalg.inv(Q)
            theta_hat = Q_inv @ self.phi[x, y] * ret
            B_hat = np.sum(self.phi @ theta_hat, axis=1) / K
            W += eta * B_hat
            q_tilda = softmax(W)
            q = (1 - gamma) * q_tilda + gamma / K
        print(q, "q", np.argmax(q), "argmax")
        print(self.B, "B_gt", np.argmax(self.B), "argmax")
        print(theta_hat, "theta_hat")


def eps_set1():
    global eps, tau
    eps = np.power(1.0 * D / T, 1.0 / 3) * np.power(np.log(3 * K * K / delta), -1.0 / 6)
    tau = np.power(D * np.log(3 * K * K / delta) * np.power(T, 2.0 / 3), 1.0 / 3)
    print(eps, "eps1", tau, "tau")


def eps_set2():
    global eps, tau
    eps = np.power(D, 1.0 / 6) * np.power(T, -1.0 / 3)
    tau = D + np.log(1.0 / delta)
    print(eps, "eps2", tau, "tau")


class ETC_Borda_sim(ETC_Borda, DataFromLowerBound):
    def __init__(self) -> None:
        super(ETC_Borda, self).__init__()
        super().__init__()


class BETC_Linear_sim(BETC_Linear, DataFromLowerBound):
    def __init__(self) -> None:
        eps_set1()
        super(BETC_Linear, self).__init__()
        super().__init__()


class BETC_Linear_Match_sim(BETC_Linear, DataFromLowerBound):
    def __init__(self) -> None:
        eps_set2()
        super(BETC_Linear, self).__init__()
        super().__init__()


class BEXP3_sim(BEXP_Linear, DataFromLowerBound):
    def __init__(self) -> None:
        super(BEXP3_sim, self).__init__()
        super().__init__()


class DEXP3_sim(DEXP3, DataFromLowerBound):
    def __init__(self) -> None:
        super(DEXP3, self).__init__()
        super().__init__()


class UCB_Borda_sim(UCB_Borda, DataFromLowerBound):
    def __init__(self) -> None:
        super(UCB_Borda, self).__init__()
        super().__init__()


class ETC_Borda_ev(ETC_Borda, DataEvents):
    def __init__(self) -> None:
        super(ETC_Borda, self).__init__()
        super().__init__()


class BETC_GLM_ev(BETC_GLM, DataEvents):
    def __init__(self) -> None:
        eps_set1()
        super(BETC_GLM, self).__init__()
        super().__init__()


class BETC_GLM_Match_ev(BETC_GLM, DataEvents):
    def __init__(self) -> None:
        eps_set2()
        super(BETC_GLM, self).__init__()
        super().__init__()


class DEXP3_ev(DEXP3, DataEvents):
    def __init__(self) -> None:
        super(DEXP3, self).__init__()
        super().__init__()


class UCB_Borda_ev(UCB_Borda, DataEvents):
    def __init__(self) -> None:
        super(UCB_Borda, self).__init__()
        super().__init__()


class BEXP3_ev(BEXP_Linear, DataEvents):
    def __init__(self) -> None:
        super(BEXP3_ev, self).__init__()
        super().__init__()


def simulation():
    # print(DEXP3_sim.__mro__)
    global K, D, two_pow_d, F, eps, tau
    no_sim = True
    ######## Hard case #########################
    algo_classes = [DEXP3_sim]
    algo_classes = [ETC_Borda_sim]
    algo_classes = [BEXP3_sim]
    algo_classes = [
        UCB_Borda_sim,
        DEXP3_sim,
        ETC_Borda_sim,
        BETC_Linear_sim,
        BETC_Linear_Match_sim,
        BEXP3_sim,
    ]
    # no_sim = False
    filename_prefix = "simulated"
    D = 6
    assert D <= 31
    K = 2 ** (D + 1)
    two_pow_d = 2**D
    D = D + 1  # for the bias
    F = lambda x: x + 0.5

    # ###### Real Data ########################
    # algo_classes = [BETC_GLM_ev, BETC_GLM_Match_ev]
    # algo_classes = [ETC_Borda_ev]
    # algo_classes = [
    #     UCB_Borda_ev,
    #     DEXP3_ev,
    #     ETC_Borda_ev,
    #     BETC_GLM_ev,
    #     BETC_GLM_Match_ev,
    #     BEXP3_ev,
    # ]
    # # no_sim = False
    # filename_prefix = "events"
    # D = 5
    # K = 100
    # two_pow_d = 2**D
    # F = lambda x: 1.0 / (1 + np.exp(-x))

    iters = 50
    for algo_cls in algo_classes:
        print("-----------------------", algo_cls.__name__, "-----------------------")
        if no_sim:
            break  # comment out to actually run simulation, otherwise just plot
        this_algo = algo_cls()
        R_hist = []
        for i in range(iters):
            this_algo.__init__()
            this_algo.run()
            R_hist.append(np.array(this_algo.R))
            print(i, "i")
        R_hist = np.array(R_hist)
        np.savez(
            "figure/{}_{}.npz".format(filename_prefix, algo_cls.__name__), R_hist=R_hist
        )

    tick_gap = T // 15
    fig, ax = plt.subplots()
    fig.set_size_inches(3, 3)
    ax.set_xlabel("T")
    ax.set_ylabel("Regret(T)")
    ax.ticklabel_format(axis="y", scilimits=[0, 2])
    ax.grid(True)
    fig.tight_layout()
    for algo_cls in algo_classes:
        R_hist = np.load("figure/{}_{}.npz".format(filename_prefix, algo_cls.__name__))[
            "R_hist"
        ][:, ::tick_gap]
        mean = R_hist.mean(axis=0)
        std = R_hist.std(axis=0)
        ticks = mean.shape[0]
        x_t = np.array(range(ticks)) * tick_gap
        # plt.errorbar(
        #     x=np.array(range(ticks)) * tick_gap,
        #     y=mean,
        #     yerr=std,
        #     label=algo_cls.__name__.strip("_sim").strip("_ev").replace("_", "-"),
        # )
        plt.plot(
            x_t,
            mean,
            label=algo_cls.__name__.strip("_sim")
            .strip("_ev")
            .replace("_", "-")
            .replace("Linear", "GLM")
            # .replace("-Match", ""),
        )
        plt.fill_between(
            x=x_t,
            y1=mean - std,
            y2=mean + std,
            alpha=0.20,
        )

    # plt.plot(betc.R, label="BETC")
    # plt.plot(dexp3.R, label="DEXP3")
    plt.legend(loc="upper left", prop={"size": 9})

    # plt.show()
    fig_name = f"figure/{filename_prefix}.pdf"
    plt.savefig(fig_name, bbox_inches="tight")
    plt.close()
    import os

    os.system(f"pdfcrop {fig_name} {fig_name} > /dev/null")
    os.system(f"pdffonts {fig_name} | grep 'Type 3'")


def realdata():
    a = BETC_GLM_ev()
    a.run()


if __name__ == "__main__":
    simulation()
    # realdata()
