import copy

from utils.function import set_seed
from utils.weight_matrix import build_W_by_topology
from train_classification.build import RunCfg, build_data, build_model
from train_classification.train_algo import train_I, train_II
from train_classification.resnet import *


def run_compare(cfg: RunCfg):
    set_seed(cfg.seed)
    init_model = build_model(cfg)
    init = init_model.state_dict()

    W_I = build_W_by_topology(cfg.num_users, cfg.topology, [1] * cfg.num_users, seed=cfg.seed)
    W_II = build_W_by_topology(cfg.num_users, cfg.topology, cfg.split_ratio, seed=cfg.seed)

    history_both = []
    for algo_name, trainer, W in [
        ("B", train_II, W_II),
        ("A", train_I, W_I),
    ]:
        set_seed(cfg.seed)
        loaders = build_data(cfg)
        if cfg.dataset == "cifar10_res":
            models = []
            for _ in range(cfg.num_users):
                model = ResNet18WithBN()
                models.append(model)
        for m in models:
            m.load_state_dict(copy.deepcopy(init))
            m.to(cfg.device)
        history_both.append(trainer(models=models, loaders=loaders, W=W, cfg=cfg))

    return history_both

if __name__ == "__main__":
    lrs = [0.07]
    seeds = [42, 123, 3407, 9527, 1024, 2048]
    split_ratios = [[0.3, 0.8, 1.0, 0.9, 0.7, 1.0, 2.0, 2.2, 1.2, 1.4, 0.8, 0.5, 1.5, 0.6, 0.6, 0.5],
                    [0.4, 2.2, 1.2, 0.5, 1.0, 0.6, 1.5, 0.5, 1.0, 0.7, 1.3, 0.9, 1.4, 0.6, 1.2, 1.0]]
    for lr_ in lrs:
        history_both_all=[]
        for seed in seeds:
            cfg = RunCfg(
                device="cuda",
                seed=seed,
                dataset="cifar10_res",
                lr=lr_,
                weight_decay=5e-4,
                batch_size_local=128,
                num_users=16,
                iterations=2500,
                eval_every=30,
                split_ratio=split_ratios[1],
                topology="er",
                use_wandb=False,
                entity="",
                project="",
                run_name=None,
            )
            hist = run_compare(cfg)