import torch
import numpy as np
import networkx as nx
EPS = 10E-8

col_softmax = torch.nn.Softmax(dim=1)

class FedExp3:
    def __init__(self, n_agents, n_arms, W, expr_scheduler, device, rng, lr=0.1) -> None:
        self.Z = torch.zeros([n_agents, n_arms], device=device)
        self.W = W
        self.lr = lr
        self.expr_scheduler = expr_scheduler
        self.X0 = torch.ones([n_agents, n_arms], device=device) / n_arms
        self.device = device
        self.rng = rng

    def action(self):
        K = self.Z.shape[-1]
        X = col_softmax(-self.lr*self.Z)
        gamma = next(self.expr_scheduler)
        P = (1 - gamma) * X + gamma * self.X0
        A = torch.multinomial(P, num_samples=1, generator= self.rng)
        A_one_hot = torch.nn.functional.one_hot(A, num_classes=K).squeeze(1)
        return A_one_hot, P

    def update(self, loss_matrix, actions, probs):
        L_t = loss_matrix.to(self.device)
        G = L_t * actions / probs
        self.Z = torch.mm(self.W.float(), self.Z.float()) + G

class FedFTRL:
    def __init__(self, n_agents, n_arms, graph, W, lr, gamma, D, device, rng, spectral_gap) -> None:
        self.L = torch.zeros([n_agents, n_arms], device=device)
        self.x = torch.ones([n_agents, n_arms], device=device) / n_arms
        self.lr_list = lr
        self.gamma_list = gamma
        self.W = W
        self.device = device
        self.rng = rng
        self.bias = torch.zeros([n_agents], device= device)
        self.V = n_agents
        self.K = n_arms
        self.graph = graph
        self.time_step = 0
        self.D = D
        self.spectral_gap = spectral_gap
        self.deviation = torch.zeros([3001, n_arms], device= device)

    def action(self):
        self.time_step += 1
        self.learning_rate = next(self.lr_list)
        self.gamma = next(self.gamma_list)
        self.solve_optimization()
        self.x /= self.x.sum(dim=1, keepdim=True)
        A = torch.multinomial(self.x, num_samples=1, generator=self.rng)
        A_one_hot = torch.nn.functional.one_hot(A, num_classes=self.K).squeeze(1)
        return A_one_hot, self.x

    def update(self, loss_matrix, actions, probs):
        L_t = loss_matrix.to(self.device)
        C_t = min(np.log(self.time_step) + np.log(self.V), np.sqrt(self.V)) / self.spectral_gap + 2 + self.D
        G = L_t * actions / probs.clamp(min=10.0 * C_t * self.gamma)
        G0 = L_t * actions / probs

        self.deviation[self.time_step] += torch.sum(torch.relu(G0 - G), 0)

        self.L = torch.mm(self.W.float(), self.L.float()) + self.V * (G + actions * self.bias.unsqueeze(1))

        if self.time_step >= self.D:
            self.L += self.deviation[self.time_step - self.D].repeat(self.V, 1)

        self.bias.zero_()
    def solve_optimization(self):
        with torch.no_grad():
            max_iter = 1000
            V, K = self.V, self.K
            device = self.device

            mark = torch.zeros((V,), device=device)
            lower = torch.full((V,), -torch.inf, device=self.device)
            upper = torch.full((V,), torch.inf, device=self.device)
            step = torch.ones(V, device=device)
            for _ in range(max_iter):
                x_pre = self.x
                self.x = self.solve_unconstrained(self.L + self.bias[:, None], self.x)
                f = self.x.sum(dim=1) - 1.0
                df = (0.5 * self.learning_rate / (self.x ** 1.5) + self.gamma / self.x).reciprocal().sum(dim=1)

                next_bias = self.bias + f / df
                above = f > 0
                lower = torch.where(above, self.bias, lower)
                upper = torch.where(~above, self.bias, upper)

                step.mul_(2.)

                self.bias.zero_()
                self.bias += torch.where(above & (upper == torch.inf) & (next_bias > lower + step), lower + step, mark)
                self.bias += torch.where(above & (upper == torch.inf) & (next_bias <= lower + step), next_bias, mark)
                self.bias += torch.where(above & (upper != torch.inf) & (next_bias > upper), (lower + upper) * 0.5, mark)
                self.bias += torch.where(above & (upper != torch.inf) & (next_bias <= upper), next_bias, mark)

                self.bias += torch.where(~above & (lower == -torch.inf) & (next_bias < upper - step), upper - step, mark)
                self.bias += torch.where(~above & (lower == -torch.inf) & (next_bias >= upper - step), next_bias, mark)
                self.bias += torch.where(~above & (lower != -torch.inf) & (next_bias < lower), (lower + upper) * 0.5, mark)
                self.bias += torch.where(~above & (lower != -torch.inf) & (next_bias >= lower), next_bias, mark)

                if torch.max(torch.abs(f)) < 100 * EPS:
                    break

                if torch.max(torch.abs(x_pre - self.x)) < 100 * EPS:
                    break

    def solve_unconstrained(self, loss, warmstart):
        with torch.no_grad():
            x_val = warmstart

            for _ in range(100):
                func_val = loss - self.learning_rate / torch.sqrt(x_val) + (torch.log(x_val) * self.gamma)
                dif_func_val = 0.5 * self.learning_rate / (x_val ** 1.5) + self.gamma / x_val
                dif_x = func_val / dif_func_val

                dif_x = torch.where(dif_x > x_val, 0.5 * x_val, dif_x)
                dif_x = torch.where(dif_x < x_val - 1.0, 0.5 * (x_val - 1.0), dif_x)
                if torch.max(torch.abs(dif_x)) < EPS:
                    break
                x_val_new = x_val - dif_x
                x_val = x_val_new.clamp(min=EPS)

            return x_val

class FTRL:
    def __init__(self, n_agents, n_arms,graph, W, lr, D,device, rng, spectral_gap) -> None:
        self.L = torch.zeros([n_agents, n_arms], device=device)
        self.x = torch.ones([n_agents, n_arms], device=device) / n_arms
        self.lr_list = lr
        self.W = W
        self.device = device
        self.rng = rng
        self.bias = torch.zeros([n_agents], device=device)
        self.V = n_agents
        self.K = n_arms
        self.graph = graph
        self.time_step = 0
        self.D = D
        self.spectral_gap = spectral_gap

    def action(self):
        self.time_step += 1
        self.learning_rate = next(self.lr_list)
        self.solve_optimization()
        self.x /= self.x.sum(dim=1, keepdim=True)
        A = torch.multinomial(self.x, num_samples=1, generator=self.rng)
        A_one_hot = torch.nn.functional.one_hot(A, num_classes=self.K).squeeze(1)
        return A_one_hot, self.x

    def update(self, loss_matrix, actions, probs):
        L_t = loss_matrix.to(self.device)
        C_t = min(np.log(self.time_step) + np.log(self.V), np.sqrt(self.V)) / self.spectral_gap + 2 + self.D
        G = L_t * actions / probs + actions * self.bias.unsqueeze(1)

        self.L += G
        self.bias.zero_()

    def solve_optimization(self):
        with torch.no_grad():
            max_iter = 100
            V, K = self.V, self.K
            device = self.device

            lower = torch.full((V,), -torch.inf, device=device)
            upper = torch.full((V,), torch.inf, device=device)
            step = torch.ones(V, device=device)

            for _ in range(max_iter):
                x_pre = self.x
                self.x = self.solve_unconstrained(self.L + self.bias[:, None], self.x)
                f = self.x.sum(dim=1) - 1.0
                df = (0.5 * self.learning_rate / (self.x ** 1.5)).reciprocal().sum(dim=1)

                next_bias = self.bias + f / df
                above = f > 0
                lower = torch.where(above, self.bias, lower)
                upper = torch.where(~above, self.bias, upper)

                step.mul_(2.)

                beta_clipped = torch.where(
                    above & (next_bias > lower + step), lower + step,
                    torch.where(~above & (next_bias < upper - step), upper - step, next_bias)
                )

                mid = 0.5 * (lower + upper)
                beta_clipped = torch.where(above & (beta_clipped > upper), mid, beta_clipped)
                beta_clipped = torch.where(~above & (beta_clipped < lower), mid, beta_clipped)

                self.bias = beta_clipped

                if torch.max(torch.abs(f)) < 100 * EPS:
                    break

                if torch.max(torch.abs(x_pre - self.x)) < 100 * EPS:
                    break

    def solve_unconstrained(self, loss, warmstart):
        with torch.no_grad():
            x_val = warmstart

            for _ in range(100):
                func_val = loss - self.learning_rate / torch.sqrt(x_val)
                dif_func_val = 0.5 * self.learning_rate / (torch.sqrt(x_val) ** 3)
                dif_x = func_val / dif_func_val

                dif_x = torch.where(dif_x > x_val, 0.5 * x_val, dif_x)
                dif_x = torch.where(dif_x < x_val - 1.0, 0.5 * (x_val - 1.0), dif_x)
                if torch.max(torch.abs(dif_x)) < EPS:
                    break
                x_val_new = x_val - dif_x
                x_val = x_val_new.clamp(min=EPS)

            return x_val

class Gossip_UCB:
    def __init__(self, n_agents, n_arms, graph, T, W, device) -> None:
        self.V = n_agents
        self.K = n_arms
        self.graph = graph
        self.T = T
        self.W = W
        self.device = device
        self.time_step = 0
        self.C = torch.zeros([n_agents, n_arms], device= device)
        self.num_sample = torch.zeros([n_agents, n_arms], device= device)
        self.n_estimate = torch.zeros([n_agents, n_arms], device= device)
        self.sample_mean = torch.zeros([n_agents, n_arms], device= device)
        self.cum_reward_estimate = torch.zeros([n_agents, n_arms], device= device)
        self.theta = torch.zeros([n_agents, n_arms], device= device)

    def action(self):
        self.time_step += 1
        actions = torch.zeros(self.V, device= self.device, dtype= torch.long)
        if self.time_step <= self.K:
            actions[:] += self.time_step - 1
            self.num_sample[:, self.time_step - 1] += 1
            self.n_estimate = self.num_sample.clone()
            actions = actions.cpu().long().to(self.device)
            A_one_hot = torch.nn.functional.one_hot(actions, num_classes=self.K).squeeze(1)
            return A_one_hot

        num_sample_expa = self.num_sample.unsqueeze(1).expand(self.V, self.V, self.K)
        n_estimate_expa = self.n_estimate.unsqueeze(0).expand(self.V, self.V, self.K)
        mask = (self.W != 0).unsqueeze(-1)
        combined = torch.maximum(num_sample_expa, n_estimate_expa) * mask
        self.new_n_estimate = new_n_estimate = combined.max(dim=1).values

        A_mask = self.num_sample < (new_n_estimate - self.V)
        Q = self.theta + self.C

        has_candidate = A_mask.any(dim=1)

        no_candidate_idx = (~has_candidate).nonzero(as_tuple=True)[0]

        if len(no_candidate_idx) > 0:
            actions[no_candidate_idx] = torch.argmax(Q[no_candidate_idx], dim=1)
        if len(no_candidate_idx) < self.V:
            candidate_idx = has_candidate.nonzero(as_tuple=True)[0]
            A_mask_candidates = A_mask[candidate_idx]
            valid_arms = [torch.where(row)[0] for row in A_mask_candidates]
            rand_choices = torch.cat([arms[torch.randint(0, len(arms), (1,))] for arms in valid_arms]).to(self.device)
            actions[candidate_idx] = rand_choices

        self.num_sample.scatter_add_(1, actions.unsqueeze(1),
                                     torch.ones_like(actions, dtype=torch.float, device=self.device).unsqueeze(1))
        actions = actions.cpu().long().to(self.device)
        A_one_hot = torch.nn.functional.one_hot(actions, num_classes=self.K).squeeze(1)
        return A_one_hot

    def update(self, loss_matrix, actions):
        L_t = loss_matrix.to(self.device) * actions
        self.cum_reward_estimate += L_t
        a = 0.1
        if self.time_step <= self.K:
            self.C[:, self.time_step - 1] = torch.sqrt(torch.tensor(2 * a * np.log(self.time_step)))
            self.sample_mean = self.theta = self.cum_reward_estimate
        else:
            self.C = torch.sqrt(torch.tensor(2 * a * np.log(self.time_step)) / self.num_sample)
            new_sample_mean = self.cum_reward_estimate / self.num_sample
            self.theta = torch.mm(self.W.float(), self.theta.float()) + new_sample_mean - self.sample_mean
            self.sample_mean = new_sample_mean

class DRRB_bandit:
    def __init__(self, n_agents, n_arms, T, W, graph, delta, device, spectral_gap, D= 1.6) -> None:
        self.V = n_agents
        self.K = n_arms
        self.D = D
        self.T = T
        self.W = W
        self.graph = graph
        self.device = device
        self.delta = delta
        self.time_step = 0
        self.U = torch.ones([n_agents, n_arms], device= device)
        self.S = list(range(n_arms))
        self.B = torch.zeros((self.V, self.K), device=device)
        self.mu = torch.zeros([n_agents, n_arms], device= device)
        self.tau = torch.zeros([n_agents, n_arms], device= device)
        self.choice_index = 0
        self.time_elimination = torch.zeros(n_arms, device= device)
        self.spectral_gap = spectral_gap
        self.actions = torch.zeros(n_agents, device= device)
        self.cache_reward = torch.zeros([n_agents, n_arms], device= device)
        self.max_mu = torch.zeros(n_agents, device= device)

    def action(self):
        self.time_step += 1
        actions = self.actions = self.S[self.choice_index]
        self.tau[:, actions] += 1
        self.choice_index = (self.choice_index + 1) % len(self.S)
        actions = torch.tensor(actions).repeat(self.V, 1)
        actions = actions.cpu().long().to(self.device)
        A_one_hot = torch.nn.functional.one_hot(actions, num_classes=self.K).squeeze(1)
        return A_one_hot

    def update(self, loss_matrix, actions):
        X_t = loss_matrix.to(self.device) * actions
        if self.time_step <= self.K:
            self.mu += X_t
            return

        self.cache_reward += X_t

        k = self.actions
        mu_k = self.mu[:, k]
        U_k = self.U[:, k]
        if self.time_elimination[k] == 0 and (mu_k < self.max_mu[0] - 2 * U_k).any():
            self.time_elimination[k] = self.time_step + len(self.S) * self.D
        tau_k = self.tau[:, k]
        sqrt_term = torch.sqrt(-np.log(self.delta) / (2.0 * self.V * tau_k))
        self.U[:, k] = sqrt_term + 0.3 / ((tau_k + 1))

        if self.choice_index != 0:
            return
        tmp = torch.mm(self.W.float(), self.mu.float())
        self.mu = (1 - 1.0 / (self.tau + 1)) * tmp + self.cache_reward / (self.tau + 1)
        self.max_mu = torch.max(self.mu, 1)
        k = 0
        while len(self.S) > 1 and k < len(self.S):
            if self.time_elimination[self.S[k]] != 0 and self.time_step >= self.time_elimination[self.S[k]]:
                self.S.remove(self.S[k])
                k -= 1
            k += 1
        self.cache_reward.zero_()

class CommNet:
    def __init__(self, nx_graph) -> None:
        self.comm_net = nx_graph

    def max_deg_gossip(self, spectral_gap=False):
        degrees = [val for (node, val) in self.comm_net.degree()]
        max_deg = max(degrees)
        D = np.diag(degrees)
        A = nx.to_numpy_array(self.comm_net)
        P = np.eye(len(degrees)) - (D - A) / (max_deg + 1)
        if spectral_gap:
            return P, compute_spectral_gap(P)
        return P

    def get_diameter(self):
        return nx.diameter(self.comm_net)

def cube_FedExp3_scheduler(spectral_gap=0, T=1., V=1., gamma=0.01):
    step = 1
    C_W = min(2 * np.log(T) + np.log(V), np.sqrt(V)) / spectral_gap + 3
    while True:
        yield gamma * (C_W ** (1 / 3)) / step ** (1 / 3)
        step += 1

def cube_FedFTRL_lr(V, D):
    step = 1
    while True:
        yield 10.0 / np.sqrt(V * step)
        step += 1

def cube_FedFTRL_gamma(D, K, spectral_gap, V):
    step = 1
    while True:
        C_t = min(np.log(V * step), np.sqrt(V)) / spectral_gap + 2 + D
        yield 1.0 / (8 * V * np.sqrt(step * C_t / np.log(K)))
        step += 1

def cube_FTRL_lr(V, D):
    step = 1
    while True:
        yield 1.0 / np.sqrt(step)
        step += 1

def compute_spectral_gap(P):
    if isinstance(P, torch.Tensor):
        P_cpu = P.detach().cpu().numpy()
    else:
        P_cpu = P
    singular_values = np.linalg.svd(P_cpu, compute_uv=False, hermitian=True)
    gap = 1 - singular_values[1]
    return gap

def fedexp3_ub_exact(n_epochs, n_agents, n_arms, spectral_gap, lr_array, gamma_array):
    C_w = 3 + min(
        2 * np.log(n_epochs) + np.log(n_agents),
        np.sqrt(n_agents)
    ) / spectral_gap
    lr_last = lr_array[-1]
    gamma_last = gamma_array[-1]
    cum_reg = np.log(n_arms) / lr_last
    for lr, gamma in zip(lr_array, gamma_array):
        first = n_arms ** 2 / 2 * lr / gamma
        second = n_arms ** 2 / gamma_last * C_w * lr
        third = gamma
        ins = first + second + third
        cum_reg += ins
        yield cum_reg