import networkx as nx
import numpy as np

from ...core.Base import LinUCB_Neighbor


class OffNCUCB_estimate_gamma(LinUCB_Neighbor):
    def __init__(self, nu, d, T, ni, choose_gamma_alpha, alpha, delta, C1, model="Pessimistic", gamma_initial=0,
                 choose_gamma_alpha_bias=0.005):
        super(OffNCUCB_estimate_gamma, self).__init__(nu, d, T, ni, alpha, delta, C1)
        self.nu = nu
        self.gamma_initial = gamma_initial
        self.G = nx.Graph()
        self.G.add_nodes_from(range(nu))
        self.final_neighbors = {user: [] for user in range(nu)}
        self.model = model
        self.choose_gamma_alpha = choose_gamma_alpha
        self.choose_gamma_alpha_bias = choose_gamma_alpha_bias

    def compute_CI_u(self, N1, delta, C1):
        numerator = np.sqrt(
            self.d * np.log(1 + N1 / (self.d * self.ucb_lambda)) + 2 * np.log(2 * self.nu / delta)) + np.sqrt(
            self.ucb_lambda)
        denominator = np.sqrt(self.ucb_lambda + C1 * N1)
        CI_u = numerator / denominator
        return CI_u

    def _compute_Gamma(self, u1, u2):
        theta_diff = np.linalg.norm(self.theta[u1] - self.theta[u2])
        CI_sum = self.compute_CI_sum(u1, u2)
        return theta_diff - self.choose_gamma_alpha_bias * CI_sum

    def connect_edges(self):
        theta_diff = np.linalg.norm(self.theta[:, np.newaxis, :] - self.theta[np.newaxis, :, :], axis=2)
        CI_sum = self.CI[:, np.newaxis] + self.CI[np.newaxis, :]
        Gamma_matrix = theta_diff - self.alpha * CI_sum

        valid_edges = np.argwhere((Gamma_matrix > self.gamma_initial) & (np.triu(np.ones_like(Gamma_matrix), k=1) == 1))
        self.G.add_edges_from(valid_edges)
        neighbor_lists = {node: list(self.G.neighbors(node)) for node in self.G.nodes}
        return neighbor_lists

    def connect_edges_for_test_user(self, test_user, gamma_value=0):
        theta_diff = np.linalg.norm(self.theta[test_user] - self.theta, axis=1)
        CI_sum = self.CI[test_user] + self.CI
        Gamma = theta_diff + self.alpha * CI_sum
        valid_neighbors = np.where(Gamma < gamma_value)[0]
        valid_neighbors = valid_neighbors[valid_neighbors != test_user]
        self.G.add_edges_from([(test_user, v) for v in valid_neighbors])
        return list(valid_neighbors)

    def estimate_gamma(self, test_user, initial_neighbors):
        M_u_test = []
        for v in initial_neighbors:
            Gamma = self._compute_Gamma(test_user, v)
            if Gamma > 0:
                M_u_test.append(v)

        if not M_u_test:
            return self.gamma_initial

        if self.model == "Pessimistic":
            pessimistic_gamma = min([
                np.linalg.norm(self.theta[test_user] - self.theta[v]) + (self.choose_gamma_alpha) * (
                        self.compute_CI_u(self.N[test_user], self.delta, self.C1) +
                        self.compute_CI_u(self.N[v], self.delta, self.C1))
                for v in M_u_test
            ])
            return pessimistic_gamma
        else:
            optimistic_gamma = min([
                np.linalg.norm(self.theta[test_user] - self.theta[v]) + (
                            self.choose_gamma_alpha - self.choose_gamma_alpha_bias) * (
                        self.compute_CI_u(self.N[test_user], self.delta, self.C1) +
                        self.compute_CI_u(self.N[v], self.delta, self.C1))
                for v in M_u_test
            ])
            return optimistic_gamma

    def compute_temp_theta(self, test_user):
        neighbor_indices = self.final_neighbors[test_user]
        S_temp = self.S[test_user] + self.S[neighbor_indices].sum(axis=0)
        b_temp = self.b[test_user] + self.b[neighbor_indices].sum(axis=0)
        theta_temp = np.linalg.solve(S_temp, b_temp)
        N_temp = self.N[test_user] + self.N[neighbor_indices].sum(axis=0)
        try:
            Sinv_temp = np.linalg.inv(S_temp)
        except np.linalg.LinAlgError:
            Sinv_temp = np.linalg.inv(self.S[test_user])
        return theta_temp, Sinv_temp, N_temp

    def precompute_CI(self):
        self.CI = np.array([self.compute_CI_u(self.N[u], self.delta, self.C1) for u in range(self.nu)])

    def compute_CI_sum(self, u1, u2):
        return self.CI[u1] + self.CI[u2]

    def store_info_Neighbor_test_user(self, test_user):
        i = test_user
        self.S_Neighbor[i] += self.S[i] - self.ucb_lambda * np.eye(self.d)
        self.b_Neighbor[i] += self.b[i]
        self.N_Neighbor[i] += self.N[i]

        # Aggregate neighbor states
        for neighbor in self.final_neighbors[i]:
            self.S_Neighbor[i] += self.S[neighbor] - self.ucb_lambda * np.eye(self.d)
            self.b_Neighbor[i] += self.b[neighbor]
            self.N_Neighbor[i] += self.N[neighbor]

        self.Sinv_Neighbor[i], self.theta_Neighbor[i] = self._update_inverse(
            self.S_Neighbor[i], self.b_Neighbor[i], self.Sinv_Neighbor[i], None, self.N_Neighbor[i]
        )

    def _select_item_ucb(self, S, Sinv, theta, items, N):
        return np.argmax(np.dot(items, theta) + 0.1*self.lcb_beta(N) * (np.matmul(items, Sinv) * items).sum(axis = 1))

    def test_recommend(self, i, items, t):
        return self._select_item_ucb(self.S_Neighbor[i], self.Sinv_Neighbor[i], self.theta_Neighbor[i], items, self.N_Neighbor[i])


    def run(self, envir):
        for t in range(self.offline_learn_T):
            self.I = envir.generate_users()
            for i in self.I:
                self.items = envir.get_items()
                recommended = self.offline_learn_recommend(t=t, i=i)
                x = self.items[recommended]
                y, r, br = envir.feedback(i=i, k=recommended)
                self.store_info(i=i, x=x, y=y, t=t, r=r, br=br)

        if len(self.theta) == self.d:
            self.Sinv, self.theta = self._update_inverse(self.S, self.b, 0, 0, 0)
        else:
            for user_index in range(self.nu):
                self.Sinv[user_index], self.theta[user_index] = self._update_inverse(self.S[user_index],
                                                                                     self.b[user_index], 0, 0, 0)

        self.precompute_CI()
        self.initial_neighbor_lists = self.connect_edges()
        self.G = nx.Graph()
        self.G.add_nodes_from(range(self.nu))

        test_users = list(range(self.nu))
        np.random.shuffle(test_users)
        estimated_gammas = []
        for idx, test_user in enumerate(test_users):
            initial_neighbors = self.initial_neighbor_lists[test_user]
            estimated_gamma = self.estimate_gamma(test_user, initial_neighbors)
            estimated_gammas.append(estimated_gamma)
            final_neighbors = self.connect_edges_for_test_user(test_user, estimated_gamma)
            self.final_neighbors[test_user] = final_neighbors
            self.store_info_Neighbor_test_user(test_user)
        print(estimated_gammas)
        print(np.mean(estimated_gammas))
        for t in range(self.offline_learn_T, self.T):
            self.I = envir.generate_users()
            for test_user in self.I:
                items = envir.get_items()
                recommended = self.test_recommend(test_user, items, self.offline_learn_T)
                x = items[recommended]
                y, r, br = envir.test_feedback(i=test_user, k=recommended)
                self.store_info_test(t=t, r=r, br=br)
        self.test_rewards[:] = self.rewards[self.offline_learn_T:]
        self.best_test_rewards[:] = self.best_rewards[self.offline_learn_T:]
        self.avg_estimate_gamma = np.mean(estimated_gammas)
        return self.avg_estimate_gamma
