import os
import time
from util.environment_setup import setup_environment
import logging
import os.path as osp
from logging import Logger
import hydra
import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf
from util.train_utils import find_last_checkpoint_path
from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import seed_everything
import torch.nn as nn
from tasks.tuab_pretrain import TuabPretrainTask

# Call the setup function
setup_environment(config_file='util/environment.yaml')


import hashlib

def generate_version_name(cfg):
    hyperparams = f"token_mask_ratio={cfg.token_mask_ratio}_channel_mask_ratio={cfg.channel_mask_ratio}"
    return hashlib.md5(hyperparams.encode()).hexdigest()

# Register the environment variable resolver
OmegaConf.register_new_resolver("env", lambda key: os.getenv(key))

OmegaConf.register_new_resolver("get_method", hydra.utils.get_method)

logger: Logger = logging.getLogger(__name__)

def train(cfg: DictConfig):
    seed_everything(cfg.seed)

    version_name = generate_version_name(cfg)

    # tensorboard
    tb_logger = TensorBoardLogger(save_dir=osp.expanduser(cfg.io.base_output_path), name=cfg.tag, version=version_name)

    # Log hyperparameters to TensorBoard
    tb_logger.log_hyperparams({
        'token_mask_ratio': cfg.token_mask_ratio,
        'channel_mask_ratio': cfg.channel_mask_ratio
    })

    # DataLoader
    print("===> Loading datasets")
    data_module = hydra.utils.instantiate(cfg.data_module)

    # Pytorch Lightning module
    print("===> Start building model")
    model = hydra.utils.instantiate(cfg.task, cfg)
    print(model)
    # Load pretrained checkpoint
    if cfg.pretrained_checkpoint is not None:
        print(f"===> Loading pretrained_checkpoint from {cfg.pretrained_checkpoint}")
        model = model.load_from_checkpoint(checkpoint_path=cfg.pretrained_checkpoint)

    # Resume training from saved checkpoint
    #checkpoint_dirpath = f"{tb_logger.log_dir}/checkpoints"
    # Get the real checkpoint_dirpath:
    checkpoint_dirpath = cfg.io.checkpoint_dirpath
    # Now we need to ammend this with the version name + the tag
    # So an example checkpoint_dirpath is /usr/scratch/mjolnir/thoriri/checkpoints/ then we want to add /tag/version_name
    checkpoint_dirpath = osp.join(checkpoint_dirpath, cfg.tag, version_name)
    print(f"Checkpoint path: {checkpoint_dirpath}")
    last_ckpt = None
    if cfg.resume:
        last_ckpt = find_last_checkpoint_path(checkpoint_dirpath)
    print("===> Checkpoint callbacks")
    model_checkpoint = ModelCheckpoint(dirpath=checkpoint_dirpath, **cfg.model_checkpoint)
    callbacks = [model_checkpoint]

    # Other Pytorch Lightning callbacks
    print("===> Instantiate other callbacks")
    for _, callback in cfg.callbacks.items():
        callbacks.append(hydra.utils.instantiate(callback))

    # Trainer
    print("===> Instantiate trainer")
    if cfg.trainer.strategy == "ddp":
        del cfg.trainer.strategy
        trainer = Trainer(
            **cfg.trainer,
            logger=tb_logger,
            callbacks=callbacks,
            strategy=DDPStrategy(find_unused_parameters=cfg.find_unused_parameters),
        )
    else:
        trainer = Trainer(
            **cfg.trainer,
            logger=tb_logger,
            callbacks=callbacks,
        )

    # Train the model
    if cfg.training:
        print("===> Start training")
        trainer.fit(model, data_module, ckpt_path=last_ckpt)

    if cfg.final_validate:
        print("===> Start validation")
        trainer.validate(model, data_module)
    if cfg.final_test:
        print("===> Start testing")
        trainer.test(model, data_module)

    if not cfg.training:
        trainer.save_checkpoint(f"{checkpoint_dirpath}/last.ckpt")


@hydra.main(config_path="./config", config_name="defaults", version_base="1.1")
def run(cfg: DictConfig):
    print(f"PyTorch-Lightning Version: {pl.__version__}")
    print(OmegaConf.to_yaml(cfg, resolve=True))
    train(cfg)


if __name__ == "__main__":
    # Ensure environment variables are set before Hydra processes the config
    #os.environ['HYDRA_FULL_ERROR'] = '1'
    start_time = time.time()
    run()
    print(f"Time taken (minutes): {(time.time() - start_time) / 60}")



