from molpcba_model import FLGnnA
import torch
from tqdm import tqdm
import molpcba_loader
import numpy as np
from torch.cuda.amp import autocast as autocast
from ogb.graphproppred.evaluate import Evaluator
from torch.cuda.amp import GradScaler
from itertools import count

device = "cuda"

def Grid_Search(config):
    res = []
    items = list(config.keys())

    def grid_search(select: dict, deep: int = 0):
        if deep == len(config):
            res.append(select)
        else:
            k = items[deep]
            for v in config[k]:
                grid_search(select | {k: v}, deep + 1)

    grid_search(dict())
    return res


def fit(m, config: dict, train_dataset, test_dataset, valid_dataset, lr_statue=None):
    epoch = config.get("epoch", 10)

    optim_config = config.get("optim", {"lr": 0.05, "weight_decay": 5e-4})

    if lr_statue is None:
        optimizer = torch.optim.Adam(**optim_config, params=m.parameters())
    else:
        optimizer = torch.optim.Adam(**optim_config, params=m.parameters())
        optimizer.load_state_dict(lr_statue)

    critical = {
        "regress": torch.nn.MSELoss,
        "binary_classify": torch.nn.BCEWithLogitsLoss,
        "multi_classify": torch.nn.CrossEntropyLoss,

    }[config.get("type")]()

    # reduce_schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
    #                                                              mode='min',
    #                                                              factor=0.9,
    #                                                              patience=100,
    #                                                              verbose=False,
    #                                                              threshold=1e-3,
    #                                                              threshold_mode='rel',
    #                                                              cooldown=0,
    #                                                              min_lr=0.0005,
    #                                                              eps=1e-6)

    reduce_schedule = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10,
                                                                           T_mult=2,
                                                                           eta_min=0.001)

    m.train()
    scaler = GradScaler()
    a = count()
    loss_mean = -1
    loss_test = -1
    sub_res = 0.24
    firing_strength = []
    for e in range(epoch):
        loss_record = []
        pbar = tqdm(total=len(train_dataset))
        for graphs in train_dataset:
            graphs = graphs.to(device)
            optimizer.zero_grad()
            is_labeled = graphs.y == graphs.y
            with autocast():
                pre_y = m(graphs)
                loss = critical(pre_y.to(torch.float32)[is_labeled], graphs.y.float()[is_labeled])
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            pbar.set_description(f'epoch: {e + 1}, loss_train: {loss}, loss_mean: {loss_mean}, loss_test: {loss_test}')
            pbar.update()
            loss_record.append(loss.detach().cpu().numpy())
        reduce_schedule.step(next(a))

        loss_mean = np.mean(loss_record)

        with torch.no_grad():
            test_res = metric(m, valid_dataset, test_dataset)["ap"]
            if test_res > sub_res:
                sub_res = test_res
                model_set = m.state_dict()
                lr_set = optimizer.state_dict()
                torch.save({"FL-GNN": model_set, "opt": lr_set, "config": config}, f"molplcba-{test_res}-{e}.tar")
                if test_res >= 0.25:
                    return sub_res
        pbar.close()
    return sub_res
def metric(m, train_dataset, test_dataset):
    m.eval()
    evl = Evaluator("ogbg-molpcba")
    with torch.no_grad():
        # y_true = []
        # y_pred = []
        # for graphs in train_dataset:
        #     graphs = graphs.to(device)
        #     pre_y = m(graphs)
        #     y_true.append(graphs.y.view(pre_y.shape).detach().cpu())
        #     y_pred.append(pre_y.detach().cpu())
        #
        # y_true = torch.cat(y_true, dim=0).numpy()
        # y_pred = torch.cat(y_pred, dim=0).numpy()
        #
        # input_dict = {"y_true": y_true, "y_pred": y_pred}
        # evl_train = evl.eval(input_dict)

        y_true = []
        y_pred = []
        for graphs in test_dataset:
            graphs = graphs.to(device)
            pre_y = m(graphs)
            y_true.append(graphs.y.view(pre_y.shape).detach().cpu())
            y_pred.append(pre_y.detach().cpu())

        y_true = torch.cat(y_true, dim=0).numpy()
        y_pred = torch.cat(y_pred, dim=0).numpy()

        input_dict = {"y_true": y_true, "y_pred": y_pred}
        evl_test = evl.eval(input_dict)

        print(f"test: {evl_test},  train: {1}")
    m.train()
    return evl_test


if __name__ == '__main__':
    configure = {
        "hidden": [400],
        "out_channels": [128],
        "windows": [3],
        "stride": [1],
        "order": [0],
        "A_P2": [True],
        "extract_ratio": [1],
        "extractor": ["pool"],
        "cross": [0.9],
        "num_mf": [2],
        "fix": [True],
        "layer": [5],
        "norm": [False],
        "attention": [False],
        "residual": [False],
        "dropout": [0.1],
        "value_intervals": [[-1, 1]],
        "optim": [
            # {"lr": 0.005},
            # {"lr": 0.003},
            {"lr": 0.004, "weight_decay":1e-8},
            # {"lr": 0.0045}
            # {"lr": 0.0065},
        ],
        "type": ["binary_classify"],
        "epoch": [40],
    }
    # check = torch.load(r"molplcba-0.253611733071654-25.tar")
    # config = check["config"]
    # print(config)
    train, valid, test = molpcba_loader.molpcba()
    for cfg in Grid_Search(configure):
            res = []
            print(cfg)
            for exp in range(3):
                try:
                    m  = FLGnnA(**cfg).to(device)
                    score = fit(m, cfg, train, test, valid)
                    res.append(score)
                except:
                    print("something wrong")

            print(np.nanmean(res), np.nanstd(res))
            print(get_parameter_number(m))

