import torch as pt
from torch.utils.data import random_split
from torch.nn.functional import one_hot

import pytorch_lightning as ptl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

import torch_geometric as ptg
from torch_geometric.loader import DataLoader
from torch_geometric.data import InMemoryDataset
from torch_geometric.transforms import BaseTransform

import ray
from ray import air, tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler, FIFOScheduler
from ray.tune.search import Repeater
from ray.tune.search.optuna import OptunaSearch
from ray.tune.search.basic_variant import BasicVariantGenerator
from ray.tune.integration.pytorch_lightning import TuneReportCallback

from GNN_models import GNN_Module, AppendEVs


from sklearn.model_selection import KFold, train_test_split, StratifiedKFold
import time
import json
import glob
import os

import networkx as nx
import numpy as np
import copy

from pytorch_lightning.callbacks import Callback
from pytorch_lightning import LightningModule


def train_model(model_config, experiment_config, data_split, verbose=0, trainer=None, model=None):
    
    dataset_name = experiment_config['dataset_name']
    num_epochs = experiment_config['num_epochs'] if not experiment_config['is_cv'] else experiment_config['cv_num_epochs']
    
    if model_config['graph_level']:
        indcs = experiment_config['data_splits'][data_split['data_split_idx']][0]                                
        dataset = experiment_config["dataset"][indcs]
        train_dataset, validate_dataset = train_test_split(dataset, test_size=experiment_config['val_size'], stratify=[dataset[i].y[0].numpy() for i in range(len(dataset))])
    else:
        train_dataset = experiment_config["dataset"]
        validate_dataset = experiment_config["dataset"]
        if experiment_config['has_predetermined_split']:
            train_dataset.data.train_mask, validate_dataset.data.val_mask = experiment_config['data_splits'][data_split['data_split_idx']][0]
        else:
            train_dataset.data.train_mask, validate_dataset.data.val_mask = train_test_split(experiment_config['data_splits'][data_split['data_split_idx']][0], test_size=experiment_config['val_size'], stratify=train_dataset[0].y.numpy()[experiment_config['data_splits'][data_split['data_split_idx']][0]])
        
    
    train_loader = DataLoader(train_dataset, batch_size=experiment_config['batch_size'] , shuffle=True, pin_memory=True, num_workers=2)
    validate_loader = DataLoader(validate_dataset, batch_size=experiment_config['batch_size'], shuffle=False, pin_memory=True, num_workers=2)

    if model is None:
        model_config['num_outputs'] = experiment_config['num_classes']
        model_config['num_inputs'] = experiment_config['num_features']
        model = experiment_config['model_class'](model_config)
        
        model_name = model.__module_name__()
       
             
    if trainer is None:
        logger = TensorBoardLogger('', name='')
        metrics = {'train_loss':'train_loss', 'val_loss':'val_loss', 'train_acc':'train_acc', 'val_acc':'val_acc','val_score':'val_score'}
        callbacks= [TuneReportCallback(metrics, on='validation_end')] + experiment_config['callbacks']
        if model_config['learning_rate'] == 0:
            trainer = ptl.Trainer(max_epochs=1000, 
                                  precision='bf16-mixed', 
                                  accelerator='auto',
                                  
                                  )
            tuner = ptl.tuner.tuning.Tuner(trainer)
            tuner.lr_find(model, train_loader, min_lr=1e-6, max_lr=1e+0, num_training=100, mode='exponential', early_stop_threshold=100, update_attr=True)
            
            
        trainer = Trainer(precision='bf16-mixed', accelerator='auto', devices=1, min_epochs=experiment_config['num_epochs']//5, max_epochs=num_epochs, logger=logger, callbacks=callbacks, enable_progress_bar=False if verbose==0 else True, enable_model_summary=False if verbose==0 else True, enable_checkpointing=experiment_config['is_cv'])
        logger.log_hyperparams(model_config | {key:experiment_config[key] for key in experiment_config.keys() if key not in ['callbacks', 'cv_callbacks', 'dataset']})
    

    trainer.fit(model, train_loader, validate_loader)
        
    return trainer, model

def train_model_cv(data_split, model_configs, experiment_config, verbose=0):
    
    model = None
    trainer = None
    
    model_config = model_configs[data_split['data_split_idx'] if experiment_config['test'] < 0 else 0]
    trainer, model = experiment_config['training_procedure_cv'](model_config, experiment_config, data_split, verbose=verbose, model=model)
    if model_config['graph_level']:
        test_dataset = experiment_config["dataset"][experiment_config['data_splits'][data_split['data_split_idx']][1]]
        test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
        _, val_dataset = train_test_split(experiment_config["dataset"][experiment_config['data_splits'][data_split['data_split_idx']][0]], test_size=experiment_config['val_size'])
        val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    else:
        test_dataset = experiment_config["dataset"]
        val_dataset = experiment_config["dataset"]
        test_dataset.data.test_mask = experiment_config['data_splits'][data_split['data_split_idx']][1]
        if experiment_config['has_predetermined_split']:
            _ , val_dataset.data.val_mask = experiment_config['data_splits'][data_split['data_split_idx']][0]
        else:
            _, val_dataset.data.val_mask = train_test_split(experiment_config['data_splits'][data_split['data_split_idx']][0], test_size=experiment_config['val_size'], stratify=val_dataset[0].y.numpy()[experiment_config['data_splits'][data_split['data_split_idx']][0]])
        
        test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
        val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
        model_config['train_mask'] = []
    evals = {}
    paths = glob.glob('version_?/checkpoints/*.ckpt')
    max_version = max([int(path[8]) for path in paths])

    for checkpoint in sorted(glob.glob(f'version_{max_version}/checkpoints/*.ckpt')):
        print(checkpoint)
        evals[checkpoint] = trainer.test(dataloaders=test_loader, ckpt_path=checkpoint)[0] | trainer.validate(dataloaders=val_loader, ckpt_path=checkpoint)[0]   
    with open('evals.json', 'w') as f:
        json.dump(evals, f)
 
def tmh(model_config, experiment_config, data_split, verbose=0, trainer=None, model=None):
    
    trainer, model = experiment_config['training_procedure_cv'](model_config, experiment_config, data_split, verbose=verbose, model=model)   
    
    
def hyperparameter_search(experiment_config, model_config, data_split, num_workers, model=None, verbose=0, scope="last-10-avg"):
    model_name = experiment_config['model_class'].__module_name__()
    
    
    
    scheduler = FIFOScheduler()
    reporter = CLIReporter(
        parameter_columns=['learning_rate', 'dropout', 'activation_func'],
        metric_columns=["train_loss", "val_loss", "train_acc", "val_acc", "val_score", "training_iteration"])
    train_fn_with_parameters = tune.with_parameters(experiment_config['training_procedure'], experiment_config=experiment_config, data_split=data_split, verbose=verbose, model=model)
    
    resources_per_trial = {"cpu": 4, "gpu": 2/num_workers}
    tuner = tune.Tuner(
        tune.with_resources(
            train_fn_with_parameters,
            resources=resources_per_trial
        ),
        tune_config=tune.TuneConfig(
            metric="val_acc",
            mode="max",
            scheduler=scheduler,
            
            num_samples=1,
        ),
        run_config=air.RunConfig(
            local_dir='./ray_results',
            name="tune_"+experiment_config['experiment_name']+'/'+str(data_split['data_split_idx']),
            progress_reporter=reporter,
            callbacks=[tune.logger.TBXLoggerCallback()],
            verbose=1
        ),
        param_space=model_config,
    )
    results = tuner.fit()
    print(results.get_dataframe(filter_metric='training_iteration', filter_mode='max')) 
    
    
    
    
    
    best_value = np.inf
    value_per_hyperparameters = {}
    keys = results[0].config.keys()
    for result in results:
        
        df = result.metrics_dataframe
        tag = tuple([str(result.config[key]) for key in keys if key != 'iteration'])
        
        value = [np.mean(sorted(result.metrics_dataframe['val_score'])[:50]), np.mean(sorted(result.metrics_dataframe['val_acc'], reverse=True)[:50])]
        if tag not in value_per_hyperparameters:
            value_per_hyperparameters[tag] = ([value], result.config)
        else :
            value_per_hyperparameters[tag][0].append(value)
            
    
    mean_val_loss = [(np.mean(value_per_hyperparameters[key][0], axis=0), value_per_hyperparameters[key][1]) for key in value_per_hyperparameters.keys() if not np.isnan(np.mean(value_per_hyperparameters[key][0]))]
    
    best_value, best_config = sorted(mean_val_loss, key=lambda x: x[0][1])[-1]
    print(best_value, best_config)
    
    
    ray.shutdown()
    
    return best_config
    

    
    
def crossvalidation(model_config, experiment_config,  num_workers, verbose=0):
    model_name = experiment_config['model_class'].__module_name__()
    
    kf = KFold(n_splits=experiment_config['cv_num_folds'], shuffle=True, random_state=7)
    experiment_config['data_splits'] = {idx: (train_idx.tolist(), test_idx.tolist()) for idx, (train_idx, test_idx) in enumerate(kf.split(experiment_config['dataset']))}
    
    data_split = {'data_split_idx': tune.grid_search([i for i in range(experiment_config['cv_num_folds'])]),
                  'iteration': tune.grid_search([i for i in range(experiment_config['cv_iterations'])])}

    scheduler = FIFOScheduler()
    reporter = CLIReporter(metric_columns=["train_loss", "val_loss", "train_acc", "val_acc", "training_iteration"])
    train_fn_with_parameters = tune.with_parameters(train_model_cv, model_config=model_config, experiment_config=experiment_config, verbose=verbose)
    resources_per_trial = {"cpu": 1, "gpu": 1/num_workers}
    tuner = tune.Tuner(
        tune.with_resources(
            train_fn_with_parameters,
            resources=resources_per_trial
        ),
        tune_config=tune.TuneConfig(
            scheduler=scheduler,
            num_samples=1,
        ),
        run_config=air.RunConfig(
            local_dir='./ray_results_cv',
            name= experiment_config['experiment_name'],
            progress_reporter=reporter,
            callbacks=[tune.logger.TBXLoggerCallback()],
            verbose=1
        ),
        param_space=data_split,
    )
    results = tuner.fit()
    print("Accuracy of best fold: ", results.get_best_result(metric="val_acc", mode="max", scope="last-10-avg").metrics['val_acc'])
    print("Accuracy of worst fold: ", results.get_best_result(metric="val_acc", mode="min", scope="last-10-avg").metrics['val_acc'])
    
def crossvalidation_with_hyperparameter_search(model_config, experiment_config,  num_workers, verbose=0):

    root_dir = 'ray_results_cv/'+experiment_config['dataset_name']+'/' + experiment_config['experiment_name'] + '/'
    os.makedirs(os.path.dirname(root_dir + 'checkpoints'), exist_ok=True)
    if model_config['graph_level']:
       
        kf = StratifiedKFold(n_splits=experiment_config['cv_num_folds'], shuffle=True, random_state=42)
        dataset = experiment_config['dataset']

    
        y_labels = np.array([experiment_config['dataset'][i].y.numpy()[0] for i in range(len(experiment_config['dataset']))])
        train_test_idcs = {idx: (train_idx.tolist(), test_idx.tolist()) for idx, (train_idx, test_idx) in enumerate(kf.split(y_labels, y=y_labels))}
        experiment_config['data_splits'] =  train_test_idcs 
    else:
        if experiment_config['has_predetermined_split']:
            splits = experiment_config['dataset'].get_idx_split()
            print(splits)
            experiment_config['data_splits'] = {0: ((splits['train'].tolist(), splits['valid'].tolist()), splits['test'].tolist())}
        else: 
            kf = StratifiedKFold(n_splits=experiment_config['cv_num_folds'], shuffle=True, random_state=42)
            y_labels = experiment_config['dataset'][0].y.numpy()
            kf.split(experiment_config['dataset'][0], y=y_labels)
            experiment_config['data_splits'] = {idx: (train_idx.tolist(), test_idx.tolist()) for idx, (train_idx, test_idx) in enumerate(kf.split(y_labels, y=y_labels))}
        
    best_results_for_fold = []
    for i in ([experiment_config['test']] if experiment_config["test"] >= 0 else range(experiment_config['cv_num_folds'])):
        
        data_split = {'data_split_idx': i}

        
        if experiment_config['skip_hyperparameter_search']:
            if i == 0:
                os.makedirs(os.path.dirname('ray_results/tune_' + experiment_config['experiment_name'] + '/' + 'hyperparameters.json'), exist_ok=True)
                filename = experiment_config['hyperparameters']
                with open(filename , 'r') as f:
                    best_results_for_fold = json.load(f)
                model_config.pop('pretrained_model_path', None)
        elif experiment_config['hyperparameter_search_only_first_fold'] == True:
            if i == 0:
                best_results_for_fold = [(model_config | hyperparameter_search(experiment_config=experiment_config, model_config=model_config, num_workers=num_workers, data_split=data_split, verbose=verbose))]*experiment_config['cv_num_folds']
                model_config.pop('pretrained_model_path', None)
        else:
            best_results_for_fold.append(model_config | hyperparameter_search(experiment_config=experiment_config, model_config=model_config, num_workers=num_workers, data_split=data_split, verbose=verbose))
            model_config.pop('pretrained_model_path', None)
        
        filename = 'ray_results/tune_' + experiment_config['experiment_name'] + '/' + 'hyperparameters.json'
        with open(filename , 'w') as f:
            json.dump([{key:best_results_dict[key] for key in best_results_dict.keys() if key not in []} for best_results_dict in best_results_for_fold], f)
    
            
    experiment_config['is_cv'] = True
    experiment_config['callbacks'] = [ModelCheckpoint(filename='{epoch}-{val_acc:.2f}', save_top_k=experiment_config['cv_save_top_k'], monitor='val_acc', mode='max', save_last=True)] + experiment_config['cv_callbacks']

    print(best_results_for_fold)

    filename = root_dir + 'hyperparameters.json'
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(filename , 'w') as f:
        json.dump([{key:best_results_dict[key] for key in best_results_dict.keys() if key not in []} for best_results_dict in best_results_for_fold], f)
    
        
    filename = root_dir + 'experiment_config.json' 
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(filename , 'w') as f:
        json.dump({key:experiment_config[key] for key in experiment_config.keys() if key not in ['callbacks', 'cv_callbacks', 'dataset', 'model_class', 'training_procedure', 'training_procedure_cv']}, f)
        
    filename = 'hyperparameters/' + experiment_config['experiment_name']
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(filename , 'w') as f:
        json.dump(best_results_for_fold, f)
    

    data_split = {'data_split_idx': tune.grid_search(([experiment_config['test']] if experiment_config["test"] >= 0 else range(experiment_config['cv_num_folds']))),
                  'iteration': tune.grid_search([i for i in range(experiment_config['cv_iterations'])])}
    
    scheduler = FIFOScheduler()
    reporter = CLIReporter(metric_columns=["train_loss", "val_loss", "train_acc", "val_acc", "training_iteration"])
    train_fn_with_parameters = tune.with_parameters(train_model_cv, model_configs=best_results_for_fold, experiment_config=experiment_config, verbose=verbose)
    resources_per_trial = {"cpu": 4, "gpu": 2/num_workers}
    tuner = tune.Tuner(
        tune.with_resources(
            train_fn_with_parameters,
            resources=resources_per_trial
        ),
        tune_config=tune.TuneConfig(
            scheduler=scheduler,
            num_samples=1,
        ),
        run_config=air.RunConfig(
            local_dir='./ray_results_cv/'+experiment_config['dataset_name']+'/',
            name= experiment_config['experiment_name'],
            progress_reporter=reporter,
            callbacks=[tune.logger.TBXLoggerCallback()],
            verbose=1
        ),
        param_space=data_split,
    )
    results = tuner.fit()
    print("Accuracy of best fold: ", results.get_best_result(metric="val_acc", mode="max", scope="last-10-avg").metrics['val_acc'])
    print("Accuracy of worst fold: ", results.get_best_result(metric="val_acc", mode="min", scope="last-10-avg").metrics['val_acc'])