import os
from typing import List, Optional, Tuple
import hydra
import pytorch_lightning as L
import pyrootutils
from datetime import datetime
import torch
from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
from omegaconf import DictConfig
from omegaconf import OmegaConf
import GPUtil

from biggs.data.predictor_data_module import PredictorDataModule
from biggs.models.predictor_module import PredictorModule
from pytorch_lightning.trainer import Trainer

pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
# ------------------------------------------------------------------------------------ #
# the setup_root above is equivalent to:
# - adding project root dir to PYTHONPATH
#       (so you don't need to force user to install project as a package)
#       (necessary before importing any local modules e.g. `from src import utils`)
# - setting up PROJECT_ROOT environment variable
#       (which is used as a base for paths in "configs/paths/default.yaml")
#       (this way all filepaths are the same no matter where you run the code)
# - loading environment variables from ".env" in root dir
#
# you can remove it if you:
# 1. either install project as a package or move entry files to project root dir
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
#
# more info: https://github.com/ashleve/pyrootutils
# ------------------------------------------------------------------------------------ #

from biggs import utils

log = utils.get_pylogger(__name__)


def train(cfg: DictConfig) -> Tuple[dict, dict]:

    # set seed for random number generators in pytorch, numpy and python.random
    if cfg.get("seed"):
        L.seed_everything(cfg.seed, workers=True)

    # Set-up data
    if cfg.data.task == 'GFP':
        task_cfg = cfg.experiment.gfp
        filter_range = task_cfg.filter_percentile
    elif cfg.data.task == 'AAV':
        task_cfg = cfg.experiment.aav
        filter_range = task_cfg.filter_percentile
    else:
        raise ValueError(f"Unknown task: {cfg.data.task}")
    log.info(f'Training predictor on task {cfg.data.task}')
    datamodule: LightningDataModule = PredictorDataModule(
        **cfg.data,
        task_cfg=task_cfg,
    )

    write_path = datamodule._dataset._write_path
    log.info(
        f"Preprocessed base sequences has saved to {write_path}.")

    if cfg.debug:
        logger = None
        log.info("Debug mode! Not logging to wandb...")
    logger = None

    # Set-up model
    model: LightningModule = PredictorModule(cfg.model)

    callbacks_cfg = cfg.callbacks
    percentile = '_'.join([str(x) for x in filter_range])
    timestamp = datetime.now().strftime("%dD_%mM_%YY_%Hh_%Mm_%Ss")
    smoothed = 'smoothed' if task_cfg.smoothed_data_fname else 'unsmoothed'
    ckpt_dir = os.path.join(
        callbacks_cfg.model_checkpoint.dirpath,
        f'mutations_{task_cfg.min_mutant_dist}',
        f'percentile_{percentile}',
        f'{smoothed}',
        f'run_{timestamp}'
    )
    os.makedirs(ckpt_dir, exist_ok=True)

    callbacks_cfg.model_checkpoint.dirpath = ckpt_dir
    log.info(f'Model checkpoints being saved to: {ckpt_dir}')
    callbacks: List[Callback] = utils.instantiate_callbacks(callbacks_cfg)
    available_gpus = [int(x) for x in GPUtil.getAvailable(
            order='memory', limit = 8)]

    trainer: Trainer = Trainer(**cfg.trainer, callbacks=callbacks, logger=logger, devices=available_gpus)
    cfg.model.predictor.seq_len = datamodule._dataset._seq_len

    # Write config to same directory as checkpoints
    cfg_path = os.path.join(ckpt_dir, 'config.yaml')
    with open(cfg_path, 'w') as f:
        OmegaConf.save(config=cfg, f=f.name)


    log.info("Starting training!")
    trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))



@hydra.main(version_base="1.3", config_path="../configs", config_name="train_predictor.yaml")
def main(cfg: DictConfig) -> Optional[float]:
    # apply extra utilities
    # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
    utils.extras(cfg)

    # train the model
    train(cfg)


if __name__ == "__main__":
    main()
