import time
import json
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.svm import SVC
from scipy.stats import mode
from torch_geometric.data import DataLoader

from trainable_scattering.utils import get_dataset

TU_DATASETS = [
    "NCI1",
    "NCI109",
    "DD",
    "PROTIENS",
    "PROTEINS",
    "MUTAG",
    "PTC",
    "ENZYMES",
    "REDDIT-BINARY",
    "REDDIT-MULTI-12K",
    "IMDB-BINARY",
    "IMDB-MULTI",
    "COLLAB",
    "REDDIT-MULTI-5K",
]


def train_shallow(args):
    def to_numpy(ds):
        loader = DataLoader(ds, batch_size=len(ds), shuffle=False)
        for data in loader: pass
        return data.x.numpy(), data.y.numpy()
    train_ds, val_ds, test_ds, train_loader, dataset = get_dataset(args, "cpu")

    X, y = to_numpy(train_ds)
    
    clf = make_pipeline(StandardScaler(), SVC(gamma=args["g"], C=args["c"]))
    clf.fit(X, y)
    test_X, test_y = to_numpy(test_ds)
    val_X, val_y = to_numpy(val_ds)
    return clf.predict(test_X), test_y



def train_model_shallow(in_dir, out_file):
    start = time.time()
    with open(str(in_dir), "r") as fp:
        args = json.load(fp)
    G_pool = [0.0001,0.001, 0.01, 0.1, 1,10]
    C_pool = [0.01, 0.1, 1, 10,25,50,100]
    best_acc = -1
    best_c = -1
    best_g = -1
    preds = None
    for c in C_pool:
        for g in G_pool:
            predict, test_y = train_shallow({**args, "c": c, "g":g})
            acc = sum(predict == test_y) / len(test_y)
            #print(dataset, splits, test_seed, val_seed, g, c, acc)
            if acc > best_acc:
                best_acc = acc
                best_c = c
                best_g = g
                preds = predict
    to_return = {
        'pred': preds,
        'true': test_y,
        'acc': best_acc,
        'best_g': best_g,
        'best_c': best_c,
    }
    end = time.time()
    print(args["dataset"], best_acc, best_g, best_c, "%0.2f" % (end - start))
    np.savez(str(out_file), **to_return)


def train_all_shallow():
    G_pool = [0.0001,0.001, 0.01, 0.1, 1,10]
    C_pool = [0.01, 0.1, 1, 10,25,50,100]
    num_val_seeds = [9, 8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [10, 5, 2, 10]
    all_splits = [[8, 1, 1], [7, 1, 2], [4, 1, 5], [2, 1, 7]]
    results = []
    best_params = []
    start = time.time()
    arg_list = []
    count = -1
    for dataset in TU_DATASETS:
        for i in range(4):
            splits = all_splits[i]
            for test_seed in range(num_test_seeds[i]):
                preds = []
                for val_seed in range(num_val_seeds[i]):
                    best_acc = -1
                    best_c = -1
                    best_g = -1
                    preds.append([])
                    for c in C_pool:
                        for g in G_pool:
                            count += 1
                            args = {
                                "dataset": dataset,
                                "splits": splits,
                                "test_seed": test_seed,
                                "transform": "fast_scatter",
                                "val_seed": val_seed,
                                "g": g,
                                "c": c,
                                "i": count,
                            }
                            arg_list.append(args)
                            """
                            predict, test_y = train_shallow(args)
                            acc = sum(predict == test_y) / len(test_y)
                            print(dataset, splits, test_seed, val_seed, g, c, acc)
                            if acc > best_acc:
                                best_acc = acc
                                best_c = c
                                best_g = g
                                preds[-1] = predict
                            """
                    #best_params.append((splits, dataset, test_seed, val_seed, g, c))
                #preds = np.array(preds)
                #agg_preds, _ = mode(preds, axis=0)
                #acc = (sum(agg_preds.squeeze() == test_y) / len(test_y))
                #results.append((splits, dataset, test_seed, acc))
        #end = time.time()
        #print("Dataset %s completed in 0.2%fs" % (dataset, end - start))
        #start = end
    param_df = pd.DataFrame(best_params, columns = ['split', 'dataset', 'test_seed', 'val_seed', 'g', 'c'])
    df = pd.DataFrame(results, columns=['split', 'dataset', 'test_seed', 'accuracy'])
    df.to_pickle('shallow_results_v3.pkl')
    param_df.to_pickle('shallow_best_params_v3.pkl')


if __name__ == '__main__':
    train_all_shallow()
    #df = pd.read_pickle('shallow_results.pkl')
