from molfreesolve_model import FLGnnA
import torch
from tqdm import tqdm
import molfreesolv_loader
import numpy as np
from torch.cuda.amp import autocast as autocast
from ogb.graphproppred.evaluate import Evaluator
from torch.cuda.amp import GradScaler
from itertools import count
from molfreesolv_layer import device


def Grid_Search(config):
    res = []
    items = list(config.keys())

    def grid_search(select: dict, deep: int = 0):
        if deep == len(config):
            res.append(select)
        else:
            k = items[deep]
            for v in config[k]:
                grid_search(select | {k: v}, deep + 1)

    grid_search(dict())
    return res

def get_parameter_number(model):
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}


def fit(m, config: dict, train_dataset, test_dataset, valid_dataset):
    epoch = config.get("epoch", 10)

    optim_config = config.get("optim", {"lr": 0.05, "weight_decay": 5e-4})

    critical = {
        "regress": torch.nn.MSELoss,
        "binary_classify": torch.nn.BCEWithLogitsLoss,
        "multi_classify": torch.nn.CrossEntropyLoss,

    }[config.get("type")]()

    optimizer = torch.optim.Adam(**optim_config, params=m.parameters())

    # reduce_schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
    #                                                              mode='min',
    #                                                              factor=0.9,
    #                                                              patience=300,
    #                                                              verbose=False,
    #                                                              threshold=1e-4,
    #                                                              threshold_mode='rel',
    #                                                              cooldown=0,
    #                                                              min_lr=0.0001,
    #                                                              eps=1e-6)

    reduce_schedule = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=epoch, T_mult=1,
                                                                           eta_min=0.0035)

    m.train()
    pbar = tqdm(total=len(train_dataset) * epoch)
    scaler = GradScaler()
    a = count()
    loss_mean = -1
    model_set = None
    sub_res = float("inf")
    firing_strength = []
    for e in range(epoch):
        loss_record = []
        for graphs in train_dataset:
            graphs = graphs.to(device)
            optimizer.zero_grad()
            with autocast():
                if len(graphs) == 1:
                    continue
                pre_y = m(graphs)
                loss = critical(pre_y, graphs.y.float())
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            pbar.set_description(f'epoch: {e + 1}, loss_train: {loss}, loss_mean: {loss_mean}, best_test: {sub_res}')
            pbar.update()
            loss_record.append(loss.detach().cpu().numpy())

        test_res = metric(m, valid_dataset, test_dataset)["rmse"]
        if test_res < sub_res:
            sub_res = test_res

        loss_mean = np.mean(loss_record)
        reduce_schedule.step(e)

    pbar.close()
    return sub_res


def metric(m, train_dataset, test_dataset):
    m.eval()
    evl = Evaluator("ogbg-molfreesolv")
    with torch.no_grad():
        y_true = []
        y_pred = []
        for graphs in test_dataset:
            graphs = graphs.to(device)
            pre_y = m(graphs)
            y_true.append(graphs.y.view(-1, 1).detach().cpu())
            y_pred.append(pre_y.detach().view(-1, 1).cpu())

        y_true = torch.cat(y_true, dim=0).numpy()
        y_pred = torch.cat(y_pred, dim=0).numpy()

        input_dict = {"y_true": y_true, "y_pred": y_pred}
        evl_test = evl.eval(input_dict)
    m.train()
    return evl_test


if __name__ == '__main__':
    configure = {
        "hidden": [100, 200, 300],
        "out_channels": [1],
        "windows": [2, 3, 4],
        "stride": [2, 1, 3],
        "order": [0, 1],
        "A_P2": [True],
        "refine_ratio": [1],
        "refiner": ["pool"],
        "dropout": [0.1, 0.3, 0],
        "cross": [0.9],
        "num_mf": [2],
        "fix": [True],
        "attention": [False],
        "layer": [3, 2, 1],
        "residual": [False],
        "norm": [False],
        "value_intervals": [[-1, 1]],
        "optim": [
            {"lr": 0.003, "weight_decay": 1e-3},
            {"lr": 0.003, "weight_decay": 1e-5},
            {"lr": 0.005, "weight_decay": 1e-3},
            {"lr": 0.005, "weight_decay": 1e-5},
        ],
        "type": ["regress"],
        "epoch": [350],
    }

    train, valid, test = molfreesolv_loader.molfreesolv()
    for cfg in Grid_Search(configure):
        res = []
        print(cfg)
        for exp in range(3):
            m = FLGnnA(**cfg).to(device)
            # try:
            ress = fit(m, cfg, train, test, valid)
            res.append(ress)
            # except:
            #     continue
        print(np.nanmean(res), np.nanstd(res))
        print(get_parameter_number(m))
