import numpy as np
import matplotlib.pyplot as plt
from SVMAgent import SVMAgent, MLPAgent, FastComNetwork, SVMOracle, MLPOracle, DatasetModel
import random

# --- Common Configuration ---
NUM_AGENTS = 16
NUM_ROUNDS = 5000
T_RESTART = 200
DELTA = 0.001
DEFAULT_D = 0.01
DEFAULT_ETA = 0.05
BATCH_SIZE_SVM = 128
BATCH_SIZE_MLP = 128
HIDDEN_DIM = 256

random.seed(42)
np.random.seed(42)


def create_default_matrix(n):
    return np.full((n, n), 1 / n)  # Or default ring


# --- Training Routine ---
def run_experiment(agent_cls, dataset, oracle, oracle_type, R_vals, ds_type='svm'):
    results = {}
    matrix = create_default_matrix(NUM_AGENTS)

    for R in R_vals:
        print(f"Running {ds_type.upper()} {oracle_type} order, R={R}...")
        # Init Agents
        if ds_type == 'svm':
            agents = [agent_cls(dataset.input_dim, id=i, lr=DEFAULT_ETA, D=DEFAULT_D, NUM_AGENTS=NUM_AGENTS) for i in
                      range(NUM_AGENTS)]
        else:
            agents = [agent_cls(dataset.input_dim, HIDDEN_DIM, id=i, lr=DEFAULT_ETA, D=DEFAULT_D, NUM_AGENTS=NUM_AGENTS)
                      for i in range(NUM_AGENTS)]

        network = FastComNetwork(matrix)
        consensus_errors = []

        for k in range(NUM_ROUNDS):
            if k % T_RESTART == (T_RESTART - 1):
                for agent in agents: agent.initialize_action()

            selected = np.random.randint(NUM_AGENTS)
            x_mb, y_mb = dataset.get_sample(selected)

            # Prepare ZO
            for m in range(NUM_AGENTS): agents[m].get_grad_point()

            # Tentative Update
            new_w = agents[selected].DOC2S_get_new_weight()
            agents[selected].set_weight(new_w)

            # Gradient
            grad_pt = agents[selected].get_grad_points()
            if oracle_type == '1st':
                grad = oracle.get_gradients(agents[selected].get_weight(), x_mb, y_mb)
            else:
                grad = oracle.get_zo_grad(grad_pt, x_mb, y_mb, delta=DELTA)

            # Action Update
            for i, agent in enumerate(agents):
                if i == selected:
                    unproj = agent.get_action() - agent.lr * grad
                    norm = np.linalg.norm(unproj)
                    scale = min(1, agent.D / norm) if norm > 1e-8 else 1.0
                    agent.set_action(agent.NUM_AGENTS * scale * unproj)
                else:
                    agent.set_action(np.zeros_like(agent.get_action()))

            # Communication
            network.propagate_actions(agents, R)
            network.propagate_weights(agents, R)

            # Record Consensus Error
            if k % 10 == 0:  # Sample every 10 rounds
                err = network.calculate_consensus_error(agents)
                consensus_errors.append(err)

        results[R] = consensus_errors
    return results


# --- Run SVM Task ---
print(">>> Loading a9a (SVM)...")
ds_svm = DatasetModel(dsname='a9a', num_agent=NUM_AGENTS, mb_size=BATCH_SIZE_SVM, max_sample=5000)
oracle_svm = SVMOracle(alpha=2, lam=1e-5)

svm_res_1st = run_experiment(SVMAgent, ds_svm, oracle_svm, '1st', [1, 2, 3, 4], 'svm')
svm_res_0th = run_experiment(SVMAgent, ds_svm, oracle_svm, '0th', [1, 2, 3, 4], 'svm')

# --- Run MLP Task ---
print(">>> Loading MNIST (MLP)...")
ds_mlp = DatasetModel(dsname='mnist', num_agent=NUM_AGENTS, mb_size=BATCH_SIZE_MLP, max_sample=3000)
oracle_mlp = MLPOracle(lam=1e-5, hidden_dim=HIDDEN_DIM)

mlp_res_1st = run_experiment(MLPAgent, ds_mlp, oracle_mlp, '1st', [1, 2, 3, 4], 'mlp')
mlp_res_0th = run_experiment(MLPAgent, ds_mlp, oracle_mlp, '0th', [1, 2, 3, 4], 'mlp')

# --- Plotting ---
x_axis = np.arange(0, NUM_ROUNDS, 10)
fig, axs = plt.subplots(2, 2, figsize=(15, 12))

# SVM 1st
for R, res in svm_res_1st.items():
    axs[0, 0].plot(x_axis, res, label=f'R={R}')
axs[0, 0].set_title('SVM/a9a (1st Order) - Consensus Error')
axs[0, 0].set_yscale('log')
axs[0, 0].set_ylabel('Consensus Error')
axs[0, 0].legend()
axs[0, 0].grid(True, alpha=0.3)

# SVM 0th
for R, res in svm_res_0th.items():
    axs[0, 1].plot(x_axis, res, label=f'R={R}')
axs[0, 1].set_title('SVM/a9a (0th Order) - Consensus Error')
axs[0, 1].set_yscale('log')
axs[0, 1].legend()
axs[0, 1].grid(True, alpha=0.3)

# MLP 1st
for R, res in mlp_res_1st.items():
    axs[1, 0].plot(x_axis, res, label=f'R={R}')
axs[1, 0].set_title('MLP/MNIST (1st Order) - Consensus Error')
axs[1, 0].set_yscale('log')
axs[1, 0].set_ylabel('Consensus Error')
axs[1, 0].set_xlabel('Computation Rounds')
axs[1, 0].legend()
axs[1, 0].grid(True, alpha=0.3)

# MLP 0th
for R, res in mlp_res_0th.items():
    axs[1, 1].plot(x_axis, res, label=f'R={R}')
axs[1, 1].set_title('MLP/MNIST (0th Order) - Consensus Error')
axs[1, 1].set_yscale('log')
axs[1, 1].set_xlabel('Computation Rounds')
axs[1, 1].legend()
axs[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('Task1_Consensus_Results.png')
print("Results saved to Task1_Consensus_Results.png")