import os
from typing import List

import hydra
import numpy as np
import torch
import wandb
from hydra.utils import instantiate
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.loggers import WandbLogger
from omegaconf import DictConfig, open_dict

from common.utils import save_in_pickle
from data.datamodule_ddm import DDMDataModule
from model.pl_modules.ddm import DDM


def launch(cfg: DictConfig) -> None:
    print(f'Seed={cfg.seed}')
    seed_everything(cfg.seed, workers=True)

    logger = WandbLogger(**cfg.logger, settings=wandb.Settings(code_dir=cfg.project_root + "src/"))

    with open_dict(cfg):
        cfg.path_storage = f"storage/{'fdr' if cfg.trainer.fast_dev_run else logger.experiment.id}/"
        cfg.path_checkpoint = cfg.path_storage + "/checkpoints"

    logger.log_hyperparams(cfg)

    datamodule: DDMDataModule = instantiate(cfg.datamodule, _recursive_=False)

    pl_model: DDM = instantiate(cfg.model, pipeline=datamodule.pipeline)

    trainer: Trainer = instantiate(cfg.trainer, logger=logger)

    continue_train = False
    if cfg.ckpt_path and continue_train:
        trainer.fit(model=pl_model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)

    if cfg.ckpt_path and not continue_train:
        out: List[(np.ndarray, np.ndarray)] = trainer.predict(
            model=pl_model, datamodule=datamodule, return_predictions=True, ckpt_path=cfg.ckpt_path
        )
        synthetics = np.concatenate([s for s, _ in out], axis=0)
        coefficients = np.concatenate([c for _, c in out], axis=0)

        epoch = int(cfg.ckpt_path.split('epoch=')[1].split('.')[0])
        inference_data_path = cfg.ckpt_path[:cfg.ckpt_path.index('checkpoints')] + 'inference_data/'
        save_in_pickle(inference_data_path, f'synthetics_epoch={epoch}_seed={cfg.seed}', synthetics)
        save_in_pickle(inference_data_path, f'coefficients_epoch={epoch}_seed={cfg.seed}', coefficients)

    else:
        trainer.fit(model=pl_model, datamodule=datamodule)

    wandb.finish()


@hydra.main(version_base=None, config_path="conf", config_name="default")
def main(cfg: DictConfig) -> None:
    if cfg.trainer.fast_dev_run or cfg.ckpt_path:
        cfg.logger.mode = 'disabled'

    with open_dict(cfg):
        cfg.trainer.accelerator = 'cuda' if not cfg.trainer.fast_dev_run and torch.cuda.is_available() else 'cpu'
        cfg.trainer.devices = "auto" if torch.cuda.device_count() == 0 else torch.cuda.device_count()

        cfg.datamodule.num_workers = os.cpu_count() if cfg.trainer.accelerator == 'cuda' else 0
        cfg.datamodule.pin_memory = cfg.trainer.accelerator == 'cuda'

        cfg.datamodule.dataset.predict.n_features = len(cfg.dataset.feature_names)
        cfg.denoising_network.n_features = len(cfg.dataset.feature_names)

    launch(cfg)


if __name__ == "__main__":
    main()
