from CiteSeer_model import FLGnnA
import torch
from CiteSeer_loader import citeseer
import numpy as np
from torch.cuda.amp import autocast as autocast, GradScaler
from sklearn.metrics import f1_score
from tqdm import tqdm

device = "cuda"

def Grid_Search(config):
    res = []
    items = list(config.keys())

    def grid_search(select: dict, deep: int = 0):
        if deep == len(config):
            res.append(select)
        else:
            k = items[deep]
            for v in config[k]:
                grid_search(select | {k: v}, deep + 1)

    grid_search(dict())
    return res


def fit(m, config: dict, dataset):
    epoch = config.get("epoch", 10)

    optim_config = config.get("optim", {"lr": 0.05, "weight_decay": 5e-4})

    critical = {
        "regress": torch.nn.MSELoss,
        "binary_classify": torch.nn.BCEWithLogitsLoss,
        "multi_classify": torch.nn.CrossEntropyLoss,
    }[config.get("type")]()

    optimizer = torch.optim.Adam(**optim_config, params=m.parameters())

    reduce_schedule = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=0.001)
    min_loss = None
    scaler = GradScaler()
    m.train()
    pbar = tqdm(total=epoch)
    ticker_train = []
    ticker_inference = []
    metric_test = 0
    for e in range(epoch):
        optimizer.zero_grad()
        with autocast():
            pre_y = m(dataset)
            loss = critical(pre_y[dataset.train_mask], dataset.y[dataset.train_mask])
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        reduce_schedule.step(e)
        with torch.no_grad():
            f1_test, loss_test = metric(m, dataset)
            metric_test = max(metric_test, f1_test)
            pbar.set_description(
                f'epoch: {e + 1} \\ {epoch}, loss_train: {loss}, test_loss: {loss_test}, best_f1: {metric_test}')
            pbar.update()
    pbar.close()
    return metric_test


def f1(true, pre):
    true = true.detach().cpu().numpy()
    pre = pre.detach().cpu().numpy()
    return f1_score(true, pre, average="micro")


def metric(m, dataset):
    m.eval()
    critical = torch.nn.CrossEntropyLoss()
    with torch.no_grad():
        pre_y = m(dataset)
        loss = critical(pre_y[dataset.test_mask], dataset.y[dataset.test_mask])
        pre_y = pre_y.argmax(dim=1)
        f1_test = f1(dataset.y[dataset.test_mask], pre_y[dataset.test_mask])
        # f1_train = f1(dataset.y[dataset.val_mask], pre_y[dataset.val_mask])
        # print(f"f1_test: {f1_test},  f1_train: {f1_train}")
    m.train()
    return f1_test, loss


if __name__ == '__main__':
    configure = {
        "in_channels": [3703],
        "hidden": [800],
        "out_channels": [6],
        "windows": [3],
        "stride": [3],
        "order": [1],
        "A_P2": [True],
        "refine_ratio": [1],
        "refiner": ["pool"],
        "dropout": [0.35],
        "cross": [0.9],
        "num_mf": [2],
        "fix": [True],
        "attention": [False],
        "residual": [False],
        "layer": [3],
        "norm": [False],
        "value_intervals": [[-1, 1]],
        "optim": [
            {"lr": 0.005, "weight_decay": 4e-1},
        ],
        "type": ["multi_classify"],
        "epoch": [400],
    }
    dataset = citeseer()
    for cfg in Grid_Search(configure):
        res = []
        print(cfg)
        for exp in range(30):
            m = FLGnnA(**cfg).to(device)
            sub_res = fit(m, cfg, dataset)
            res.append(sub_res)
        print(np.mean(res), np.std(res))
