from reddit_model import FLGnnA
import torch
from tqdm import tqdm
import numpy as np
from torch_geometric.data import Data
from torch.cuda.amp import autocast as autocast
from sklearn.metrics import f1_score, roc_auc_score
from reddit_loader import reddit
import itertools
from torch.cuda.amp import GradScaler

device = "cuda"


def Grid_Search(config):
    res = []
    items = list(config.keys())

    def grid_search(select: dict, deep: int = 0):
        if deep == len(config):
            res.append(select)
        else:
            k = items[deep]
            for v in config[k]:
                grid_search(select | {k: v}, deep + 1)

    grid_search(dict())
    return res


def get_parameter_number(model):
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}


def fit(m, config: dict, dataset, data_sampler, test_sampler):
    epoch = config.get("epoch", 10)

    optim_config = config.get("optim", {"lr": 0.01})

    critical = {
        "regress": torch.nn.MSELoss,
        "binary_classify": torch.nn.BCEWithLogitsLoss,
        "multi_classify": torch.nn.CrossEntropyLoss,
    }[config.get("type")]()

    if not config.get("fix", True):
        optimizer = torch.optim.Adam(**optim_config, params=m.get_segregate_parapms())
    else:
        optimizer = torch.optim.Adam(**optim_config, params=m.parameters())

    # reduce_schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
    #                                                              mode='min',
    #                                                              factor=0.8,
    #                                                              patience=20,
    #                                                              verbose=True,
    #                                                              threshold=1e-4,
    #                                                              threshold_mode='rel',
    #                                                              cooldown=0,
    #                                                              min_lr=0.000001,
    #                                                              eps=1e-6)

    reduce_schedule = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=100, T_mult=2, eta_min=0.0001,
                                                                           verbose=False)

    m.train()
    length = len(data_sampler)
    pbar = tqdm(total=epoch * length)
    counter = itertools.count(1)
    best = 0
    scaler = GradScaler()
    for e in range(epoch):
        for batch_size, n_id, adjs in data_sampler:
            optimizer.zero_grad()
            x_sample = dataset.x[n_id]
            with autocast():
                graph = Data()
                pre_y = m(x_sample, adjs)
                loss = critical(pre_y[:batch_size], dataset.y[n_id[:batch_size]])
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            reduce_schedule.step(next(counter))
            pbar.set_description(f'epoch: {e}/{length}, loss_train: {loss}, test_best: {best}')
            pbar.update()
        cur = metric(m, dataset, test_sampler, data_sampler)
        if cur > best:
            best = cur
    pbar.close()
    return best


def f1(true, pre):
    true = true.numpy()
    pre = pre.numpy()
    return f1_score(true, pre, average="micro")


def metric(m, dataset, test_sampler, train_sampler):
    m.eval()
    with torch.no_grad():
        y_pres = []
        y_true = []
        for batch_size, n_id, adjs in test_sampler:
            x_sample = dataset.x[n_id]
            with autocast():
                pre_y = m(x_sample, adjs)[:batch_size].argmax(dim=1)
                y_pres.append(pre_y.detach().cpu())
                y_true.append(dataset.y[n_id[:batch_size]].detach().cpu())
        f1_test = f1(torch.cat(y_true, dim=0), torch.cat(y_pres, dim=0))

        # y_pres = []
        # y_true = []
        # for batch_size, n_id, adjs in train_sampler:
        #     x_sample = dataset.x[n_id]
        #     with autocast():
        #         pre_y = m(x_sample, adjs)[:batch_size].argmax(dim=1)
        #         y_pres.append(pre_y.detach().cpu())
        #         y_true.append(dataset.y[n_id[:batch_size]].detach().cpu())
        # f1_train = f1(torch.cat(y_true, dim=0), torch.cat(y_pres, dim=0))

        print(f"f1_test: {f1_test},  f1_train: {'...'}")
    m.train()
    return f1_test


if __name__ == '__main__':
    configure = {
        "in_channels": [602],
        "hidden": [1024],
        "out_channels": [41],
        "windows": [6],
        "stride": [10],
        "order": [1],
        "A_P2": [True],
        "extract_ratio": [1],
        "extractor": ["pool"],
        "cross": [0.9],
        "num_mf": [2],
        "attention": [False],
        "fix": [True],
        "residual": [False],
        "layer": [2],
        "norm": [False],
        "value_intervals": [[-1, 1]],
        "optim": [
            {"lr": 0.004, "weight_decay": 1e-8},
        ],
        "type": ["multi_classify"],
        "epoch": [7],
    }

    dataset, train_sampler, valid_sampler, test_sampler = reddit()
    for cfg in Grid_Search(configure):
        res = []
        for exp in range(3):
            print(cfg)
            # m = TranformerGNN(**cfg).to(device)
            m = FLGnnA(**cfg).to(device)
            # print(get_parameter_number(m))
            m = fit(m, cfg, dataset, train_sampler, test_sampler=test_sampler)
            res.append([m])
        print(np.mean(res), np.std(res))
