import numpy as np
import matplotlib.pyplot as plt
from SVMAgent import SVMAgent, MLPAgent, FastComNetwork, SVMOracle, MLPOracle, DatasetModel
import random

# --- Configuration ---
NUM_AGENTS = 16
NUM_ROUNDS = 5000
T_RESTART = 200
DEFAULT_D = 0.01
DEFAULT_R = 2
BATCH_SIZE_SVM = 128
BATCH_SIZE_MLP = 128
HIDDEN_DIM = 256

random.seed(42)
np.random.seed(42)


def create_matrix(n):
    return np.full((n, n), 1 / n)


def run_eta_exp(agent_cls, dataset, oracle, etas, ds_type='svm'):
    results = {}
    matrix = create_matrix(NUM_AGENTS)

    for eta in etas:
        print(f"Running {ds_type.upper()} with eta={eta}...")
        if ds_type == 'svm':
            agents = [agent_cls(dataset.input_dim, id=i, lr=eta, D=DEFAULT_D, NUM_AGENTS=NUM_AGENTS) for i in
                      range(NUM_AGENTS)]
        else:
            agents = [agent_cls(dataset.input_dim, HIDDEN_DIM, id=i, lr=eta, D=DEFAULT_D, NUM_AGENTS=NUM_AGENTS) for i
                      in range(NUM_AGENTS)]

        network = FastComNetwork(matrix)
        losses = []

        for k in range(NUM_ROUNDS):
            if k % T_RESTART == (T_RESTART - 1):
                for agent in agents: agent.initialize_action()

            selected = np.random.randint(NUM_AGENTS)
            x_mb, y_mb = dataset.get_sample(selected)
            for m in range(NUM_AGENTS): agents[m].get_grad_point()
            new_w = agents[selected].DOC2S_get_new_weight()
            agents[selected].set_weight(new_w)

            grad = oracle.get_gradients(agents[selected].get_weight(), x_mb, y_mb)

            for i, agent in enumerate(agents):
                if i == selected:
                    unproj = agent.get_action() - agent.lr * grad
                    norm = np.linalg.norm(unproj)
                    scale = min(1, agent.D / norm) if norm > 1e-8 else 1.0
                    agent.set_action(agent.NUM_AGENTS * scale * unproj)
                else:
                    agent.set_action(np.zeros_like(agent.get_action()))

            network.propagate_actions(agents, DEFAULT_R)
            network.propagate_weights(agents, DEFAULT_R)

            if k % 10 == 0:
                avg_w = network.get_average_weight(agents)
                loss = oracle.get_fn_val(avg_w, *dataset.get_test_set())
                losses.append(loss)
        results[eta] = losses
    return results


# --- Run Experiments ---
etas = [0.1, 0.05, 0.01]

print(">>> SVM Experiment...")
ds_svm = DatasetModel(dsname='a9a', num_agent=NUM_AGENTS, mb_size=BATCH_SIZE_SVM, max_sample=5000)
oracle_svm = SVMOracle(alpha=2, lam=1e-5)
res_svm = run_eta_exp(SVMAgent, ds_svm, oracle_svm, etas, 'svm')

print(">>> MLP Experiment...")
ds_mlp = DatasetModel(dsname='mnist', num_agent=NUM_AGENTS, mb_size=BATCH_SIZE_MLP, max_sample=3000)
oracle_mlp = MLPOracle(lam=1e-5, hidden_dim=HIDDEN_DIM)
res_mlp = run_eta_exp(MLPAgent, ds_mlp, oracle_mlp, etas, 'mlp')

# --- Plotting ---
x_axis = np.arange(0, NUM_ROUNDS, 10)
fig, axs = plt.subplots(1, 2, figsize=(16, 6))

for eta, loss in res_svm.items():
    axs[0].plot(x_axis, loss, label=f'$\eta={eta}$')
axs[0].set_title('SVM/a9a - Step Size vs Loss')
axs[0].set_xlabel('Computation Rounds')
axs[0].set_ylabel('Function Value')
axs[0].legend()
axs[0].grid(True, alpha=0.3)

for eta, loss in res_mlp.items():
    axs[1].plot(x_axis, loss, label=f'$\eta={eta}$')
axs[1].set_title('MLP/MNIST - Step Size vs Loss')
axs[1].set_xlabel('Computation Rounds')
axs[1].set_ylabel('Function Value')
axs[1].legend()
axs[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('Task3_StepSize_Results.png')
print("Results saved to Task3_StepSize_Results.png")