import numpy as np
import matplotlib.pyplot as plt
from SVMAgent import SVMAgent, MLPAgent, FastComNetwork, SVMOracle, MLPOracle, DatasetModel
import random

# --- Configuration ---
NUM_AGENTS = 16
NUM_ROUNDS = 5000
T_RESTART = 200
DEFAULT_D = 0.05
DEFAULT_ETA = 0.05
DEFAULT_R = 2
BATCH_SIZE_SVM = 128
BATCH_SIZE_MLP = 128
HIDDEN_DIM = 256

random.seed(42)
np.random.seed(42)


def create_ring_matrix_neighbors(n, neighbors_count):
    W = np.zeros((n, n))
    k = (neighbors_count - 1) // 2
    weight = 1.0 / neighbors_count
    for i in range(n):
        for offset in range(-k, k + 1):
            neighbor = (i + offset) % n
            W[i, neighbor] = weight
    return W


def run_connectivity_exp(agent_cls, dataset, oracle, neighbors_list, ds_type='svm'):
    results = {}

    for nb in neighbors_list:
        print(f"Running {ds_type.upper()} with neighbors={nb}...")
        if ds_type == 'svm':
            agents = [agent_cls(dataset.input_dim, id=i, lr=DEFAULT_ETA, D=DEFAULT_D, NUM_AGENTS=NUM_AGENTS) for i in
                      range(NUM_AGENTS)]
        else:
            agents = [agent_cls(dataset.input_dim, HIDDEN_DIM, id=i, lr=DEFAULT_ETA, D=DEFAULT_D, NUM_AGENTS=NUM_AGENTS)
                      for i in range(NUM_AGENTS)]

        matrix = create_ring_matrix_neighbors(NUM_AGENTS, nb)
        network = FastComNetwork(matrix)
        losses = []

        for k in range(NUM_ROUNDS):
            if k % T_RESTART == (T_RESTART - 1):
                for agent in agents: agent.initialize_action()

            selected = np.random.randint(NUM_AGENTS)
            x_mb, y_mb = dataset.get_sample(selected)

            for m in range(NUM_AGENTS): agents[m].get_grad_point()
            new_w = agents[selected].DOC2S_get_new_weight()
            agents[selected].set_weight(new_w)

            # Use 1st order for connectivity test
            grad = oracle.get_gradients(agents[selected].get_weight(), x_mb, y_mb)

            for i, agent in enumerate(agents):
                if i == selected:
                    unproj = agent.get_action() - agent.lr * grad
                    norm = np.linalg.norm(unproj)
                    scale = min(1, agent.D / norm) if norm > 1e-8 else 1.0
                    agent.set_action(agent.NUM_AGENTS * scale * unproj)
                else:
                    agent.set_action(np.zeros_like(agent.get_action()))

            network.propagate_actions(agents, DEFAULT_R)
            network.propagate_weights(agents, DEFAULT_R)

            if k % 10 == 0:
                avg_w = network.get_average_weight(agents)
                loss = oracle.get_fn_val(avg_w, *dataset.get_test_set())
                losses.append(loss)
        results[nb] = losses
    return results


# --- Run Experiments ---
neighbors = [3, 5, 7, 9]

print(">>> SVM Experiment...")
ds_svm = DatasetModel(dsname='a9a', num_agent=NUM_AGENTS, mb_size=BATCH_SIZE_SVM, max_sample=5000)
oracle_svm = SVMOracle(alpha=2, lam=1e-5)
res_svm = run_connectivity_exp(SVMAgent, ds_svm, oracle_svm, neighbors, 'svm')

print(">>> MLP Experiment...")
ds_mlp = DatasetModel(dsname='mnist', num_agent=NUM_AGENTS, mb_size=BATCH_SIZE_MLP, max_sample=3000)
oracle_mlp = MLPOracle(lam=1e-5, hidden_dim=HIDDEN_DIM)
res_mlp = run_connectivity_exp(MLPAgent, ds_mlp, oracle_mlp, neighbors, 'mlp')

# --- Plotting ---
x_axis = np.arange(0, NUM_ROUNDS, 10)
fig, axs = plt.subplots(1, 2, figsize=(16, 6))

for nb, loss in res_svm.items():
    axs[0].plot(x_axis, loss, label=f'Neighbors={nb}')
axs[0].set_title('SVM/a9a - Connectivity vs Loss')
axs[0].set_xlabel('Computation Rounds')
axs[0].set_ylabel('Function Value')
axs[0].legend()
axs[0].grid(True, alpha=0.3)

for nb, loss in res_mlp.items():
    axs[1].plot(x_axis, loss, label=f'Neighbors={nb}')
axs[1].set_title('MLP/MNIST - Connectivity vs Loss')
axs[1].set_xlabel('Computation Rounds')
axs[1].set_ylabel('Function Value')
axs[1].legend()
axs[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('Task2_Connectivity_Results.png')
print("Results saved to Task2_Connectivity_Results.png")