import gc
import torch

from utils.weight_matrix import build_W_by_topology
from least_square.build_ls import RunCfg, build_quadratic_problem, build_model
from least_square.train_algo_ls import train_core_quadratic
from utils.function import set_seed

def train_I(models, problem, W, cfg):
    cfg.run_name = f"I_lr{cfg.lr}_{cfg.topology}_{cfg.seed_noise}"
    return train_core_quadratic(models=models, problem=problem, W=W, cfg=cfg, approach="I")


def train_II(models, problem, W, cfg):
    cfg.run_name = f"II_lr{cfg.lr}_{cfg.topology}_{cfg.seed_noise}"
    return train_core_quadratic(models=models, problem=problem, W=W, cfg=cfg, approach="II")


def run_compare(cfg: RunCfg):
    set_seed(cfg.seed)

    W_I = build_W_by_topology(cfg.num_users, cfg.topology, [1] * cfg.num_users)
    W_II = build_W_by_topology(cfg.num_users, cfg.topology, cfg.split_ratio)
    history_both = []

    for algo_name, trainer, W in [
        ("A", train_I, W_I),
        ("B", train_II, W_II),
    ]:
        print(f"\n========== Running {algo_name} ==========\n")
        set_seed(cfg.seed)

        problem = build_quadratic_problem(cfg)

        x_init_all = problem["x_init_all"]
        models = [
            build_model(cfg, x_init=x_init_all[u].to(cfg.device))
            for u in range(cfg.num_users)
        ]

        hist = trainer(models=models, problem=problem, W=W, cfg=cfg)
        history_both.append(hist)

    return history_both


if __name__ == "__main__":
    lrs = [0.1]
    seeds_1 = [42, 1123, 3407, 9527, 1024,
              2048, 4096, 777, 888, 999]
    split_ratios = [[0.3, 0.8, 1.0, 0.9, 0.7, 1.0, 2.0, 2.2, 1.2, 1.4, 0.8, 0.5, 1.5, 0.6, 0.6, 0.5],
                    [0.4, 2.2, 1.2, 0.5, 1.0, 0.6, 1.5, 0.5, 1.0, 0.7, 1.3, 0.9, 1.4, 0.6, 1.2, 1.0]]
    for lr_ in lrs:
        for seed in seeds_1:
            cfg = RunCfg(
                device="cuda",
                seed=3407,
                d=10,
                num_users=16,
                lr=lr_,
                iterations=300,
                eval_every=3,
                rho=0.01,
                topology="static",
                split_ratio=split_ratios[1],
                gamma_low=5.5,
                gamma_high=12.5,
                bias_nodes=[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],
                mu_norm=3.0,
                noise_sigma=1.0,
                seed_noise=seed,
                use_wandb=False,
                entity="",
                project="",
            )
            history_both = run_compare(cfg)
            gc.collect()