import torch
import numpy as np
import matplotlib.pyplot as plt
import argparse
import time
from Client import DatasetHandler, ResNetAgent, FastComNetworkPyTorch, ComNetworkPyTorch


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='cifar10')
    parser.add_argument('--num_agents', type=int, default=16)
    parser.add_argument('--rounds', type=int, default=10000, help='Total communication rounds')
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--lr', type=float, default=0.05)
    parser.add_argument('--D', type=float, default=0.001)
    parser.add_argument('--seed', type=int, default=42)
    return parser.parse_args()


# --- 1. DOC2S Algorithm ---
def train_DOC2S(agents, network, dataset_handler, args):
    print("Training DOC2S...")
    train_iters = [iter(dataset_handler.get_train_loader(i)) for i in range(args.num_agents)]
    losses, accs = [], []

    for k in range(args.rounds):
        if k % 100 == 99:  # Periodic restart
            for agent in agents: agent.initialize_action()

        # Client Sampling
        selected = np.random.randint(args.num_agents)

        # Get Data batch safely
        try:
            batch = next(train_iters[selected])
        except StopIteration:
            train_iters[selected] = iter(dataset_handler.get_train_loader(selected))
            batch = next(train_iters[selected])

        # Update Logic
        agent = agents[selected]
        grad = agent.compute_gradient(batch)

        # Update Action: a = a - eta * g (Projected)
        curr_action = agent.get_action()
        unprojected = curr_action - args.lr * grad
        norm = torch.norm(unprojected)
        scale = args.D / norm if (norm > 1e-8 and norm > args.D) else 1.0
        new_action = args.num_agents * scale * unprojected  # Scale by N for selected

        for i, a in enumerate(agents):
            if i == selected:
                a.set_action(new_action)
            else:
                a.initialize_action()

        # Consensus (Chebyshev)
        actions = [a.get_action() for a in agents]
        mixed_actions = network.propagate(actions, R=1)
        for i, a in enumerate(agents): a.set_action(mixed_actions[i])

        # Update Weights: w = w + mixed_action (Effective update)
        # Note: To match paper precisely, propagate weight too
        weights = [a.get_flat_params() + a.get_action() for a in agents]  # tentative update
        mixed_weights = network.propagate(weights, R=1)
        for i, a in enumerate(agents): a.set_flat_params(mixed_weights[i])

        if k % 20 == 0:
            loss, acc = agents[0].get_test_loss(dataset_handler.get_test_loader())
            losses.append(loss)
            accs.append(acc)
            print(f"[DOC2S] R {k} | Loss {loss:.3f} | Acc {acc:.2f}")

    return losses, accs


# --- 2. MEDOL Algorithm (1st Order) ---
def train_MEDOL(agents, network, dataset_handler, args):
    print("Training MEDOL...")
    train_iters = [iter(dataset_handler.get_train_loader(i)) for i in range(args.num_agents)]
    losses, accs = [], []

    for k in range(args.rounds):
        # Full Participation
        for i, agent in enumerate(agents):
            try:
                batch = next(train_iters[i])
            except StopIteration:
                train_iters[i] = iter(dataset_handler.get_train_loader(i))
                batch = next(train_iters[i])

            # Local Update: w_new = w + action, then compute grad at w_new
            # In paper, MEDOL computes grad at shifted point.
            # Simplified 1st order: w is current, action is momentum

            grad = agent.compute_gradient(batch)

            # Update Action: a = a - lr * grad
            curr_a = agent.get_action()
            unprojected = curr_a - args.lr * grad
            # Project
            norm = torch.norm(unprojected)
            scale = args.D / norm if (norm > 1e-8 and norm > args.D) else 1.0
            agent.set_action(scale * unprojected)

            # Local Weight Update
            w = agent.get_flat_params()
            agent.set_flat_params(w + agent.get_action())

        # Mix Actions & Weights
        actions = network.propagate([a.get_action() for a in agents])
        weights = network.propagate([a.get_flat_params() for a in agents])

        for i, a in enumerate(agents):
            a.set_action(actions[i])
            a.set_flat_params(weights[i])

        if k % 20 == 0:
            loss, acc = agents[0].get_test_loss(dataset_handler.get_test_loader())
            losses.append(loss)
            accs.append(acc)
            print(f"[MEDOL] R {k} | Loss {loss:.3f} | Acc {acc:.2f}")

    return losses, accs


# --- 3. DGFM Algorithm (1st Order / Gradient Tracking) ---
def train_DGFM(agents, network, dataset_handler, args):
    print("Training DGFM (Gradient Tracking)...")
    train_iters = [iter(dataset_handler.get_train_loader(i)) for i in range(args.num_agents)]
    losses, accs = [], []

    # Initialization: Compute first gradients
    with torch.no_grad():
        for i, agent in enumerate(agents):
            try:
                batch = next(train_iters[i])
            except:
                batch = next(iter(dataset_handler.get_train_loader(i)))  # simplified reset

            g = agent.compute_gradient(batch)
            agent.set_prev_grad(g)
            agent.set_tracker(g)  # v_0 = g_0

    # Initial mix of trackers
    trackers = network.propagate([a.get_tracker() for a in agents], R=5)  # Better mixing at start
    for i, a in enumerate(agents): a.set_tracker(trackers[i])

    for k in range(args.rounds):
        for i, agent in enumerate(agents):
            try:
                batch = next(train_iters[i])
            except:
                train_iters[i] = iter(dataset_handler.get_train_loader(i))
                batch = next(train_iters[i])

            # Compute new grad at current w
            g_new = agent.compute_gradient(batch)
            g_prev = agent.get_prev_grad()

            # Update Tracker: v = v + g_new - g_prev
            tracker = agent.get_tracker()
            new_tracker = tracker + g_new - g_prev

            agent.set_tracker(new_tracker)
            agent.set_prev_grad(g_new)  # Save for next round

        # Consensus on Trackers
        trackers = network.propagate([a.get_tracker() for a in agents])
        for i, a in enumerate(agents): a.set_tracker(trackers[i])

        # Update Weights: w = w - lr * tracker
        for agent in agents:
            w = agent.get_flat_params()
            agent.set_flat_params(w - args.lr * agent.get_tracker())

        # Consensus on Weights
        weights = network.propagate([a.get_flat_params() for a in agents])
        for i, a in enumerate(agents): a.set_flat_params(weights[i])

        if k % 20 == 0:
            loss, acc = agents[0].get_test_loss(dataset_handler.get_test_loader())
            losses.append(loss)
            accs.append(acc)
            print(f"[DGFM ] R {k} | Loss {loss:.3f} | Acc {acc:.2f}")

    return losses, accs


if __name__ == '__main__':
    args = parse_args()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    dataset = DatasetHandler(args.dataset, args.num_agents, args.batch_size)
    num_cls = dataset.num_classes


    # Helper to reset agents
    def get_agents():
        return [ResNetAgent(i, num_cls, args.lr, args.D, args.num_agents, device) for i in range(args.num_agents)]


    # Train DOC2S
    # Note: DOC2S uses Chebyshev Network
    net_doc2s = FastComNetworkPyTorch(args.num_agents, device)
    loss_doc, acc_doc = train_DOC2S(get_agents(), net_doc2s, dataset, args)

    # Train MEDOL
    # Note: MEDOL often uses Ring/Standard Network
    net_medol = ComNetworkPyTorch(args.num_agents, device)
    loss_med, acc_med = train_MEDOL(get_agents(), net_medol, dataset, args)

    # Train DGFM
    net_dgfm = ComNetworkPyTorch(args.num_agents, device)
    loss_dgfm, acc_dgfm = train_DGFM(get_agents(), net_dgfm, dataset, args)

    # Plotting
    steps = np.arange(0, args.rounds, 100)
    plt.figure(figsize=(9.5, 8.5))

    plt.plot(steps, loss_doc, 'k.-', label='DOC2S', markersize=10)
    plt.plot(steps, loss_med, 'r^--', label='MEDOL', markersize=10)
    plt.plot(steps, loss_dgfm, 'bs-.', label='DGFM (GT)', markersize=10)

    plt.xlabel("Communication Rounds", fontsize=21)
    plt.ylabel("Test Loss", fontsize=21)
    plt.legend(fontsize=25)
    plt.grid(True, linestyle='--', alpha=0.3)
    plt.title(f"Comparison on {args.dataset.upper()}", fontsize=16)

    plt.tight_layout()
    plt.savefig(f"compare_{args.dataset}.png")
    print("Saved plot to compare_" + args.dataset + ".png")