import torch
import wandb

from flood_echo.datasets.datasets import get_lightning_dataset, get_dataset
from flood_echo.network.model import get_model

from torch_geometric.seed import seed_everything

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import LearningRateMonitor, Timer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint

from flood_echo.network.flood_echo import FloodModel
from flood_echo.network.gin import GIN
from flood_echo.network.recgnn import RecGNN
from flood_echo.network.pgn import PGN

MODEL_DIRECTORY = 'models/'


def train_and_eval(config, cluster=None):
    #WandDB init
    if config.use_wandb:
        wandb.init(project=config.wandb_project_name, config=config)

    #CUDA device
    device = f'cuda:{config.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)


    # load datasets
    data = get_lightning_dataset(config.dataset, config)

    # set seeds for training, (dataset splits are definied in datasets)
    seed_everything(config.run_number)

    # load model 
    model = get_model(config)

    lr_monitor = LearningRateMonitor(logging_interval='epoch')

    # laod logger
    logger = None
    if config.use_wandb:
        logger = WandbLogger(project=config.wandb_project_name)

    early_stop_callback = EarlyStopping(monitor="val_loss_step", patience=25, verbose=False, mode="min")
    monitor_c = "valid_loss_epoch"
    mode="min"
    if config.model =="PGN":
        monitor_c = "valid_acc_epoch"
        mode="max"
        early_stop_callback = EarlyStopping(monitor="val_acc_step", patience=25, verbose=False, mode="max")
    checkpoint_callback = ModelCheckpoint(
        monitor=monitor_c,
        dirpath='checkpoints/',
        filename=f"{config.model_name}",#f'{config.dataset}-{config.run_number}-{config.start_mode}-'+'{epoch:02d}-{valid_loss_epoch:.5f}'+f'-{config.model_name}',
        verbose=True,
        mode = mode
    )
    print("filename", config.model_name)

        # stop training after 12 hours
    timer = Timer(duration="02:12:00:00")
    # look into optimizer, schedule, loss (balanced)
    trainer = pl.Trainer(
        gradient_clip_val=1,
        gradient_clip_algorithm='norm', 
        max_epochs=config.epochs,
        logger=logger,
        callbacks=[lr_monitor, early_stop_callback, checkpoint_callback, timer],
        num_sanity_val_steps=0,
        log_every_n_steps=1,
        )

    print(config)
    print(logger)
    if hasattr(config, 'inference_mode') and config.inference_mode: 



        load_model_path = 'checkpoints/' + config.load_model + '.ckpt'
        if config.model == 'GIN':
            load_model = GIN.load_from_checkpoint(load_model_path)    
        elif config.model == 'RecGNN':
            load_model = RecGNN.load_from_checkpoint(load_model_path)
        elif config.model == 'FloodEcho':
            load_model = FloodModel.load_from_checkpoint(load_model_path)
        elif config.model == 'PGN':
            load_model = PGN.load_from_checkpoint(load_model_path)

        trainer.test(load_model, data)

        timer.start_time("validate")
        a = timer.start_time("test")
        b = timer.end_time("test")
        print(a, b)
        print(timer.time_elapsed("test"))
        print(b-a)
        wandb.log({"test_time": timer.time_elapsed("test")})#'test_time',timer.time_elapsed("test"))
    else:
        # then train using lightning
        trainer.fit(model, data)




        best_val_model_path = checkpoint_callback.best_model_path
        if config.model == 'GIN':
            best_val_model = GIN.load_from_checkpoint(best_val_model_path)    
        elif config.model == 'RecGNN':
            best_val_model = RecGNN.load_from_checkpoint(best_val_model_path)
        elif config.model == 'FloodEcho':
            best_val_model = FloodModel.load_from_checkpoint(best_val_model_path)
        elif config.model == 'PGN':
            best_val_model = PGN.load_from_checkpoint(best_val_model_path)


        # do another last pass on test set, (can compute more custom metrics as well ...)
        trainer.test(model, data)

        trainer.test(best_val_model, data)
        # maybe even test set throughout training?

        # model checkpointing, store best model, store last model, store all models
