import numpy as np
import pytorch_lightning as pl
import sys
import torch
from pytorch_lightning.callbacks import Callback, EarlyStopping
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger

import neural_network
import survival_analysis_loss
import survival_game_loss
import survival_model


class PrintCallback(Callback):
    def on_train_epoch_start(self, trainer, pl_module):
        print('epoch start (PrintCallback)')

class TrainCallback(Callback):
    def __init__(self, dataset, model):
        self.dataset = dataset
        self.model = model

    def on_train_epoch_start(self, trainer, pl_module):
        print('epoch start (TrainCallback)')
        predict_dataloader = self.dataset.datamodule.predict_train_dataloader()
        y_pred = trainer.predict(dataloaders=predict_dataloader, model=self.model, ckpt_path=None)
        print('y_pred', y_pred)
        print('y_pred shape', y_pred[0].shape)

def create_trainer(args, trial=None):
    # prepare TensorBoard
    logger = None
    if args.output_log:
        logger = TensorBoardLogger(args.dir_name,
                                   name=args.model_name,
                                   default_hp_metric=False)

    # callback for early stopping
    callback_list = []
    if args.early_stopping_epoch > 0:
        ese = args.early_stopping_epoch
        callback_list.append(EarlyStopping('val_loss', patience = ese))
    elif args.early_stopping_threshold > 0:
        est = args.early_stopping_threshold
        for val_name, threshold in est.items():
            es = EarlyStopping(val_name, patience = 1000000,
                                stopping_threshold = threshold)
            callback_list.append(es)
    else:
        checkpoint_callback = pl.callbacks.ModelCheckpoint(
            save_top_k=1,
            monitor='val_loss')
        callback_list.append(checkpoint_callback)

    # create trainer
    trainer = pl.Trainer(logger=logger, deterministic=True, gpus=None,
                        enable_checkpointing=True,
                        max_epochs=args.num_epoch,
                        callbacks=callback_list,
                        enable_model_summary=False)

    return trainer

def create_model(datamodule, args, verbose = False):
    # create loss function
    if args.model=='SurvivalGame':
        loss_fn = survival_game_loss.SurvivalGameLoss(args, datamodule.y_max)
    else:
        loss_fn = survival_analysis_loss.SurvivalLoss(args, datamodule.y_max,
                                                      args.loss_function)

    # create neural network
    input_len = datamodule.x.shape[1]
    if args.neural_network == 'MLP':
        neural_network = survival_model.MLP(input_len, args,
                                            True, verbose)
    else:
        print('Unknown neural_network: %s' % args.neural_network)
        sys.exit()

    # create model
    if args.model == 'Softmax':
        model = survival_model.Softmax(neural_network, loss_fn, args)
    elif args.model == 'SurvivalGame':
        nn_f = survival_model.MLP(input_len, args, True, verbose)
        nn_g = survival_model.MLP(input_len, args, True, verbose)
        model = survival_model.SurvivalGame(nn_f, nn_g, loss_fn, args)
    else:
        print('Unknown model: %s' % args.model)
        sys.exit()

    if verbose:
        print(model)
    return model

def execute_lightning(datamodule, args, trial=None):
    model = create_model(datamodule, args)
    trainer = create_trainer(args, trial)
    train_dataloader = datamodule.train_dataloader()
    val_dataloader = datamodule.val_dataloader()

    # train
    device = torch.device('cpu')
    model = model.to(device)
    trainer.fit(model, train_dataloader, val_dataloader)
    return trainer

def execute(dataset, args, trial=None):
    model = create_model(dataset, args)
    neural_network.train(dataset, model, args)
    return model
