import numpy as np

from fs_model import FLGnnA
import torch
import fs_loader
from utils import Grid_Search
import matplotlib.pyplot as plt
import seaborn as sns

device = "cuda"

def Grid_Search(config):
    res = []
    items = list(config.keys())

    def grid_search(select: dict, deep: int = 0):
        if deep == len(config):
            res.append(select)
        else:
            k = items[deep]
            for v in config[k]:
                grid_search(select | {k: v}, deep + 1)

    grid_search(dict())
    return res

if __name__ == '__main__':
    configure = {
        "hidden": [256],
        "out_channels": [1],
        "windows": [3],
        "stride": [1],
        "order": [0],
        "concat": [True],
        "extract_ratio": [1],
        "extractor": ["pool"],
        "cross": [0.9],
        "num_mf": [3],
        "fix": [True],
        "layer": [2],
        "norm": [False],
        "value_intervals": [[0, 1]],
        "optim": [
            {"lr": 0.005},
            # {"lr": 0.0045},
            # {"lr": 0.0065},
        ],
        "type": ["binary_classify"],
        "epoch": [80],
    }
    data = fs_loader.molhiv()
    for cfg in Grid_Search(configure):
        res = []
        for exp in range(1):
            m = FLGnnA(**cfg)
            # m.load_state_dict(torch.load(r"m.pt"))
            m = m.to(device)
            for e, graph in enumerate(data):
                x, fs = m(graph.to(device))
                if res:
                    for i in range(len(res)):
                        res[i] += fs[i]
                else:
                    res += fs
            for i in range(len(res)):
                res[i] = res[i] / (e + 1)
                sns.ecdfplot(res[i],)

                plt.grid()
                plt.xlabel("firing strength", fontsize=18)
                plt.ylabel("proportion", fontsize=18)
                plt.yticks(np.arange(0, 1.1, 0.1), labels=[f"{i}%" for i in range(0, 110, 10)])
                plt.savefig("test.svg", dpi=300, format="svg")
                plt.show()
