import numpy as np
from ...core.Base import LinUCB_IND
import math


class Cluster:
    def __init__(self, users, S, b, N, checks):
        self.users = users  # a list/array of users
        self.S = S
        self.b = b
        self.N = N
        self.checks = checks
        self.Sinv = np.linalg.inv(self.S)
        self.theta = np.matmul(self.Sinv, self.b)
        self.checked = len(self.users) == sum(self.checks.values())

    def update_check(self, i):
        self.checks[i] = True
        self.checked = len(self.users) == sum(self.checks.values())


class SCLUB(LinUCB_IND):
    def __init__(self, nu, d, T, ni):
        super(SCLUB, self).__init__(nu, d, T, ni)

        self.clusters = {
            0: Cluster(
                users=[i for i in range(nu)],
                S=self.ucb_lambda * np.eye(d),
                b=np.zeros(d),
                N=0,
                checks={i: False for i in range(nu)}
            )
        }
        self.cluster_inds = np.zeros(nu, dtype=int)

        self.T = T  # T
        self.num_clusters = np.ones(self.T)
        self.max_refine_iters = 5

    def _init_each_stage(self):
        for c in self.clusters:
            self.clusters[c].checks = {i: False for i in self.clusters[c].users}
            self.clusters[c].checked = False

    def recommend(self, i, items, t):
        c = int(self.cluster_inds[i])
        cluster = self.clusters[c]
        return self._select_item_ucb(cluster.S, cluster.Sinv, cluster.theta, items, cluster.N, t)

    def test_recommend(self, i, items, t):
        c = int(self.cluster_inds[i])
        cluster = self.clusters[c]
        return self._select_item_ucb(cluster.S, cluster.Sinv, cluster.theta, items, cluster.N, t)
        
    def _factT(self, T):
        return np.sqrt((1 + np.log(1 + T)) / (1 + T))

    def store_info_test(self, t, r, br):
        self.rewards[t] = r
        self.best_rewards[t] = br

    def _split_or_merge(self, theta, N1, N2, split=True):
        alpha = 1
        if split:
            return np.linalg.norm(theta) > alpha * (self._factT(N1) + self._factT(N2))
        else:
            return np.linalg.norm(theta) < alpha * (self._factT(N1) + self._factT(N2)) / 2

    def _cluster_avg_freq(self, c, t):
        # return self.clusters[c].N / (len(self.clusters[c].users) * t)
        # Avoid division by zero/NaN when t==0 or the cluster is empty
        denom = len(self.clusters[c].users) * t
        if denom <= 0:
            return 0.0
        num = self.clusters[c].N
        if not np.isfinite(num):
            return 0.0
        return float(num) / float(denom)

    def _split_or_merge_p(self, p1, p2, t, split=True):
        alpha_p = np.sqrt(2)
        if split:
            return np.abs(p1 - p2) > alpha_p * self._factT(t)
        else:
            return np.abs(p1 - p2) < alpha_p * self._factT(t) / 2

    def split(self, i, t):
        c = self.cluster_inds[i]
        cluster = self.clusters[c]

        cluster.update_check(i)

        if self._split_or_merge_p(
                self.N[i] / (t + 1),
                self._cluster_avg_freq(c, t + 1),
                t + 1,
                split=True
        ) or self._split_or_merge(
            self.theta[i] - cluster.theta,
            self.N[i],
            cluster.N,
            split=True
        ):

            def _find_available_index():
                cmax = max(self.clusters)
                for c1 in range(cmax + 1):
                    if c1 not in self.clusters:
                        return c1
                return cmax + 1

            cnew = _find_available_index()
            self.clusters[cnew] = Cluster(
                users=[i],
                S=self.S[i],
                b=self.b[i],
                N=self.N[i],
                checks={i: True}
            )
            self.cluster_inds[i] = cnew

            cluster.users.remove(i)
            cluster.S = cluster.S - self.S[i] + self.ucb_lambda * np.eye(self.d)
            cluster.b = cluster.b - self.b[i]
            cluster.N = cluster.N - self.N[i]
            del cluster.checks[i]

    def merge(self, t):
        cmax = max(self.clusters)
        for c1 in range(cmax + 1):
            if c1 not in self.clusters or not self.clusters[c1].checked:
                continue

            for c2 in range(c1 + 1, cmax + 1):
                if c2 not in self.clusters or not self.clusters[c2].checked:
                    continue

                if self._split_or_merge(
                        self.clusters[c1].theta - self.clusters[c2].theta,
                        self.clusters[c1].N,
                        self.clusters[c2].N,
                        split=False
                ) and self._split_or_merge_p(
                    self._cluster_avg_freq(c1, t + 1),
                    self._cluster_avg_freq(c2, t + 1),
                    t + 1,
                    split=False
                ):
                    for i in self.clusters[c2].users:
                        self.cluster_inds[i] = c1

                    self.clusters[c1].users += self.clusters[c2].users
                    self.clusters[c1].S += self.clusters[c2].S - self.ucb_lambda * np.eye(self.d)
                    self.clusters[c1].b += self.clusters[c2].b
                    self.clusters[c1].N += self.clusters[c2].N
                    self.clusters[c1].checks = {**self.clusters[c1].checks, **self.clusters[c2].checks}

                    del self.clusters[c2]

    # def offline_learn(self, envir):
    #     T = self.T  # T
    #     S = math.floor(math.log2(T))

    #     for s in range(S - 1):
    #         for t in range(2 ** s):
    #             self._init_each_stage()
    #             tau = 2 ** s + t - 1
    #             I = envir.generate_users()
    #             for i in I:
    #                 self.items = envir.get_items()
    #                 recommended = self.offline_learn_recommend(t=t, i=i)
    #                 x = self.items[recommended]
    #                 y, r, br = envir.feedback(i, recommended)
    #                 self.store_info(i, x, y, tau, r, br)

    #     additional_learning_time = T // 4
    #     print(f"Additional Learning Stage up to {additional_learning_time}", end=' ')
    #     for t in range(additional_learning_time):
    #         if t % 5000 == 0:
    #             print(f"{t // 5000} ", end='')
    #         self._init_each_stage()
    #         tau = additional_learning_time + t  # tau
    #         I = envir.generate_users()
    #         for i in I:
    #             items = envir.get_items()
    #             recommended = self.test_recommend(i, items, tau)
    #             x = items[recommended]
    #             y, r, br = envir.test_feedback(i, recommended)
    #             self.store_info_test(t=t, r=r, br=br)
    #         self.num_clusters[tau] = len(self.clusters)

    def update(self):
        tau = self.offline_learn_T- 1
        for user_index in range(self.nu):
            c = int(self.cluster_inds[user_index])
            self.clusters[c].S += self.S[user_index] - self.ucb_lambda * np.eye(self.d)
            self.clusters[c].b += self.b[user_index]
            self.clusters[c].N += self.N[user_index]
        for c in range(len(self.clusters)):
            self.clusters[c].Sinv, self.clusters[c].theta = self._update_inverse(self.clusters[c].S,self.clusters[c].b,0,0,0)

        prev_assignments = self.cluster_inds.copy()
        prev_num_clusters = len(self.clusters)

        # for i in range(self.nu):
        #     self.split(i, tau)
        # self.merge(tau)
        # self.num_clusters[tau] = len(self.clusters)
        # # After split/merge changes cluster structure, recompute each cluster's Sinv/theta and write back to users
        # for c in self.clusters:
        #     self.clusters[c].Sinv, self.clusters[c].theta = self._update_inverse(
        #         self.clusters[c].S, self.clusters[c].b, 0, 0, 0
        #     )
        # for i in range(self.nu):
        #     ci = int(self.cluster_inds[i])
        #     self.theta[i] = self.clusters[ci].theta

        for _ in range(self.max_refine_iters):
            for i in range(self.nu):
                self.split(i, tau)
            self.merge(tau)
            # Recompute parameters of all clusters after split/merge
            for c in self.clusters:
                self.clusters[c].Sinv, self.clusters[c].theta = self._update_inverse(
                    self.clusters[c].S, self.clusters[c].b, 0, 0, 0
                )
            # Write back user theta
            for i in range(self.nu):
                ci = int(self.cluster_inds[i])
                self.theta[i] = self.clusters[ci].theta
            # Convergence check
            changed = (not np.array_equal(self.cluster_inds, prev_assignments)) or (len(self.clusters) != prev_num_clusters)
            if not changed:
                break
            prev_assignments = self.cluster_inds.copy()
            prev_num_clusters = len(self.clusters)

        self.num_clusters[tau] = len(self.clusters)
        print("new")

    def run(self, envir):
        for t in range(self.offline_learn_T):
            self.I = envir.generate_users()
            for i in self.I:
                self.items = envir.get_items()
                recommended = self.offline_learn_recommend(t=t, i=i)
                x = self.items[recommended]
                y, r, br = envir.feedback(i=i, k=recommended)
                self.store_info(i=i, x=x, y=y, t=t, r=r, br=br)
        if len(self.theta) == self.d:
            self.Sinv, self.theta = self._update_inverse(self.S, self.b, 0, 0, 0)
        else:
            for user_index in range(self.nu):
                self.Sinv[user_index], self.theta[user_index] = self._update_inverse(self.S[user_index], self.b[user_index], 0, 0, 0)

        self.update()
        test_users = list(range(self.nu))
        np.random.shuffle(test_users)
        for t in  range(self.offline_learn_T, self.T):
            self.I = envir.generate_users()
            for test_user in self.I:
                self.items = envir.get_items()
                recommended = self.test_recommend(test_user, self.items, self.offline_learn_T)
                y, r, br = envir.test_feedback(i=test_user, k=recommended)
                self.store_info_test(t=t, r=r, br=br)
        self.test_rewards[:] = self.rewards[self.offline_learn_T:]
        self.best_test_rewards[:] = self.best_rewards[self.offline_learn_T:]
        return 0