import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.optim.optimizer import Optimizer
import torch.distributed as dist
import time

import numpy as np

def zeropower_via_newtonschulz5(G, steps: int):
    """
    Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
    quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
    of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
    zero even beyond the point where the iteration no longer converges all the way to one everywhere
    on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
    where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
    performance at all relative to UV^T, where USV^T = G is the SVD.
    """
    assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
    a, b, c = (3.4445, -4.7750,  2.0315)
    #X = G.bfloat16()

    X = G.float()
    
    if G.size(-2) > G.size(-1):
        X = X.T #X.mT
    
    # Ensure spectral norm is at most 1
    X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
    # Perform the NS iterations
    for _ in range(steps):
        A = X @ X.T #X.mT
        B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
        X = a * X + B @ X
    
    if G.size(-2) > G.size(-1):
        X = X.T #X.mT
    return X

def muon_update(momentum, beta=0.95, ns_steps=5, nesterov=True):
    #momentum.lerp_(grad, 1 - beta)    
    update = momentum #grad.lerp_(momentum, beta) if nesterov else momentum
    if update.ndim == 4: # for the case of conv filters
        update = update.view(len(update), -1)
    update = zeropower_via_newtonschulz5(update, steps=ns_steps)

    #update *= max(1, grad.size(-2) / grad.size(-1))**0.5
    update *= max(momentum.size(-2), momentum.size(-1))**0.5 

    return update


class LocalMuonOptimizer(Optimizer):
    def __init__(self, param_groups, node_id: int, graph, local_step, n_nodes, n_sampled_nodes, lr=1e-5, beta=0.9, device="cuda"):

        self.node_id = node_id
        self.device = device

        self.n_nodes = n_nodes
        self.n_sampled_nodes = n_sampled_nodes
        
        self.local_step = local_step
        self.step_counter = 0
        self.beta = beta
        
        self.rng = np.random.default_rng(0)
        self.client_sampling()

        for group in param_groups:
            assert "use_muon" in group
            if group["use_muon"]:
                # defaults
                group["lr"] = group.get("lr", 0.02)
                group["momentum"] = group.get("momentum", 0.95)
                group["weight_decay"] = group.get("weight_decay", 0)
                assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
            else:
                # defaults
                group["lr"] = group.get("lr", 3e-4)
                group["betas"] = group.get("betas", (0.9, 0.95))
                group["eps"] = group.get("eps", 1e-10)
                group["weight_decay"] = group.get("weight_decay", 0)
                assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
        super().__init__(param_groups, dict())
        

        for group in self.param_groups:
            group["momentum"] = []
            group["local_c"] = []
            group["global_c"] = []
            group["grad_accum"] = []
            
            for p in group["params"]:
                group["momentum"].append(torch.zeros_like(p, device=self.device))
                group["local_c"].append(torch.zeros_like(p, device=self.device))
                group["global_c"].append(torch.zeros_like(p, device=self.device))
                group["grad_accum"].append(torch.zeros_like(p, device=self.device))


    @torch.no_grad()
    def client_sampling(self):
        self.sampled_clients = list(self.rng.choice(np.arange(0, self.n_nodes), size=self.n_sampled_nodes, replace=False))
        #print("sampled_clients", self.sampled_clients, self.node_id)
                
    @torch.no_grad()
    def step(self, closure=None):
        loss = None

        if self.node_id in self.sampled_clients:
            for group in self.param_groups:
                if group["use_muon"]:
                    for p in group["params"]:
                        for p, momentum, local_c, global_c, grad_accum in zip(group['params'], group["momentum"], group["local_c"], group["global_c"], group["grad_accum"]):
                            print("muon")
                            
                            momentum.data = (1 - self.beta) * p.grad + self.beta * momentum
                            update = muon_update(momentum)
                            #p -= lr * momentum
                            p -= group["lr"] *update.reshape(p.shape) #(momentum - local_c + global_c)

                else: # とりあえずSGD
                    for p in group["params"]:
                        for p, momentum, local_c, global_c, grad_accum in zip(group['params'], group["momentum"], group["local_c"], group["global_c"], group["grad_accum"]):
                            #print(f'{group["lr"]}')
                            momentum.data = (1 - self.beta) * p.grad + self.beta * momentum
                            p -= group["lr"] * momentum

                            
        if closure is not None:
            loss = closure()

        self.step_counter += 1
        if self.step_counter % self.local_step == 0:
            self.average()
            self.client_sampling()
            
        return loss


    @torch.no_grad()
    def average(self):                    
        for group in self.param_groups:
            for i, (p, global_c) in enumerate(zip(group["params"], group["global_c"])):
                dist.all_reduce(p.data, op=dist.ReduceOp.SUM)
                p /= self.n_nodes
                
