import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.optim.optimizer import Optimizer
import torch.distributed as dist
import time

import numpy as np

class ScaffoldOptimizer(Optimizer):
    def __init__(self, params, node_id: int, graph, local_step, n_nodes, n_sampled_nodes, lr=1e-5, beta=0.9, device="cuda"):

        self.node_id = node_id
        self.device = device

        self.n_nodes = n_nodes
        self.n_sampled_nodes = n_sampled_nodes
        
        self.local_step = local_step
        self.step_counter = 0
        self.beta = beta
        
        self.rng = np.random.default_rng(0)
        self.client_sampling()
        
        defaults = dict(lr=lr)
        super(ScaffoldOptimizer, self).__init__(params, defaults)

        for group in self.param_groups:
            group["momentum"] = []
            group["local_c"] = []
            group["global_c"] = []
            group["grad_accum"] = []
            
            for p in group["params"]:
                group["momentum"].append(torch.zeros_like(p, device=self.device))
                group["local_c"].append(torch.zeros_like(p, device=self.device))
                group["global_c"].append(torch.zeros_like(p, device=self.device))
                group["grad_accum"].append(torch.zeros_like(p, device=self.device))


    @torch.no_grad()
    def client_sampling(self):
        self.sampled_clients = list(self.rng.choice(np.arange(0, self.n_nodes), size=self.n_sampled_nodes, replace=False))
        #print("sampled_clients", self.sampled_clients, self.node_id)
                
    @torch.no_grad()
    def step(self, closure=None):
        loss = None

        if self.node_id in self.sampled_clients:
            
            for group in self.param_groups:
                lr = group['lr']

                
                for p, momentum, local_c, global_c, grad_accum in zip(group['params'], group["momentum"], group["local_c"], group["global_c"], group["grad_accum"]):
                    momentum.data = (1 - self.beta) * p.grad + self.beta * momentum
                    #p -= lr * momentum
                    p -= lr * (momentum - local_c + global_c)
                    grad_accum += momentum

                    
        if closure is not None:
            loss = closure()

        self.step_counter += 1
        if self.step_counter % self.local_step == 0:
            self.average()
            self.client_sampling()
            
        return loss


    @torch.no_grad()
    def average(self):

        if self.node_id in self.sampled_clients:
            for group in self.param_groups:
                for i, (global_c, local_c, grad_accum) in enumerate(zip(group["global_c"], group["local_c"], group["grad_accum"])):
                    local_c.copy_(grad_accum / self.local_step)
                    global_c.copy_(grad_accum / self.local_step)
                    grad_accum.copy_(torch.zeros_like(grad_accum))
                    
        for group in self.param_groups:
            for i, (p, global_c) in enumerate(zip(group["params"], group["global_c"])):
                dist.all_reduce(p.data, op=dist.ReduceOp.SUM)
                p /= self.n_nodes
                dist.all_reduce(global_c, op=dist.ReduceOp.SUM)
                global_c /= self.n_nodes
                
