import os
import logging
import os.path as osp
from logging import Logger

import hydra
import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf
from util.train_utils import find_last_checkpoint_path
from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import seed_everything
import torch.nn as nn
from tasks.TSTT_pretrain import TSTTPretrainTask
#from lightning.pytorch.profilers import SimpleProfiler, AdvancedProfiler
from util.environment_setup import setup_environment
import hashlib

# Call the setup function
setup_environment(config_file='util/environment.yaml')


def generate_version_name_finetune(cfg):
    hyperparams = f"pretrained={cfg.finetune_pretrained}_freeze_layers={cfg.freeze_layers}_token_mask_ratio={cfg.token_mask_ratio}_channel_mask_ratio={cfg.channel_mask_ratio}_test"
    return hashlib.md5(hyperparams.encode()).hexdigest()

def generate_version_name(cfg):
    hyperparams = f"token_mask_ratio={cfg.token_mask_ratio}_channel_mask_ratio={cfg.channel_mask_ratio}"
    return hashlib.md5(hyperparams.encode()).hexdigest()

# Register the environment variable resolver
OmegaConf.register_new_resolver("env", lambda key: os.getenv(key))

OmegaConf.register_new_resolver("get_method", hydra.utils.get_method)

logger: Logger = logging.getLogger(__name__)

def train(cfg: DictConfig):
    seed_everything(cfg.seed)

    version_name_old = generate_version_name(cfg)
    version_name = generate_version_name_finetune(cfg)
    
    # tensorboard
    tb_logger = TensorBoardLogger(save_dir=osp.expanduser(cfg.io.base_output_path), name=cfg.tag, version=version_name)

    # Log hyperparameters to TensorBoard
    tb_logger.log_hyperparams({
        'token_mask_ratio': cfg.token_mask_ratio,
        'channel_mask_ratio': cfg.channel_mask_ratio
    })

    # DataLoader
    print("===> Loading datasets")
    data_module = hydra.utils.instantiate(cfg.data_module)

    # Pytorch Lightning module
    print("===> Start building model")
    model = hydra.utils.instantiate(cfg.task, cfg)
    print(model)
    # Load pretrained checkpoint
    if cfg.pretrained_checkpoint is not None:
        print(f"===> Loading pretrained_checkpoint from {cfg.pretrained_checkpoint}")
        model = model.load_from_checkpoint(checkpoint_path=cfg.pretrained_checkpoint)

    # Resume training from saved checkpoint
    #checkpoint_dirpath = f"{tb_logger.log_dir}/checkpoints"
    # Get the real checkpoint_dirpath:
    checkpoint_dirpath = cfg.io.checkpoint_dirpath
    # Now we need to ammend this with the version name + the tag
    # So an example checkpoint_dirpath is /usr/scratch/mjolnir/thoriri/checkpoints/ then we want to add /tag/version_name
    checkpoint_dirpath = osp.join(checkpoint_dirpath, cfg.tag, version_name)
    print(f"Checkpoint path: {checkpoint_dirpath}")
    last_ckpt = None
    if cfg.resume:
        last_ckpt = find_last_checkpoint_path(checkpoint_dirpath)
    print("===> Checkpoint callbacks")
    model_checkpoint = ModelCheckpoint(dirpath=checkpoint_dirpath, **cfg.model_checkpoint)
    callbacks = [model_checkpoint]

    # Other Pytorch Lightning callbacks
    print("===> Instantiate other callbacks")
    for _, callback in cfg.callbacks.items():
        callbacks.append(hydra.utils.instantiate(callback))


    if cfg.finetune_pretrained:
        print("===> Finetune from pretrained model")
        checkpoint_dirpath_old = osp.join(cfg.io.checkpoint_dirpath, cfg.tag_old, version_name_old, 'last.ckpt')
        pretrainmodel = TSTTPretrainTask.load_from_checkpoint(checkpoint_path=checkpoint_dirpath_old)
        model.model = pretrainmodel.model
        if('TSTT' in cfg.tag):
            # Here the d_model is C*D
            d_model = cfg.model.embed_dim * cfg.model.num_channels
        else:
            # Here the d_model is S*D
            d_model = None # To be implemented
        model.model.classifier = nn.Linear(d_model, cfg.num_classes_finetune)
        if(cfg.freeze_layers):
            print("===> Freezing layers")
            for param in model.model.parameters():
                param.requires_grad = False
            for param in model.model.classifier.parameters():
                param.requires_grad = True
        else:
            print("===> Freezing Decoder only")
            for param in model.model.decoder_layers.parameters():
                param.requires_grad = False
            for param in model.model.upsampler.parameters():
                param.requires_grad = False
    else:
        print("===> Freezing Decoder only")
        for param in model.model.decoder_layers.parameters():
            param.requires_grad = False
        for param in model.model.upsampler.parameters():
            param.requires_grad = False

    # Trainer
    print("===> Instantiate trainer")
    if cfg.trainer.strategy == "ddp":
        del cfg.trainer.strategy
        trainer = Trainer(
            **cfg.trainer,
            logger=tb_logger,
            callbacks=callbacks,
            strategy=DDPStrategy(find_unused_parameters=cfg.find_unused_parameters),
        )
    else:
        trainer = Trainer(
            **cfg.trainer,
            logger=tb_logger,
            callbacks=callbacks,
        )

    # Train the model
    if cfg.training:
        print("===> Start training")
        trainer.fit(model, data_module, ckpt_path=last_ckpt)

    if cfg.final_validate:
        print("===> Start validation")
        trainer.validate(model, data_module)
    if cfg.final_test:
        print("===> Start testing")
        trainer.test(model, data_module)

    if not cfg.training:
        trainer.save_checkpoint(f"{checkpoint_dirpath}/last.ckpt")


@hydra.main(config_path="./config", config_name="defaults", version_base="1.1")
def run(cfg: DictConfig):
    print(f"PyTorch-Lightning Version: {pl.__version__}")
    print(OmegaConf.to_yaml(cfg, resolve=True))
    train(cfg)


if __name__ == "__main__":
    # Ensure environment variables are set before Hydra processes the config
    os.environ['HYDRA_FULL_ERROR'] = '1'
    run()



