import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import time
from collections import defaultdict, OrderedDict
import pickle
from tqdm import tqdm
import random
import math
import argparse    
import json
import random
import numpy as np
from timm.scheduler import *

from model.resnet_cifar10 import *

from optimizer.gossip_optimizer import *
from optimizer.fedadam_optimizer import *
from optimizer.scaffold_optimizer import *
from optimizer.fedmuon_optimizer import *
from optimizer.local_muon_optimizer import *
from optimizer.scaffold_adam_optimizer import *
from data.loader_dirichlet import *

torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


def get_lr(it, lr, warmup_iters=0, min_lr=1e-5, lr_decay_iters=100):
    """
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        #return learning_rate * (it + 1) / (warmup_iters + 1)
        return lr * (it + 1) / (warmup_iters + 1)
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
    #return min_lr + coeff * (learning_rate - min_lr)
    return min_lr + coeff * (lr - min_lr)
    """
    return lr

def run(rank, size, datasets, config):
    # initialize the model parameters with same seed value.
    torch.manual_seed(config["seed"])
    random.seed(config["seed"])
    np.random.seed(config["seed"])
    
    torch.set_num_threads(1)

    net = ResNetCifar10(device=config["device"][rank]).to(config["device"][rank])
    net.to(config["device"][rank])
    
    loaders = datasets_to_loaders(datasets, config["batch"])

    if config["optimizer"] == "fedavg":
        optimizer = GossipOptimizer(params=net.parameters(), node_id=rank, graph=config["graph"], local_step=config["local_step"], lr=config["lr"], beta=config["beta"], device=config["device"][rank], n_nodes=config["n_nodes"], n_sampled_nodes=config["n_sampled_nodes"])
    elif config["optimizer"] == "fedadam":
        optimizer = FedAdamOptimizer(params=net.parameters(), node_id=rank, graph=config["graph"], local_step=config["local_step"], lr=config["lr"], device=config["device"][rank], n_nodes=config["n_nodes"], n_sampled_nodes=config["n_sampled_nodes"])
    elif config["optimizer"] == "scaffold":
        optimizer = ScaffoldOptimizer(params=net.parameters(), node_id=rank, graph=config["graph"], local_step=config["local_step"], lr=config["lr"], beta=config["beta"], device=config["device"][rank], n_nodes=config["n_nodes"], n_sampled_nodes=config["n_sampled_nodes"])
    elif config["optimizer"] == "scaffold_adam":
        optimizer = ScaffoldAdamOptimizer(params=net.parameters(), node_id=rank, graph=config["graph"], local_step=config["local_step"], lr=config["lr"], beta=config["beta"], device=config["device"][rank], n_nodes=config["n_nodes"], n_sampled_nodes=config["n_sampled_nodes"])
    elif config["optimizer"] == "fedmuon":

        hidden_matrix_params = [p for n, p in net.named_parameters() if p.ndim >= 2 and "embed" not in n]
        embed_params = [p for n, p in net.named_parameters() if "embed" in n]
        scalar_params = [p for p in net.parameters() if p.ndim < 2]
        adam_groups = [dict(params=embed_params, lr=0.22), dict(params=scalar_params, lr=0.04)]
        adam_groups = [dict(**g, betas=(0.8, 0.95), eps=1e-10, use_muon=False) for g in adam_groups]
        muon_group = dict(params=hidden_matrix_params, lr=0.05, momentum=0.95, use_muon=True)
        param_groups = [*adam_groups, muon_group]
        """
        hidden_matrix_params = []
        embed_params = [p for n, p in net.named_parameters()]
        scalar_params = []
        adam_groups = [dict(params=embed_params, lr=0.22), dict(params=scalar_params, lr=0.04)]
        adam_groups = [dict(**g, betas=(0.8, 0.95), eps=1e-10, use_muon=False) for g in adam_groups]
        muon_group = dict(params=hidden_matrix_params, lr=0.05, momentum=0.95, use_muon=True)
        param_groups = [*adam_groups, muon_group]
        """
        
        optimizer = FedMuonOptimizer(param_groups, node_id=rank, graph=config["graph"], local_step=config["local_step"], lr=config["lr"], beta=config["beta"], device=config["device"][rank], n_nodes=config["n_nodes"], n_sampled_nodes=config["n_sampled_nodes"])
    elif config["optimizer"] == "localmuon":

        hidden_matrix_params = [p for n, p in net.named_parameters() if p.ndim >= 2 and "embed" not in n]
        embed_params = [p for n, p in net.named_parameters() if "embed" in n]
        scalar_params = [p for p in net.parameters() if p.ndim < 2]
        adam_groups = [dict(params=embed_params, lr=0.22), dict(params=scalar_params, lr=0.04)]
        adam_groups = [dict(**g, betas=(0.8, 0.95), eps=1e-10, use_muon=False) for g in adam_groups]
        muon_group = dict(params=hidden_matrix_params, lr=0.05, momentum=0.95, use_muon=True)
        param_groups = [*adam_groups, muon_group]
        
        optimizer = LocalMuonOptimizer(param_groups, node_id=rank, graph=config["graph"], local_step=config["local_step"], lr=config["lr"], beta=config["beta"], device=config["device"][rank], n_nodes=config["n_nodes"], n_sampled_nodes=config["n_sampled_nodes"])
        
    
    history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": [], "test_loss": [], "test_acc": []}
    count_epoch = 0
    
    with tqdm(range(config["epochs"]), desc=("node "+str(rank)), position=rank) as pbar:
        for epoch in pbar:

            if config["optimizer"] in ["fedmuon", "localmuon"]:
                adam_lr = get_lr(epoch, config["adam_lr"], lr_decay_iters=config["epochs"])
                muon_lr = get_lr(epoch, config["muon_lr"], lr_decay_iters=config["epochs"])

                for param_group in optimizer.param_groups:
                    if param_group["use_muon"]:
                        param_group['lr'] = muon_lr
                    else:
                        param_group["lr"] = adam_lr

                        
            else:
                lr = get_lr(epoch, config["lr"], lr_decay_iters=config["epochs"])
                
            train_loss, train_acc = net.run(loaders, optimizer)
            
            if (count_epoch % 1 == 0) or (count_epoch == config["epochs"] -1):
                val_loss, val_acc = net.run_val(loaders)
                test_loss, test_acc = net.run_test(loaders)
                
                # save loss and accuracy
                history["train_loss"] += [train_loss]
                history["test_loss"] += [test_loss]
                history["val_loss"] += [val_loss]
                history["train_acc"] += [train_acc]
                history["test_acc"] += [test_acc]
                history["val_acc"] += [val_acc]
                        
                pbar.set_postfix(OrderedDict(loss=(round(train_loss, 2), round(test_loss, 2)), acc=(round(train_acc, 2), round(test_acc, 2))))
                
            count_epoch += 1

                
    pickle.dump(history, open(config["log_path"] + "node" + str(rank) + ".pk", "wb"))
    
    
def init_process(rank, size, datasets, config, fn, backend='gloo'):
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500' #config["config"]["port"]
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size, datasets, config)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('log', default="./results", type=str)
    parser.add_argument('--n_nodes', default=16, type=int)
    parser.add_argument('--n_sampled_nodes', default=4, type=int)
    parser.add_argument('--dataset', default="cifar10", type=str)
    parser.add_argument('--optimizer', default="gossip", type=str)
    parser.add_argument('--batch', default=32, type=int)
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--cuda', default=None, type=str)
    parser.add_argument('--lr', default=1e-3, type=float)
    parser.add_argument('--muon_lr', default=0., type=float)
    parser.add_argument('--adam_lr', default=0., type=float)
    parser.add_argument('--epoch', default=100, type=int)
    parser.add_argument('--alpha', default=10.0, type=float)
    parser.add_argument('--local_step', default=5, type=int)
    parser.add_argument('--beta', default=0.9, type=float)
    args = parser.parse_args()

    config = defaultdict(dict)
    config["n_nodes"] = args.n_nodes
    config["n_sampled_nodes"] = args.n_sampled_nodes
    config["dataset"] = args.dataset
    config["optimizer"] = args.optimizer
    config["lr"] = args.lr
    config["adam_lr"] = args.adam_lr
    config["muon_lr"] = args.muon_lr
    config["seed"] = args.seed 
    config["epochs"] = args.epoch
    config["log_path"] = args.log
    config["batch"] = args.batch
    config["alpha"] = args.alpha
    config["local_step"] = args.local_step
    config["beta"] = args.beta
    
    n_nodes = config["n_nodes"]
    
    config["device"] = {node_id : f"cuda:{node_id%8}" for node_id in range(config["n_nodes"])}

    torch.manual_seed(config["seed"])
    random.seed(config["seed"])
    np.random.seed(config["seed"])

    datasets = load_CIFAR10(n_nodes, batch=config["batch"], alpha=config["alpha"], val_rate=0.1, seed=config["seed"])

        
    processes = []
    try:
        mp.set_start_method("spawn")
    except RuntimeError:
        pass
    
    for rank in range(config["n_nodes"]):
        print(rank)
        node_datasets = {"train": datasets["train"][rank], "val": datasets["val"], "test": datasets["test"]}
        p = mp.Process(target=init_process, args=(rank, n_nodes, node_datasets, config, run))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()
