import numpy as np
import matplotlib.pyplot as plt
from SVMAgent import SVMAgent, FastComNetwork, SVMOracle, DatasetModel
import random

# Hyperparameter configuration
NUM_AGENTS = 16
NUM_ROUNDS = 2000
T_RESTART = 100
DELTA = 0.001
D = 0.005  # Online learning bound (固定半径D)
R = 1  # Chebyshev acceleration rounds
p = 0.99  # Matrix diagonal function value (虽然没用到, 但保留)
BATCH_SIZE = 128
DATASET_NAME = 'rcv'

# 你要研究的eta (lr) 列表
etas = [0.01, 0.05, 0.1]

random.seed(42)
np.random.seed(42)

# Initialize dataset
dataset = DatasetModel(dsname=DATASET_NAME, num_agent=NUM_AGENTS, mb_size=BATCH_SIZE)
oracle = SVMOracle(alpha=2, lam=1e-5)


def create_matrix(n):
    return np.full((n, n), 1 / n)


# DOC²S training function (保持不变)
def train_DOC2S(agents, num_rounds, t_restart, oracle_type):
    network = FastComNetwork(create_matrix(NUM_AGENTS))
    losses = []
    max_fun = []
    mean_grad = []
    for k in range(num_rounds):  # Changed to a loop with a counter
        # Periodically reset agent actions
        if k % t_restart == (t_restart - 1):
            for agent in agents:
                agent.initialize_action()

        # Client sampling
        selected = np.random.randint(NUM_AGENTS)  # Client sampling

        # Get gradient estimate
        x_mb, y_mb = dataset.get_sample(selected)
        for m in range(NUM_AGENTS):
            agents[m].get_grad_point()
        new_weight = agents[selected].DOC2S_get_new_weight()
        agents[selected].set_weight(new_weight)

        grad_point = agents[selected].get_grad_points()

        if oracle_type == '1st':
            grad = oracle.get_gradients(grad_point, x_mb, y_mb)
        else:
            grad = oracle.get_zo_grad(grad_point, x_mb, y_mb, delta=DELTA)

        # Process Delta update for all clients
        for i, agent in enumerate(agents):
            if i == selected:
                # For the selected client: apply update with projection and multiply by n
                unprojected = agent.get_action() - agent.lr * grad
                norm = np.linalg.norm(unprojected)
                scale = min(1, agent.D / norm) if norm > 1e-8 else 1.0
                agent.set_action(agent.NUM_AGENTS * scale * unprojected)
            else:
                # For unselected clients: set Delta to 0
                agent.set_action(np.zeros_like(agent.get_action()))

        # Chebyshev accelerated communication
        network.propagate_actions(agents, R)
        network.propagate_weights(agents, R)

        # Record loss
        avg_w = network.get_average_weight(agents)
        losses.append(oracle.get_fn_val(avg_w, *dataset.get_test_set()))

        # Record maximum loss
        agent_wi = []
        for agent in agents:
            wi = agent.get_weight()
            agent_wi.append(wi)
        agent_losses = []
        for idx, weight in enumerate(agent_wi):
            loss = oracle.get_fn_val(weight, *dataset.get_test_set())
            agent_losses.append(loss)
        max_loss = max(agent_losses)
        max_fun.append(max_loss)

        # Calculate gradients
        all_gradients = []
        for m in range(NUM_AGENTS):
            if oracle_type == '1st':
                grad = oracle.get_gradients(agents[m].get_weight(), x_mb, y_mb)
            else:
                grad = oracle.get_zo_grad(grad_point, x_mb, y_mb, delta=DELTA)
            all_gradients.append(grad)
        # Modification: Calculate the maximum norm of all gradients
        avg_g = sum(all_gradients) / len(all_gradients)
        avg_norm = np.linalg.norm(avg_g)
        mean_grad.append(avg_norm)

    return losses


# --- 训练循环 ---
# 存储不同eta下的损失历史
all_losses = {}

for eta in etas:
    print(f"--- Training DOC²S with eta = {eta} ---")

    # 1. 为每个eta重新初始化agents, 传入当前的eta (lr) 和固定的D
    doc2s_agents = [SVMAgent(dataset.input_dim, id=i, lr=eta, D=D, NUM_AGENTS=NUM_AGENTS) for i in range(NUM_AGENTS)]

    # 2. 训练
    loss_history = train_DOC2S(doc2s_agents, NUM_ROUNDS, T_RESTART, '1st')

    # 3. 存储结果
    all_losses[eta] = np.array(loss_history)

print("--- All training finished ---")

# --- 绘图 ---

step = T_RESTART
# x-axis uses actual rounds [0, 100, 200, ...]
x = np.arange(0, NUM_ROUNDS, step)

# Set up the figure
plt.figure(figsize=(9, 8))

# 定义绘图样式
colors = ['black', 'blue', 'green']
markers = ['.', '^', 's']

# 循环绘制三条曲线
for i, (eta, loss_history) in enumerate(all_losses.items()):
    # 采样
    loss_sampled = loss_history[x]

    # 绘制
    plt.plot(x, loss_sampled, label=f'DOC²S (η={eta})',
             color=colors[i % len(colors)],
             marker=markers[i % len(markers)],
             markersize=11.5, linewidth=1.5)

# Add legend and labels
plt.legend()
plt.xlabel(r"$\mathrm{Computation~rounds}$", fontsize=31)
plt.ylabel(r"$\mathrm{Function~value}$", fontsize=31)

# Adjust axis range and ticks
plt.xlim(0, NUM_ROUNDS + 30)
plt.xticks(np.arange(0, NUM_ROUNDS, step * 3),
           fontsize=20)
plt.yticks(fontsize=20)

# Optimize legend style
plt.legend(fontsize=27, framealpha=0.9)

# Add grid lines
plt.grid(True, linestyle='--', alpha=0.3)

plt.tight_layout()
plt.savefig("doc2s_eta_comparison.png")
print("Plot saved to doc2s_eta_comparison.png")

# plt.show() # 在VM中注释掉show()