import copy
import os
import sys  

sys.path.insert(0,'./src')
os.environ["CUDA_VISIBLE_DEVICES"]="-1"

import json
import argparse
from argparse import Namespace

from sklearn.model_selection import StratifiedKFold, train_test_split
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

import numpy as np
from dataset import TUDataset
from torch_geometric.data import DataLoader

from model import *
    
def run_training_process_with_validation(run_params):
    print("######################################### NEW TRAIN on FOLD %d ######################################" % run_params.fold)

    dataset = TUDataset('./data',run_params.dataset)
    yy = [int(d.y) for d in dataset]
    fold = run_params.fold

    
    ###### Load or generate splits
    if not os.path.isfile('./data/folds/%s_folds_%d.txt' % (run_params.dataset, run_params.folds)):
        print('GENERATING %d FOLDS FOR %s' % (run_params.folds, run_params.dataset) ) 
        skf = StratifiedKFold(n_splits=run_params.folds, random_state=1, shuffle=True)
        folds = list(skf.split(np.arange(len(yy)),yy))

        folds_split = []
        for fold in range(run_params.folds):
          train_i_split, val_i_split = train_test_split([int(i) for i in folds[fold][0]],
                                                stratify=[n for n in np.asarray(yy)[folds[fold][0]]],
                                                test_size=int(len(list(folds[fold][0]))*0.1),
                                                random_state=0)
          test_i_split = [int(i) for i in folds[fold][1]]
          folds_split.append([train_i_split,val_i_split,test_i_split])

        with open('./data/folds/%s_folds_%d.txt' % (run_params.dataset, run_params.folds), 'w') as f:
            f.write(json.dumps(folds_split))

    fold = run_params.fold
    with open('./data/folds/%s_folds_%d.txt' % (run_params.dataset, run_params.folds), 'r') as f:
        folds = json.loads(f.read())
    train_i_split,val_i_split,test_i_split = folds[fold]
    
    train_dataset = dataset[train_i_split]
    val_dataset = dataset[val_i_split]
    
    test_dataset = dataset[test_i_split]

    train_loader = DataLoader(train_dataset[:], batch_size=run_params.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=run_params.batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=run_params.batch_size, shuffle=True)


    class MyDataModule(pl.LightningDataModule):
        def setup(self,stage=None):
            pass
        def train_dataloader(self):
            return train_loader
        def val_dataloader(self):
            return val_loader
        def test_dataloader(self):
            return test_loader
    
    run_params.in_features = train_dataset.num_features
    run_params.labels = train_dataset.num_features
    run_params.num_classes = train_dataset.num_classes
    
    model = Model(run_params)
    
    checkpoint_callback = ModelCheckpoint(
        save_last=True,
        save_top_k=1,
        verbose=True,
        monitor='val_acc',
        mode='max',
        dirpath='./checkpoints/%s_%s_fold%d' % (run_params.project,run_params.dataset,run_params.fold)
    )
    early_stop_callback = EarlyStopping(
        monitor='val_loss',
        min_delta=0.00,
        patience=300,
        verbose=False,
        mode='min')
        
    trainer = pl.Trainer.from_argparse_args(run_params,
                                            callbacks=[checkpoint_callback,early_stop_callback])
    trainer.fit(model, datamodule=MyDataModule())
    
    trainer.test(datamodule=MyDataModule())
    trainer.validate(datamodule=MyDataModule())
    
  
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    
    parser.add_argument("--project", default='GKNN')
    parser.add_argument("--dataset", default='MUTAG')
    parser.add_argument("--fold", default=0, type=int)
    parser.add_argument("--folds", default=10, type=int)

    parser.add_argument("--nodes", default=6, type=int) 
    parser.add_argument("--labels", default=7, type=int) 
    parser.add_argument("--hidden", default=16, type=int) 
    
    parser.add_argument("--layers", default=1, type=int) 
    parser.add_argument("--hops", default=2, type=int) #submask radius
    parser.add_argument("--kernel", default='wl', type=str) 
    parser.add_argument("--normalize", default=True, type=bool) 
    
    parser.add_argument("--pooling", default='add', type=str) 
    
    parser.add_argument("--jsd_weight", default=1e4, type=float) 
    parser.add_argument("--max_cc", default=True, type=bool)

    parser.add_argument("--max_epochs", default=500, type=int) 
    parser.add_argument("--lr", default=1e-3, type=float) 
    parser.add_argument("--lr_graph", default=1e-2, type=float) 
    
    parser.add_argument("--batch_size", default=32, type=int) 
    params = parser.parse_args()
    
    run_training_process_with_validation(params)
        